Spaces:
Sleeping
Sleeping
| print("β app.py import started") | |
| import json | |
| import gradio as gr | |
| import torch | |
| import timm | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| print("β imports done") | |
| MODEL_PATH = "skinsight_model.pt" | |
| CLASS_PATH = "class_names.json" | |
| META_PATH = "model_meta.json" | |
| device = torch.device("cpu") | |
| # Load metadata | |
| with open(META_PATH, "r", encoding="utf-8") as f: | |
| meta = json.load(f) | |
| model_name = meta["model_name"] # tf_efficientnetv2_s | |
| img_size = int(meta.get("img_size", 320)) # 320 | |
| num_classes = int(meta.get("num_classes", 23)) | |
| # Load class names | |
| with open(CLASS_PATH, "r", encoding="utf-8") as f: | |
| class_names = json.load(f) | |
| assert len(class_names) == num_classes, "class_names.json length != num_classes" | |
| # ---- Preprocessing ---- | |
| # Most timm EfficientNet models are trained with ImageNet normalization. | |
| # If your training used different mean/std, tell me and we'll change it. | |
| transform = T.Compose([ | |
| T.Resize((img_size, img_size)), | |
| T.ToTensor(), | |
| T.Normalize(mean=(0.485, 0.456, 0.406), | |
| std=(0.229, 0.224, 0.225)), | |
| ]) | |
| print("β about to load model") | |
| def _build_model(): | |
| m = timm.create_model(model_name, pretrained=False, num_classes=num_classes) | |
| return m | |
| print("β about to load model") | |
| def _load_weights(m): | |
| obj = torch.load(MODEL_PATH, map_location="cpu") | |
| # Case 1: entire model saved | |
| if isinstance(obj, torch.nn.Module): | |
| m = obj | |
| m.eval() | |
| print("β model loaded, starting UI...") | |
| return m | |
| # Case 2: checkpoint dict or plain state_dict | |
| if isinstance(obj, dict): | |
| state = obj.get("state_dict", obj) | |
| # Handle DataParallel prefix "module." | |
| state = {k.replace("module.", ""): v for k, v in state.items()} | |
| missing, unexpected = m.load_state_dict(state, strict=False) | |
| # It's okay if strict=False; but big mismatches mean wrong architecture. | |
| m.eval() | |
| return m | |
| raise RuntimeError("Unknown .pt format: expected nn.Module or (checkpoint/)state_dict dict.") | |
| model = _build_model() | |
| model = _load_weights(model) | |
| model.to(device) | |
| model.eval() | |
| def predict(image: Image.Image): | |
| image = image.convert("RGB") | |
| x = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(x) | |
| probs = torch.softmax(logits, dim=1)[0] | |
| k = min(3, probs.shape[0]) | |
| topk = torch.topk(probs, k=k) | |
| out = [] | |
| for score, idx in zip(topk.values.tolist(), topk.indices.tolist()): | |
| out.append({"label": class_names[int(idx)], "score": float(score)}) | |
| return out | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload skin image"), | |
| outputs=gr.JSON(label="Top predictions"), | |
| title="SkinSight β tf_efficientnetv2_s (23 classes)", | |
| description="Free inference hosted on Hugging Face Spaces (CPU). First request after idle may take longer (cold start).", | |
| ) | |
| demo.launch() |