import os, json import torch import torchvision.transforms as T from PIL import Image import numpy as np import gradio as gr import spaces # <-- IMPORTANTE MODEL_WEIGHTS = os.getenv("MODEL_WEIGHTS_PATH", "vit_b16_best.pth") CLASSES_PATH = os.getenv("CLASSES_PATH", "classes.json") IMAGE_SIZE = 224 def load_classes(path: str): with open(path, "r", encoding="utf-8") as f: return json.load(f) def build_transforms(img_size=IMAGE_SIZE): return T.Compose([ T.Resize((img_size, img_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def create_model(num_classes: int): import timm return timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=num_classes) # cache CPU _model_cpu = None _classes = None _tfm = None def setup_cpu(): global _model_cpu, _classes, _tfm if _classes is None: _classes = load_classes(CLASSES_PATH) if _tfm is None: _tfm = build_transforms() if _model_cpu is None: if not os.path.exists(MODEL_WEIGHTS): raise FileNotFoundError(f"File pesi non trovato: {MODEL_WEIGHTS}") model = create_model(num_classes=len(_classes)) state = torch.load(MODEL_WEIGHTS, map_location="cpu") if isinstance(state, dict) and "state_dict" in state: state = state["state_dict"] state = {k.replace("module.", "").replace("model.", ""): v for k, v in state.items()} model.load_state_dict(state, strict=False) model.eval() _model_cpu = model @spaces.GPU(duration=20) # <-- QUI CHIEDI LA GPU ON-DEMAND (20s bastano per una singola inference) @torch.inference_mode() def predict(image: Image.Image): # prepara tutto su CPU setup_cpu() # sali su GPU solo durante la chiamata device = "cuda" if torch.cuda.is_available() else "cpu" model = _model_cpu if device == "cuda": model = model.to("cuda") if image.mode != "RGB": image = image.convert("RGB") x = _tfm(image).unsqueeze(0) if device == "cuda": x = x.to("cuda") logits = model(x) probs = torch.softmax(logits, dim=1).detach().cpu().numpy().squeeze(0) # torna su CPU per rilasciare la GPU (ZeroGPU è pignolo) if device == "cuda": model.to("cpu") top_idx = np.argsort(-probs)[:3] top_labels = [_classes[i] for i in top_idx] top_scores = [float(probs[i]) for i in top_idx] pred_label = top_labels[0] pred_conf = round(top_scores[0] * 100, 2) result = { "prediction": pred_label, "confidence": pred_conf, "top3": [ {"label": top_labels[0], "confidence": round(top_scores[0] * 100, 2)}, {"label": top_labels[1], "confidence": round(top_scores[1] * 100, 2)}, {"label": top_labels[2], "confidence": round(top_scores[2] * 100, 2)}, ], "device": device, } human = f"Tipo: {pred_label} — Affidabilità: {pred_conf}% (device: {device})" return human, result title = "Corrosion Classifier (ViT-B/16 • ZeroGPU-ready)" description = ( "Carica o usa la webcam. Predice il tipo di corrosione con affidabilità. " "ZeroGPU: GPU allocata solo durante la predizione." ) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(f"# {title}") gr.Markdown(description) with gr.Row(): with gr.Column(): inp = gr.Image(label="Immagine", type="pil", sources=["upload", "webcam"], image_mode="RGB") analyze_btn = gr.Button("Analizza immagine", variant="primary") with gr.Column(): out_text = gr.Textbox(label="Risultato", interactive=False) out_json = gr.JSON(label="Dettagli (top-3)") analyze_btn.click(fn=predict, inputs=[inp], outputs=[out_text, out_json]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)