jacopo22295 commited on
Commit
d733d16
·
verified ·
1 Parent(s): c442e5e

Upload 6 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -1
  2. README.md +1 -2
  3. app.py +29 -23
  4. model.py +0 -2
  5. requirements.txt +2 -1
.gitattributes CHANGED
@@ -1 +1 @@
1
- *.pth filter=lfs diff=lfs merge=lfs -text
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -10,7 +10,6 @@ app_file: app.py
10
  pinned: false
11
  ---
12
 
13
-
14
  # ResNet34 Corrosion Classifier — Hugging Face Space
15
 
16
  Semplice Space Gradio che carica un modello ResNet34 e predice 9 classi di corrosione.
@@ -27,7 +26,7 @@ Semplice Space Gradio che carica un modello ResNet34 e predice 9 classi di corro
27
  1. Crea una nuova Space su Hugging Face (Gradio + Python).
28
  2. Carica questi file nella Space.
29
  3. Aggiungi il tuo file di pesi `resnet34_best.pth` (usa Git LFS se > 50 MB).
30
- 4. (Opzionale) Se il file si chiama diversamente, imposta una variabile d'ambiente `CKPT_PATH`
31
  nelle Settings della Space, oppure modifica `CKPT_PATH` in `app.py`.
32
  5. Avvia la Space.
33
 
 
10
  pinned: false
11
  ---
12
 
 
13
  # ResNet34 Corrosion Classifier — Hugging Face Space
14
 
15
  Semplice Space Gradio che carica un modello ResNet34 e predice 9 classi di corrosione.
 
26
  1. Crea una nuova Space su Hugging Face (Gradio + Python).
27
  2. Carica questi file nella Space.
28
  3. Aggiungi il tuo file di pesi `resnet34_best.pth` (usa Git LFS se > 50 MB).
29
+ 4. (Opzionale) Se il file si chiama diversamente, imposta la variabile d'ambiente `CKPT_PATH`
30
  nelle Settings della Space, oppure modifica `CKPT_PATH` in `app.py`.
31
  5. Avvia la Space.
32
 
app.py CHANGED
@@ -12,14 +12,18 @@ from model import build_model, load_weights
12
  TITLE = "ResNet34 Corrosion Classifier"
13
  DESCRIPTION = """
14
  Carica o scatta una foto. Il modello (ResNet34) restituisce la classe prevista e le probabilità.
15
- Assicurati di **caricare il file dei pesi** nella repo come `resnet34_best.pth` (o aggiorna il percorso qui sotto).
16
  """
17
 
18
- # ====== Config ======
19
  CKPT_PATH = os.environ.get("CKPT_PATH", "resnet34_best.pth")
20
  CLASSES_PATH = os.environ.get("CLASSES_PATH", "classes.json")
21
- DEVICE = "cpu" # su Spaces CPU per default
22
 
 
 
 
 
 
23
  with open(CLASSES_PATH, "r", encoding="utf-8") as f:
24
  IDX2LABEL = json.load(f)
25
 
@@ -31,36 +35,38 @@ preprocess = transforms.Compose([
31
  std=[0.229, 0.224, 0.225]),
32
  ])
33
 
34
- # Lazy load del modello
35
  _model = None
36
  def get_model():
37
  global _model
38
  if _model is None:
39
- model = build_model(num_classes=len(IDX2LABEL))
40
  if not os.path.isfile(CKPT_PATH):
41
  raise FileNotFoundError(
42
- f\"Checkpoint non trovato: {CKPT_PATH}. Carica i pesi nella Space o imposta CKPT_PATH.\"
43
  )
 
44
  model = load_weights(model, CKPT_PATH, map_location=DEVICE)
45
  _model = model
46
  return _model
47
 
48
  def predict(image: Image.Image, topk: int = 5):
49
- if image is None:
50
- return {"Errore": 1.0}, "Nessuna immagine."
51
- model = get_model()
52
- model.eval()
53
- with torch.no_grad():
54
- img = image.convert("RGB")
55
- tensor = preprocess(img).unsqueeze(0)
56
- logits = model(tensor)
57
- probs = F.softmax(logits, dim=1).squeeze(0)
58
- topk = min(topk, probs.shape[0])
59
- values, indices = torch.topk(probs, k=topk)
60
- label_scores = {IDX2LABEL[i.item()]: float(v.item()) for v, i in zip(values, indices)}
61
- pred_label = IDX2LABEL[int(torch.argmax(probs).item())]
62
- msg = f"Predizione: **{pred_label}**"
63
- return label_scores, msg
 
 
 
64
 
65
  with gr.Blocks(fill_height=True) as demo:
66
  gr.Markdown(f"# {TITLE}")
@@ -69,10 +75,10 @@ with gr.Blocks(fill_height=True) as demo:
69
  with gr.Row():
70
  with gr.Column(scale=1):
71
  img_in = gr.Image(type="pil", sources=["upload", "webcam"], label="Immagine")
72
- topk = gr.Slider(1, 9, value=5, step=1, label="Top-K")
73
  btn = gr.Button("Analizza immagine")
74
  with gr.Column(scale=1):
75
- lbl = gr.Label(label="Probabilità", num_top_classes=9)
76
  txt = gr.Markdown()
77
 
78
  btn.click(predict, inputs=[img_in, topk], outputs=[lbl, txt])
 
12
  TITLE = "ResNet34 Corrosion Classifier"
13
  DESCRIPTION = """
14
  Carica o scatta una foto. Il modello (ResNet34) restituisce la classe prevista e le probabilità.
15
+ Assicurati di caricare il file dei pesi nella repo come `resnet34_best.pth` (o imposta la variabile di ambiente `CKPT_PATH`).
16
  """
17
 
 
18
  CKPT_PATH = os.environ.get("CKPT_PATH", "resnet34_best.pth")
19
  CLASSES_PATH = os.environ.get("CLASSES_PATH", "classes.json")
20
+ DEVICE = "cpu"
21
 
22
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
23
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
24
+
25
+ if not os.path.isfile(CLASSES_PATH):
26
+ raise FileNotFoundError(f"File classi non trovato: {CLASSES_PATH}")
27
  with open(CLASSES_PATH, "r", encoding="utf-8") as f:
28
  IDX2LABEL = json.load(f)
29
 
 
35
  std=[0.229, 0.224, 0.225]),
36
  ])
37
 
 
38
  _model = None
39
  def get_model():
40
  global _model
41
  if _model is None:
 
42
  if not os.path.isfile(CKPT_PATH):
43
  raise FileNotFoundError(
44
+ f"Checkpoint non trovato: {CKPT_PATH}. Carica i pesi nella Space o imposta CKPT_PATH."
45
  )
46
+ model = build_model(num_classes=len(IDX2LABEL))
47
  model = load_weights(model, CKPT_PATH, map_location=DEVICE)
48
  _model = model
49
  return _model
50
 
51
  def predict(image: Image.Image, topk: int = 5):
52
+ try:
53
+ if image is None:
54
+ return {}, "Nessuna immagine."
55
+ model = get_model()
56
+ model.eval()
57
+ with torch.no_grad():
58
+ img = image.convert("RGB")
59
+ tensor = preprocess(img).unsqueeze(0)
60
+ logits = model(tensor)
61
+ probs = torch.softmax(logits, dim=1).squeeze(0)
62
+ k = int(min(max(1, topk), probs.shape[0]))
63
+ values, indices = torch.topk(probs, k=k)
64
+ label_scores = {IDX2LABEL[i.item()]: float(v.item()) for v, i in zip(values, indices)}
65
+ pred_label = IDX2LABEL[int(torch.argmax(probs).item())]
66
+ msg = f"Predizione: **{pred_label}**"
67
+ return label_scores, msg
68
+ except Exception as e:
69
+ return {}, f"Errore durante l'inferenza: {e}"
70
 
71
  with gr.Blocks(fill_height=True) as demo:
72
  gr.Markdown(f"# {TITLE}")
 
75
  with gr.Row():
76
  with gr.Column(scale=1):
77
  img_in = gr.Image(type="pil", sources=["upload", "webcam"], label="Immagine")
78
+ topk = gr.Slider(1, len(IDX2LABEL), value=5, step=1, label="Top-K")
79
  btn = gr.Button("Analizza immagine")
80
  with gr.Column(scale=1):
81
+ lbl = gr.Label(label="Probabilità", num_top_classes=len(IDX2LABEL))
82
  txt = gr.Markdown()
83
 
84
  btn.click(predict, inputs=[img_in, topk], outputs=[lbl, txt])
model.py CHANGED
@@ -11,12 +11,10 @@ def build_model(num_classes: int) -> nn.Module:
11
 
12
  def load_weights(model: nn.Module, ckpt_path: str, map_location="cpu") -> nn.Module:
13
  state = torch.load(ckpt_path, map_location=map_location)
14
- # Support both full state dicts and {'model': state_dict} formats
15
  if isinstance(state, dict) and "state_dict" in state:
16
  state = state["state_dict"]
17
  if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
18
  state = state["model"]
19
- # Strip possible 'module.' prefixes
20
  new_state = {}
21
  for k, v in state.items():
22
  if k.startswith("module."):
 
11
 
12
  def load_weights(model: nn.Module, ckpt_path: str, map_location="cpu") -> nn.Module:
13
  state = torch.load(ckpt_path, map_location=map_location)
 
14
  if isinstance(state, dict) and "state_dict" in state:
15
  state = state["state_dict"]
16
  if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
17
  state = state["model"]
 
18
  new_state = {}
19
  for k, v in state.items():
20
  if k.startswith("module."):
requirements.txt CHANGED
@@ -1,5 +1,6 @@
 
1
  torch>=2.2.0
2
  torchvision>=0.17.0
3
  pillow>=10.3.0
4
  numpy>=1.26.4
5
- gradio==4.44.1
 
1
+
2
  torch>=2.2.0
3
  torchvision>=0.17.0
4
  pillow>=10.3.0
5
  numpy>=1.26.4
6
+ gradio==4.44.1