jacopo22295 commited on
Commit
aa983ba
·
verified ·
1 Parent(s): dda036a

Upload 7 files

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -35
  2. README.md +28 -12
  3. app.py +82 -0
  4. classes.json +11 -0
  5. model.py +28 -0
  6. requirements.txt +5 -0
  7. resnet34_best.pth +3 -0
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,12 +1,28 @@
1
- ---
2
- title: Resnet34 Testrun
3
- emoji: ⚡
4
- colorFrom: red
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.45.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # ResNet34 Corrosion Classifier — Hugging Face Space
3
+
4
+ Semplice Space Gradio che carica un modello ResNet34 e predice 9 classi di corrosione.
5
+
6
+ ## File
7
+ - `app.py`: interfaccia Gradio.
8
+ - `model.py`: definizione modello e caricamento pesi.
9
+ - `classes.json`: etichette delle classi.
10
+ - `requirements.txt`: dipendenze.
11
+ - `.gitattributes`: abilita LFS per i file `.pth`.
12
+ - `resnet34_best.pth`: **DA CARICARE DA TE** (non incluso).
13
+
14
+ ## Istruzioni
15
+ 1. Crea una nuova Space su Hugging Face (Gradio + Python).
16
+ 2. Carica questi file nella Space.
17
+ 3. Aggiungi il tuo file di pesi `resnet34_best.pth` (usa Git LFS se > 50 MB).
18
+ 4. (Opzionale) Se il file si chiama diversamente, imposta una variabile d'ambiente `CKPT_PATH`
19
+ nelle Settings della Space, oppure modifica `CKPT_PATH` in `app.py`.
20
+ 5. Avvia la Space.
21
+
22
+ ## Uso
23
+ - Carica o scatta una foto, poi clicca **Analizza immagine**.
24
+ - La card a destra mostra le probabilità (Top-K) e la predizione.
25
+
26
+ ## Note
27
+ - Il modello gira su CPU per default. Se vuoi più velocità, passa a una Space con GPU.
28
+ - Le trasformazioni input usano Resize 256 → CenterCrop 224 e normalizzazione ImageNet.
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import json
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ import gradio as gr
9
+
10
+ from model import build_model, load_weights
11
+
12
+ TITLE = "ResNet34 Corrosion Classifier"
13
+ DESCRIPTION = \"\"\
14
+ Carica o scatta una foto. Il modello (ResNet34) restituisce la classe prevista e le probabilità.
15
+ Assicurati di **caricare il file dei pesi** nella repo come `resnet34_best.pth` (o aggiorna il percorso qui sotto).
16
+ \"\"\
17
+
18
+ # ====== Config ======
19
+ CKPT_PATH = os.environ.get("CKPT_PATH", "resnet34_best.pth")
20
+ CLASSES_PATH = os.environ.get("CLASSES_PATH", "classes.json")
21
+ DEVICE = "cpu" # su Spaces CPU per default
22
+
23
+ with open(CLASSES_PATH, "r", encoding="utf-8") as f:
24
+ IDX2LABEL = json.load(f)
25
+
26
+ preprocess = transforms.Compose([
27
+ transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
28
+ transforms.CenterCrop(224),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
31
+ std=[0.229, 0.224, 0.225]),
32
+ ])
33
+
34
+ # Lazy load del modello
35
+ _model = None
36
+ def get_model():
37
+ global _model
38
+ if _model is None:
39
+ model = build_model(num_classes=len(IDX2LABEL))
40
+ if not os.path.isfile(CKPT_PATH):
41
+ raise FileNotFoundError(
42
+ f\"Checkpoint non trovato: {CKPT_PATH}. Carica i pesi nella Space o imposta CKPT_PATH.\"
43
+ )
44
+ model = load_weights(model, CKPT_PATH, map_location=DEVICE)
45
+ _model = model
46
+ return _model
47
+
48
+ def predict(image: Image.Image, topk: int = 5):
49
+ if image is None:
50
+ return {"Errore": 1.0}, "Nessuna immagine."
51
+ model = get_model()
52
+ model.eval()
53
+ with torch.no_grad():
54
+ img = image.convert("RGB")
55
+ tensor = preprocess(img).unsqueeze(0)
56
+ logits = model(tensor)
57
+ probs = F.softmax(logits, dim=1).squeeze(0)
58
+ topk = min(topk, probs.shape[0])
59
+ values, indices = torch.topk(probs, k=topk)
60
+ label_scores = {IDX2LABEL[i.item()]: float(v.item()) for v, i in zip(values, indices)}
61
+ pred_label = IDX2LABEL[int(torch.argmax(probs).item())]
62
+ msg = f"Predizione: **{pred_label}**"
63
+ return label_scores, msg
64
+
65
+ with gr.Blocks(fill_height=True) as demo:
66
+ gr.Markdown(f"# {TITLE}")
67
+ gr.Markdown(DESCRIPTION)
68
+
69
+ with gr.Row():
70
+ with gr.Column(scale=1):
71
+ img_in = gr.Image(type="pil", sources=["upload", "webcam"], label="Immagine")
72
+ topk = gr.Slider(1, 9, value=5, step=1, label="Top-K")
73
+ btn = gr.Button("Analizza immagine")
74
+ with gr.Column(scale=1):
75
+ lbl = gr.Label(label="Probabilità", num_top_classes=9)
76
+ txt = gr.Markdown()
77
+
78
+ btn.click(predict, inputs=[img_in, topk], outputs=[lbl, txt])
79
+ img_in.change(predict, inputs=[img_in, topk], outputs=[lbl, txt])
80
+
81
+ if __name__ == "__main__":
82
+ demo.launch()
classes.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "crevice_corrosion",
3
+ "erosion_corrosion",
4
+ "galvanic_corrosion",
5
+ "mic_corrosion",
6
+ "no_corrosion",
7
+ "pitting_corrosion",
8
+ "stress_corrosion",
9
+ "under_insulation_corrosion",
10
+ "uniform_corrosion"
11
+ ]
model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import models
5
+
6
+ def build_model(num_classes: int) -> nn.Module:
7
+ model = models.resnet34(weights=None)
8
+ in_features = model.fc.in_features
9
+ model.fc = nn.Linear(in_features, num_classes)
10
+ return model
11
+
12
+ def load_weights(model: nn.Module, ckpt_path: str, map_location="cpu") -> nn.Module:
13
+ state = torch.load(ckpt_path, map_location=map_location)
14
+ # Support both full state dicts and {'model': state_dict} formats
15
+ if isinstance(state, dict) and "state_dict" in state:
16
+ state = state["state_dict"]
17
+ if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
18
+ state = state["model"]
19
+ # Strip possible 'module.' prefixes
20
+ new_state = {}
21
+ for k, v in state.items():
22
+ if k.startswith("module."):
23
+ new_state[k[len("module."):]] = v
24
+ else:
25
+ new_state[k] = v
26
+ model.load_state_dict(new_state, strict=False)
27
+ model.eval()
28
+ return model
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.2.0
2
+ torchvision>=0.17.0
3
+ pillow>=10.3.0
4
+ numpy>=1.26.4
5
+ gradio==4.44.1
resnet34_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f03fbf24fb21f3eb66b8d348bbf908408c0c3b4176384ddc341ad23da3d553a
3
+ size 255709619