api_for_model / app.py
muruga778's picture
Update app.py
044ce91 verified
raw
history blame
7.32 kB
import json, os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import timm
from timm.data import resolve_model_data_config, create_transform
from transformers import AutoTokenizer, AutoModel
import gradio as gr
import ast
from huggingface_hub import hf_hub_download
SPACE_REPO = os.getenv("SPACE_REPO_NAME", "muruga778/api_for_model") # change if your space id differs
def safe_torch_load(filename: str):
"""
1) try local file
2) if corrupted -> force-download from Hub cache and load again
"""
try:
print(f"πŸ”Ž Loading weights: {filename} (local)")
return torch.load(filename, map_location="cpu")
except Exception as e:
print(f"⚠️ Local load failed for {filename}: {repr(e)}")
print("⬇️ Force-downloading from Hugging Face Hub cache...")
cached = hf_hub_download(
repo_id=SPACE_REPO,
repo_type="space",
filename=filename,
force_download=True,
)
print("βœ… Downloaded to:", cached, "size(MB)=", os.path.getsize(cached)/1024/1024)
return torch.load(cached, map_location="cpu")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_json(path):
with open(path, "r") as f:
return json.load(f)
def clean_state_dict(sd):
for key in ["state_dict", "model", "model_state_dict"]:
if isinstance(sd, dict) and key in sd and isinstance(sd[key], dict):
sd = sd[key]
if isinstance(sd, dict) and any(k.startswith("module.") for k in sd.keys()):
sd = {k.replace("module.", "", 1): v for k, v in sd.items()}
return sd
def softmax_np(x):
x = x - np.max(x)
e = np.exp(x)
return e / (np.sum(e) + 1e-9)
# --- Triage rules (simple + demo friendly)
SEVERITY_BY_LABEL = {
"acne": 1, "tinea": 2, "tinea versicolor": 1, "eczema": 2, "urticaria": 2,
"psoriasis": 2, "folliculitis": 2, "impetigo": 3, "herpes zoster": 3,
"drug rash": 4, "scabies": 3, "unknown": 2
}
RED_FLAG_WORDS = [
"fever","breathing","shortness of breath","face","eye","mouth","genital",
"severe pain","blister","purple","swelling","rapid","spreading","bleeding"
]
def triage(label, conf, text):
label_l = (label or "").lower().strip()
text_l = (text or "").lower()
score = SEVERITY_BY_LABEL.get(label_l, 2)
hits = sum(1 for w in RED_FLAG_WORDS if w in text_l)
if hits >= 2: score += 2
elif hits == 1: score += 1
if conf < 0.50: score += 1
if conf < 0.35: score += 1
score = int(max(1, min(5, score)))
stage = "SELF-CARE / MONITOR" if score <= 2 else ("DOCTOR (24–48h)" if score <= 4 else "URGENT NOW")
note = "Not medical advice. If rapidly worsening / fever / face-eye involvement / breathing trouble β†’ seek urgent care."
return stage, score, note
# ---- Load config + label map
CFG = load_json("fusion_config.json")
LABEL_MAP = load_json("label_map.json")
# Your label_map.json looks like: {"classes":[...], "label2idx":{...}}
if isinstance(LABEL_MAP, dict) and "classes" in LABEL_MAP and isinstance(LABEL_MAP["classes"], list):
CLASSES = [str(x) for x in LABEL_MAP["classes"]]
label2idx = LABEL_MAP.get("label2idx", {c: i for i, c in enumerate(CLASSES)})
# Older possible formats:
elif isinstance(LABEL_MAP, dict) and all(isinstance(k, str) and k.isdigit() for k in LABEL_MAP.keys()):
# {"0":"eczema", ...}
idx2label = {int(k): str(v) for k, v in LABEL_MAP.items()}
CLASSES = [idx2label[i] for i in sorted(idx2label.keys())]
label2idx = {c: i for i, c in enumerate(CLASSES)}
else:
# {"eczema": 0, ...}
label2idx = {str(k): int(v) for k, v in LABEL_MAP.items()}
CLASSES = [c for c, _ in sorted(label2idx.items(), key=lambda x: x[1])]
NUM_CLASSES = len(CLASSES)
print("βœ… NUM_CLASSES:", NUM_CLASSES)
print("βœ… First labels:", CLASSES[:5])
IMG_BACKBONE = CFG.get("img_backbone", "tf_efficientnetv2_s")
IMG_SIZE = int(CFG.get("img_size", 384))
TEXT_MODEL_NAME = CFG.get("text_model_name", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
MAX_LEN = int(CFG.get("max_len", 128))
# ---- Image model
img_model = timm.create_model(IMG_BACKBONE, pretrained=False, num_classes=NUM_CLASSES)
sd_img = clean_state_dict(safe_torch_load("best_scin_image.pt"))
img_model.load_state_dict(sd_img, strict=True)
img_model.to(DEVICE).eval()
data_cfg = resolve_model_data_config(img_model)
data_cfg["input_size"] = (3, IMG_SIZE, IMG_SIZE)
img_tfm = create_transform(**data_cfg, is_training=False)
# ---- Text model
class TextClassifier(nn.Module):
def __init__(self, model_name, num_classes, dropout=0.2):
super().__init__()
self.backbone = AutoModel.from_pretrained(model_name)
self.drop = nn.Dropout(dropout)
self.head = nn.Linear(self.backbone.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
feat = out.pooler_output if hasattr(out, "pooler_output") and out.pooler_output is not None else out.last_hidden_state[:, 0]
return self.head(self.drop(feat))
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
text_model = TextClassifier(TEXT_MODEL_NAME, NUM_CLASSES)
sd_txt = clean_state_dict(safe_torch_load("best_scin_text.pt"))
text_model.load_state_dict(sd_txt, strict=False)
text_model.to(DEVICE).eval()
W_IMG = float(CFG.get("fusion_weights", {}).get("image", 0.6))
W_TXT = float(CFG.get("fusion_weights", {}).get("text", 0.4))
s = W_IMG + W_TXT
W_IMG, W_TXT = W_IMG / s, W_TXT / s
@torch.inference_mode()
def predict(image, symptom_text, topk=3):
if image is None:
return "Upload an image.", ""
pil = image.convert("RGB") if hasattr(image, "convert") else Image.open(image).convert("RGB")
x_img = img_tfm(pil).unsqueeze(0).to(DEVICE)
tok = tokenizer(symptom_text or "", truncation=True, padding="max_length", max_length=MAX_LEN, return_tensors="pt")
tok = {k: v.to(DEVICE) for k, v in tok.items()}
img_logits = img_model(x_img)[0].detach().float().cpu().numpy()
txt_logits = text_model(tok["input_ids"], tok["attention_mask"])[0].detach().float().cpu().numpy()
p_img = softmax_np(img_logits)
p_txt = softmax_np(txt_logits)
p = W_IMG * p_img + W_TXT * p_txt
pred_idx = int(np.argmax(p))
pred_label = CLASSES[pred_idx]
conf = float(p[pred_idx])
k = min(int(topk), len(CLASSES))
top_idx = np.argsort(-p)[:k]
top_lines = [f"{i+1}) {CLASSES[int(ix)]} β€” {float(p[int(ix)]):.2f}" for i, ix in enumerate(top_idx)]
stage, sev_score, note = triage(pred_label, conf, symptom_text)
out1 = f"**Prediction:** {pred_label}\n\n**Confidence:** {conf:.2f}\n\n**Triage:** {stage} (score {sev_score}/5)\n\n{note}"
out2 = "\n".join(top_lines)
return out1, out2
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Skin image"),
gr.Textbox(lines=3, label="Symptoms (text)"),
gr.Slider(1, 5, value=3, step=1, label="Top-K"),
],
outputs=[
gr.Markdown(label="Result"),
gr.Textbox(label="Top-K"),
],
title="SmartSkin β€” SCIN Multimodal (Image + Text Fusion)",
description="Demo only. Not medical advice.",
)
if __name__ == "__main__":
demo.launch()