skinsight-space / app.py
xameedius's picture
Changed path
b31262a verified
print("βœ… app.py import started")
import json
import gradio as gr
import torch
import timm
from PIL import Image
import torchvision.transforms as T
print("βœ… imports done")
MODEL_PATH = "skinsight_model.pt"
CLASS_PATH = "class_names.json"
META_PATH = "model_meta.json"
device = torch.device("cpu")
# Load metadata
with open(META_PATH, "r", encoding="utf-8") as f:
meta = json.load(f)
model_name = meta["model_name"] # tf_efficientnetv2_s
img_size = int(meta.get("img_size", 320)) # 320
num_classes = int(meta.get("num_classes", 23))
# Load class names
with open(CLASS_PATH, "r", encoding="utf-8") as f:
class_names = json.load(f)
assert len(class_names) == num_classes, "class_names.json length != num_classes"
# ---- Preprocessing ----
# Most timm EfficientNet models are trained with ImageNet normalization.
# If your training used different mean/std, tell me and we'll change it.
transform = T.Compose([
T.Resize((img_size, img_size)),
T.ToTensor(),
T.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)),
])
print("βœ… about to load model")
def _build_model():
m = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
return m
print("βœ… about to load model")
def _load_weights(m):
obj = torch.load(MODEL_PATH, map_location="cpu")
# Case 1: entire model saved
if isinstance(obj, torch.nn.Module):
m = obj
m.eval()
print("βœ… model loaded, starting UI...")
return m
# Case 2: checkpoint dict or plain state_dict
if isinstance(obj, dict):
state = obj.get("state_dict", obj)
# Handle DataParallel prefix "module."
state = {k.replace("module.", ""): v for k, v in state.items()}
missing, unexpected = m.load_state_dict(state, strict=False)
# It's okay if strict=False; but big mismatches mean wrong architecture.
m.eval()
return m
raise RuntimeError("Unknown .pt format: expected nn.Module or (checkpoint/)state_dict dict.")
model = _build_model()
model = _load_weights(model)
model.to(device)
model.eval()
def predict(image: Image.Image):
image = image.convert("RGB")
x = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1)[0]
k = min(3, probs.shape[0])
topk = torch.topk(probs, k=k)
out = []
for score, idx in zip(topk.values.tolist(), topk.indices.tolist()):
out.append({"label": class_names[int(idx)], "score": float(score)})
return out
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload skin image"),
outputs=gr.JSON(label="Top predictions"),
title="SkinSight – tf_efficientnetv2_s (23 classes)",
description="Free inference hosted on Hugging Face Spaces (CPU). First request after idle may take longer (cold start).",
)
demo.launch()