Corrobotv2 / app.py
jacopo22295's picture
Update app.py
112d633 verified
import os, json
import torch
import torchvision.transforms as T
from PIL import Image
import numpy as np
import gradio as gr
import spaces # <-- IMPORTANTE
MODEL_WEIGHTS = os.getenv("MODEL_WEIGHTS_PATH", "vit_b16_best.pth")
CLASSES_PATH = os.getenv("CLASSES_PATH", "classes.json")
IMAGE_SIZE = 224
def load_classes(path: str):
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def build_transforms(img_size=IMAGE_SIZE):
return T.Compose([
T.Resize((img_size, img_size), interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def create_model(num_classes: int):
import timm
return timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=num_classes)
# cache CPU
_model_cpu = None
_classes = None
_tfm = None
def setup_cpu():
global _model_cpu, _classes, _tfm
if _classes is None:
_classes = load_classes(CLASSES_PATH)
if _tfm is None:
_tfm = build_transforms()
if _model_cpu is None:
if not os.path.exists(MODEL_WEIGHTS):
raise FileNotFoundError(f"File pesi non trovato: {MODEL_WEIGHTS}")
model = create_model(num_classes=len(_classes))
state = torch.load(MODEL_WEIGHTS, map_location="cpu")
if isinstance(state, dict) and "state_dict" in state:
state = state["state_dict"]
state = {k.replace("module.", "").replace("model.", ""): v for k, v in state.items()}
model.load_state_dict(state, strict=False)
model.eval()
_model_cpu = model
@spaces.GPU(duration=20) # <-- QUI CHIEDI LA GPU ON-DEMAND (20s bastano per una singola inference)
@torch.inference_mode()
def predict(image: Image.Image):
# prepara tutto su CPU
setup_cpu()
# sali su GPU solo durante la chiamata
device = "cuda" if torch.cuda.is_available() else "cpu"
model = _model_cpu
if device == "cuda":
model = model.to("cuda")
if image.mode != "RGB":
image = image.convert("RGB")
x = _tfm(image).unsqueeze(0)
if device == "cuda":
x = x.to("cuda")
logits = model(x)
probs = torch.softmax(logits, dim=1).detach().cpu().numpy().squeeze(0)
# torna su CPU per rilasciare la GPU (ZeroGPU è pignolo)
if device == "cuda":
model.to("cpu")
top_idx = np.argsort(-probs)[:3]
top_labels = [_classes[i] for i in top_idx]
top_scores = [float(probs[i]) for i in top_idx]
pred_label = top_labels[0]
pred_conf = round(top_scores[0] * 100, 2)
result = {
"prediction": pred_label,
"confidence": pred_conf,
"top3": [
{"label": top_labels[0], "confidence": round(top_scores[0] * 100, 2)},
{"label": top_labels[1], "confidence": round(top_scores[1] * 100, 2)},
{"label": top_labels[2], "confidence": round(top_scores[2] * 100, 2)},
],
"device": device,
}
human = f"Tipo: {pred_label} — Affidabilità: {pred_conf}% (device: {device})"
return human, result
title = "Corrosion Classifier (ViT-B/16 • ZeroGPU-ready)"
description = (
"Carica o usa la webcam. Predice il tipo di corrosione con affidabilità. "
"ZeroGPU: GPU allocata solo durante la predizione."
)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(f"# {title}")
gr.Markdown(description)
with gr.Row():
with gr.Column():
inp = gr.Image(label="Immagine", type="pil", sources=["upload", "webcam"], image_mode="RGB")
analyze_btn = gr.Button("Analizza immagine", variant="primary")
with gr.Column():
out_text = gr.Textbox(label="Risultato", interactive=False)
out_json = gr.JSON(label="Dettagli (top-3)")
analyze_btn.click(fn=predict, inputs=[inp], outputs=[out_text, out_json])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)