File size: 3,154 Bytes
aa983ba
 
 
 
 
 
 
 
 
 
 
 
c442e5e
aa983ba
d733d16
c442e5e
aa983ba
 
 
d733d16
aa983ba
d733d16
 
 
 
 
aa983ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d733d16
aa983ba
d733d16
aa983ba
 
 
 
 
d733d16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa983ba
 
 
 
 
 
 
 
d733d16
aa983ba
 
d733d16
aa983ba
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

import os
import json
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
import gradio as gr

from model import build_model, load_weights

TITLE = "ResNet34 Corrosion Classifier"
DESCRIPTION = """
Carica o scatta una foto. Il modello (ResNet34) restituisce la classe prevista e le probabilità.
Assicurati di caricare il file dei pesi nella repo come `resnet34_best.pth` (o imposta la variabile di ambiente `CKPT_PATH`).
"""

CKPT_PATH = os.environ.get("CKPT_PATH", "resnet34_best.pth")
CLASSES_PATH = os.environ.get("CLASSES_PATH", "classes.json")
DEVICE = "cpu"

os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")

if not os.path.isfile(CLASSES_PATH):
    raise FileNotFoundError(f"File classi non trovato: {CLASSES_PATH}")
with open(CLASSES_PATH, "r", encoding="utf-8") as f:
    IDX2LABEL = json.load(f)

preprocess = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

_model = None
def get_model():
    global _model
    if _model is None:
        if not os.path.isfile(CKPT_PATH):
            raise FileNotFoundError(
                f"Checkpoint non trovato: {CKPT_PATH}. Carica i pesi nella Space o imposta CKPT_PATH."
            )
        model = build_model(num_classes=len(IDX2LABEL))
        model = load_weights(model, CKPT_PATH, map_location=DEVICE)
        _model = model
    return _model

def predict(image: Image.Image, topk: int = 5):
    try:
        if image is None:
            return {}, "Nessuna immagine."
        model = get_model()
        model.eval()
        with torch.no_grad():
            img = image.convert("RGB")
            tensor = preprocess(img).unsqueeze(0)
            logits = model(tensor)
            probs = torch.softmax(logits, dim=1).squeeze(0)
            k = int(min(max(1, topk), probs.shape[0]))
            values, indices = torch.topk(probs, k=k)
            label_scores = {IDX2LABEL[i.item()]: float(v.item()) for v, i in zip(values, indices)}
            pred_label = IDX2LABEL[int(torch.argmax(probs).item())]
            msg = f"Predizione: **{pred_label}**"
            return label_scores, msg
    except Exception as e:
        return {}, f"Errore durante l'inferenza: {e}"

with gr.Blocks(fill_height=True) as demo:
    gr.Markdown(f"# {TITLE}")
    gr.Markdown(DESCRIPTION)

    with gr.Row():
        with gr.Column(scale=1):
            img_in = gr.Image(type="pil", sources=["upload", "webcam"], label="Immagine")
            topk = gr.Slider(1, len(IDX2LABEL), value=5, step=1, label="Top-K")
            btn = gr.Button("Analizza immagine")
        with gr.Column(scale=1):
            lbl = gr.Label(label="Probabilità", num_top_classes=len(IDX2LABEL))
            txt = gr.Markdown()

    btn.click(predict, inputs=[img_in, topk], outputs=[lbl, txt])
    img_in.change(predict, inputs=[img_in, topk], outputs=[lbl, txt])

if __name__ == "__main__":
    demo.launch()