jacopo22295's picture
Upload 6 files
d733d16 verified
import os
import json
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
import gradio as gr
from model import build_model, load_weights
TITLE = "ResNet34 Corrosion Classifier"
DESCRIPTION = """
Carica o scatta una foto. Il modello (ResNet34) restituisce la classe prevista e le probabilità.
Assicurati di caricare il file dei pesi nella repo come `resnet34_best.pth` (o imposta la variabile di ambiente `CKPT_PATH`).
"""
CKPT_PATH = os.environ.get("CKPT_PATH", "resnet34_best.pth")
CLASSES_PATH = os.environ.get("CLASSES_PATH", "classes.json")
DEVICE = "cpu"
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
if not os.path.isfile(CLASSES_PATH):
raise FileNotFoundError(f"File classi non trovato: {CLASSES_PATH}")
with open(CLASSES_PATH, "r", encoding="utf-8") as f:
IDX2LABEL = json.load(f)
preprocess = transforms.Compose([
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
_model = None
def get_model():
global _model
if _model is None:
if not os.path.isfile(CKPT_PATH):
raise FileNotFoundError(
f"Checkpoint non trovato: {CKPT_PATH}. Carica i pesi nella Space o imposta CKPT_PATH."
)
model = build_model(num_classes=len(IDX2LABEL))
model = load_weights(model, CKPT_PATH, map_location=DEVICE)
_model = model
return _model
def predict(image: Image.Image, topk: int = 5):
try:
if image is None:
return {}, "Nessuna immagine."
model = get_model()
model.eval()
with torch.no_grad():
img = image.convert("RGB")
tensor = preprocess(img).unsqueeze(0)
logits = model(tensor)
probs = torch.softmax(logits, dim=1).squeeze(0)
k = int(min(max(1, topk), probs.shape[0]))
values, indices = torch.topk(probs, k=k)
label_scores = {IDX2LABEL[i.item()]: float(v.item()) for v, i in zip(values, indices)}
pred_label = IDX2LABEL[int(torch.argmax(probs).item())]
msg = f"Predizione: **{pred_label}**"
return label_scores, msg
except Exception as e:
return {}, f"Errore durante l'inferenza: {e}"
with gr.Blocks(fill_height=True) as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
img_in = gr.Image(type="pil", sources=["upload", "webcam"], label="Immagine")
topk = gr.Slider(1, len(IDX2LABEL), value=5, step=1, label="Top-K")
btn = gr.Button("Analizza immagine")
with gr.Column(scale=1):
lbl = gr.Label(label="Probabilità", num_top_classes=len(IDX2LABEL))
txt = gr.Markdown()
btn.click(predict, inputs=[img_in, topk], outputs=[lbl, txt])
img_in.change(predict, inputs=[img_in, topk], outputs=[lbl, txt])
if __name__ == "__main__":
demo.launch()