import os import json import torch import torch.nn.functional as F from PIL import Image from torchvision import transforms import gradio as gr from model import build_model, load_weights TITLE = "ResNet34 Corrosion Classifier" DESCRIPTION = """ Carica o scatta una foto. Il modello (ResNet34) restituisce la classe prevista e le probabilità. Assicurati di caricare il file dei pesi nella repo come `resnet34_best.pth` (o imposta la variabile di ambiente `CKPT_PATH`). """ CKPT_PATH = os.environ.get("CKPT_PATH", "resnet34_best.pth") CLASSES_PATH = os.environ.get("CLASSES_PATH", "classes.json") DEVICE = "cpu" os.environ.setdefault("OMP_NUM_THREADS", "1") os.environ.setdefault("MKL_NUM_THREADS", "1") if not os.path.isfile(CLASSES_PATH): raise FileNotFoundError(f"File classi non trovato: {CLASSES_PATH}") with open(CLASSES_PATH, "r", encoding="utf-8") as f: IDX2LABEL = json.load(f) preprocess = transforms.Compose([ transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) _model = None def get_model(): global _model if _model is None: if not os.path.isfile(CKPT_PATH): raise FileNotFoundError( f"Checkpoint non trovato: {CKPT_PATH}. Carica i pesi nella Space o imposta CKPT_PATH." ) model = build_model(num_classes=len(IDX2LABEL)) model = load_weights(model, CKPT_PATH, map_location=DEVICE) _model = model return _model def predict(image: Image.Image, topk: int = 5): try: if image is None: return {}, "Nessuna immagine." model = get_model() model.eval() with torch.no_grad(): img = image.convert("RGB") tensor = preprocess(img).unsqueeze(0) logits = model(tensor) probs = torch.softmax(logits, dim=1).squeeze(0) k = int(min(max(1, topk), probs.shape[0])) values, indices = torch.topk(probs, k=k) label_scores = {IDX2LABEL[i.item()]: float(v.item()) for v, i in zip(values, indices)} pred_label = IDX2LABEL[int(torch.argmax(probs).item())] msg = f"Predizione: **{pred_label}**" return label_scores, msg except Exception as e: return {}, f"Errore durante l'inferenza: {e}" with gr.Blocks(fill_height=True) as demo: gr.Markdown(f"# {TITLE}") gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=1): img_in = gr.Image(type="pil", sources=["upload", "webcam"], label="Immagine") topk = gr.Slider(1, len(IDX2LABEL), value=5, step=1, label="Top-K") btn = gr.Button("Analizza immagine") with gr.Column(scale=1): lbl = gr.Label(label="Probabilità", num_top_classes=len(IDX2LABEL)) txt = gr.Markdown() btn.click(predict, inputs=[img_in, topk], outputs=[lbl, txt]) img_in.change(predict, inputs=[img_in, topk], outputs=[lbl, txt]) if __name__ == "__main__": demo.launch()