Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from torchvision import transforms | |
| import gradio as gr | |
| from model import build_model, load_weights | |
| TITLE = "ResNet34 Corrosion Classifier" | |
| DESCRIPTION = """ | |
| Carica o scatta una foto. Il modello (ResNet34) restituisce la classe prevista e le probabilità. | |
| Assicurati di caricare il file dei pesi nella repo come `resnet34_best.pth` (o imposta la variabile di ambiente `CKPT_PATH`). | |
| """ | |
| CKPT_PATH = os.environ.get("CKPT_PATH", "resnet34_best.pth") | |
| CLASSES_PATH = os.environ.get("CLASSES_PATH", "classes.json") | |
| DEVICE = "cpu" | |
| os.environ.setdefault("OMP_NUM_THREADS", "1") | |
| os.environ.setdefault("MKL_NUM_THREADS", "1") | |
| if not os.path.isfile(CLASSES_PATH): | |
| raise FileNotFoundError(f"File classi non trovato: {CLASSES_PATH}") | |
| with open(CLASSES_PATH, "r", encoding="utf-8") as f: | |
| IDX2LABEL = json.load(f) | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| _model = None | |
| def get_model(): | |
| global _model | |
| if _model is None: | |
| if not os.path.isfile(CKPT_PATH): | |
| raise FileNotFoundError( | |
| f"Checkpoint non trovato: {CKPT_PATH}. Carica i pesi nella Space o imposta CKPT_PATH." | |
| ) | |
| model = build_model(num_classes=len(IDX2LABEL)) | |
| model = load_weights(model, CKPT_PATH, map_location=DEVICE) | |
| _model = model | |
| return _model | |
| def predict(image: Image.Image, topk: int = 5): | |
| try: | |
| if image is None: | |
| return {}, "Nessuna immagine." | |
| model = get_model() | |
| model.eval() | |
| with torch.no_grad(): | |
| img = image.convert("RGB") | |
| tensor = preprocess(img).unsqueeze(0) | |
| logits = model(tensor) | |
| probs = torch.softmax(logits, dim=1).squeeze(0) | |
| k = int(min(max(1, topk), probs.shape[0])) | |
| values, indices = torch.topk(probs, k=k) | |
| label_scores = {IDX2LABEL[i.item()]: float(v.item()) for v, i in zip(values, indices)} | |
| pred_label = IDX2LABEL[int(torch.argmax(probs).item())] | |
| msg = f"Predizione: **{pred_label}**" | |
| return label_scores, msg | |
| except Exception as e: | |
| return {}, f"Errore durante l'inferenza: {e}" | |
| with gr.Blocks(fill_height=True) as demo: | |
| gr.Markdown(f"# {TITLE}") | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_in = gr.Image(type="pil", sources=["upload", "webcam"], label="Immagine") | |
| topk = gr.Slider(1, len(IDX2LABEL), value=5, step=1, label="Top-K") | |
| btn = gr.Button("Analizza immagine") | |
| with gr.Column(scale=1): | |
| lbl = gr.Label(label="Probabilità", num_top_classes=len(IDX2LABEL)) | |
| txt = gr.Markdown() | |
| btn.click(predict, inputs=[img_in, topk], outputs=[lbl, txt]) | |
| img_in.change(predict, inputs=[img_in, topk], outputs=[lbl, txt]) | |
| if __name__ == "__main__": | |
| demo.launch() | |