Spaces:
Runtime error
Runtime error
| import os, json | |
| import torch | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| import numpy as np | |
| import gradio as gr | |
| import spaces # <-- IMPORTANTE | |
| MODEL_WEIGHTS = os.getenv("MODEL_WEIGHTS_PATH", "vit_b16_best.pth") | |
| CLASSES_PATH = os.getenv("CLASSES_PATH", "classes.json") | |
| IMAGE_SIZE = 224 | |
| def load_classes(path: str): | |
| with open(path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| def build_transforms(img_size=IMAGE_SIZE): | |
| return T.Compose([ | |
| T.Resize((img_size, img_size), interpolation=T.InterpolationMode.BICUBIC), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def create_model(num_classes: int): | |
| import timm | |
| return timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=num_classes) | |
| # cache CPU | |
| _model_cpu = None | |
| _classes = None | |
| _tfm = None | |
| def setup_cpu(): | |
| global _model_cpu, _classes, _tfm | |
| if _classes is None: | |
| _classes = load_classes(CLASSES_PATH) | |
| if _tfm is None: | |
| _tfm = build_transforms() | |
| if _model_cpu is None: | |
| if not os.path.exists(MODEL_WEIGHTS): | |
| raise FileNotFoundError(f"File pesi non trovato: {MODEL_WEIGHTS}") | |
| model = create_model(num_classes=len(_classes)) | |
| state = torch.load(MODEL_WEIGHTS, map_location="cpu") | |
| if isinstance(state, dict) and "state_dict" in state: | |
| state = state["state_dict"] | |
| state = {k.replace("module.", "").replace("model.", ""): v for k, v in state.items()} | |
| model.load_state_dict(state, strict=False) | |
| model.eval() | |
| _model_cpu = model | |
| # <-- QUI CHIEDI LA GPU ON-DEMAND (20s bastano per una singola inference) | |
| def predict(image: Image.Image): | |
| # prepara tutto su CPU | |
| setup_cpu() | |
| # sali su GPU solo durante la chiamata | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = _model_cpu | |
| if device == "cuda": | |
| model = model.to("cuda") | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| x = _tfm(image).unsqueeze(0) | |
| if device == "cuda": | |
| x = x.to("cuda") | |
| logits = model(x) | |
| probs = torch.softmax(logits, dim=1).detach().cpu().numpy().squeeze(0) | |
| # torna su CPU per rilasciare la GPU (ZeroGPU è pignolo) | |
| if device == "cuda": | |
| model.to("cpu") | |
| top_idx = np.argsort(-probs)[:3] | |
| top_labels = [_classes[i] for i in top_idx] | |
| top_scores = [float(probs[i]) for i in top_idx] | |
| pred_label = top_labels[0] | |
| pred_conf = round(top_scores[0] * 100, 2) | |
| result = { | |
| "prediction": pred_label, | |
| "confidence": pred_conf, | |
| "top3": [ | |
| {"label": top_labels[0], "confidence": round(top_scores[0] * 100, 2)}, | |
| {"label": top_labels[1], "confidence": round(top_scores[1] * 100, 2)}, | |
| {"label": top_labels[2], "confidence": round(top_scores[2] * 100, 2)}, | |
| ], | |
| "device": device, | |
| } | |
| human = f"Tipo: {pred_label} — Affidabilità: {pred_conf}% (device: {device})" | |
| return human, result | |
| title = "Corrosion Classifier (ViT-B/16 • ZeroGPU-ready)" | |
| description = ( | |
| "Carica o usa la webcam. Predice il tipo di corrosione con affidabilità. " | |
| "ZeroGPU: GPU allocata solo durante la predizione." | |
| ) | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(f"# {title}") | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Image(label="Immagine", type="pil", sources=["upload", "webcam"], image_mode="RGB") | |
| analyze_btn = gr.Button("Analizza immagine", variant="primary") | |
| with gr.Column(): | |
| out_text = gr.Textbox(label="Risultato", interactive=False) | |
| out_json = gr.JSON(label="Dettagli (top-3)") | |
| analyze_btn.click(fn=predict, inputs=[inp], outputs=[out_text, out_json]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |