Spaces:
Sleeping
Sleeping
File size: 7,316 Bytes
bc4c032 6f85d0e 007ae18 681b6e2 bc4c032 59123a7 bc4c032 007ae18 bc4c032 007ae18 af19b8e 007ae18 681b6e2 af19b8e 007ae18 681b6e2 bc4c032 044ce91 bc4c032 044ce91 bc4c032 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import json, os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import timm
from timm.data import resolve_model_data_config, create_transform
from transformers import AutoTokenizer, AutoModel
import gradio as gr
import ast
from huggingface_hub import hf_hub_download
SPACE_REPO = os.getenv("SPACE_REPO_NAME", "muruga778/api_for_model") # change if your space id differs
def safe_torch_load(filename: str):
"""
1) try local file
2) if corrupted -> force-download from Hub cache and load again
"""
try:
print(f"π Loading weights: {filename} (local)")
return torch.load(filename, map_location="cpu")
except Exception as e:
print(f"β οΈ Local load failed for {filename}: {repr(e)}")
print("β¬οΈ Force-downloading from Hugging Face Hub cache...")
cached = hf_hub_download(
repo_id=SPACE_REPO,
repo_type="space",
filename=filename,
force_download=True,
)
print("β
Downloaded to:", cached, "size(MB)=", os.path.getsize(cached)/1024/1024)
return torch.load(cached, map_location="cpu")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_json(path):
with open(path, "r") as f:
return json.load(f)
def clean_state_dict(sd):
for key in ["state_dict", "model", "model_state_dict"]:
if isinstance(sd, dict) and key in sd and isinstance(sd[key], dict):
sd = sd[key]
if isinstance(sd, dict) and any(k.startswith("module.") for k in sd.keys()):
sd = {k.replace("module.", "", 1): v for k, v in sd.items()}
return sd
def softmax_np(x):
x = x - np.max(x)
e = np.exp(x)
return e / (np.sum(e) + 1e-9)
# --- Triage rules (simple + demo friendly)
SEVERITY_BY_LABEL = {
"acne": 1, "tinea": 2, "tinea versicolor": 1, "eczema": 2, "urticaria": 2,
"psoriasis": 2, "folliculitis": 2, "impetigo": 3, "herpes zoster": 3,
"drug rash": 4, "scabies": 3, "unknown": 2
}
RED_FLAG_WORDS = [
"fever","breathing","shortness of breath","face","eye","mouth","genital",
"severe pain","blister","purple","swelling","rapid","spreading","bleeding"
]
def triage(label, conf, text):
label_l = (label or "").lower().strip()
text_l = (text or "").lower()
score = SEVERITY_BY_LABEL.get(label_l, 2)
hits = sum(1 for w in RED_FLAG_WORDS if w in text_l)
if hits >= 2: score += 2
elif hits == 1: score += 1
if conf < 0.50: score += 1
if conf < 0.35: score += 1
score = int(max(1, min(5, score)))
stage = "SELF-CARE / MONITOR" if score <= 2 else ("DOCTOR (24β48h)" if score <= 4 else "URGENT NOW")
note = "Not medical advice. If rapidly worsening / fever / face-eye involvement / breathing trouble β seek urgent care."
return stage, score, note
# ---- Load config + label map
CFG = load_json("fusion_config.json")
LABEL_MAP = load_json("label_map.json")
# Your label_map.json looks like: {"classes":[...], "label2idx":{...}}
if isinstance(LABEL_MAP, dict) and "classes" in LABEL_MAP and isinstance(LABEL_MAP["classes"], list):
CLASSES = [str(x) for x in LABEL_MAP["classes"]]
label2idx = LABEL_MAP.get("label2idx", {c: i for i, c in enumerate(CLASSES)})
# Older possible formats:
elif isinstance(LABEL_MAP, dict) and all(isinstance(k, str) and k.isdigit() for k in LABEL_MAP.keys()):
# {"0":"eczema", ...}
idx2label = {int(k): str(v) for k, v in LABEL_MAP.items()}
CLASSES = [idx2label[i] for i in sorted(idx2label.keys())]
label2idx = {c: i for i, c in enumerate(CLASSES)}
else:
# {"eczema": 0, ...}
label2idx = {str(k): int(v) for k, v in LABEL_MAP.items()}
CLASSES = [c for c, _ in sorted(label2idx.items(), key=lambda x: x[1])]
NUM_CLASSES = len(CLASSES)
print("β
NUM_CLASSES:", NUM_CLASSES)
print("β
First labels:", CLASSES[:5])
IMG_BACKBONE = CFG.get("img_backbone", "tf_efficientnetv2_s")
IMG_SIZE = int(CFG.get("img_size", 384))
TEXT_MODEL_NAME = CFG.get("text_model_name", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
MAX_LEN = int(CFG.get("max_len", 128))
# ---- Image model
img_model = timm.create_model(IMG_BACKBONE, pretrained=False, num_classes=NUM_CLASSES)
sd_img = clean_state_dict(safe_torch_load("best_scin_image.pt"))
img_model.load_state_dict(sd_img, strict=True)
img_model.to(DEVICE).eval()
data_cfg = resolve_model_data_config(img_model)
data_cfg["input_size"] = (3, IMG_SIZE, IMG_SIZE)
img_tfm = create_transform(**data_cfg, is_training=False)
# ---- Text model
class TextClassifier(nn.Module):
def __init__(self, model_name, num_classes, dropout=0.2):
super().__init__()
self.backbone = AutoModel.from_pretrained(model_name)
self.drop = nn.Dropout(dropout)
self.head = nn.Linear(self.backbone.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
feat = out.pooler_output if hasattr(out, "pooler_output") and out.pooler_output is not None else out.last_hidden_state[:, 0]
return self.head(self.drop(feat))
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
text_model = TextClassifier(TEXT_MODEL_NAME, NUM_CLASSES)
sd_txt = clean_state_dict(safe_torch_load("best_scin_text.pt"))
text_model.load_state_dict(sd_txt, strict=False)
text_model.to(DEVICE).eval()
W_IMG = float(CFG.get("fusion_weights", {}).get("image", 0.6))
W_TXT = float(CFG.get("fusion_weights", {}).get("text", 0.4))
s = W_IMG + W_TXT
W_IMG, W_TXT = W_IMG / s, W_TXT / s
@torch.inference_mode()
def predict(image, symptom_text, topk=3):
if image is None:
return "Upload an image.", ""
pil = image.convert("RGB") if hasattr(image, "convert") else Image.open(image).convert("RGB")
x_img = img_tfm(pil).unsqueeze(0).to(DEVICE)
tok = tokenizer(symptom_text or "", truncation=True, padding="max_length", max_length=MAX_LEN, return_tensors="pt")
tok = {k: v.to(DEVICE) for k, v in tok.items()}
img_logits = img_model(x_img)[0].detach().float().cpu().numpy()
txt_logits = text_model(tok["input_ids"], tok["attention_mask"])[0].detach().float().cpu().numpy()
p_img = softmax_np(img_logits)
p_txt = softmax_np(txt_logits)
p = W_IMG * p_img + W_TXT * p_txt
pred_idx = int(np.argmax(p))
pred_label = CLASSES[pred_idx]
conf = float(p[pred_idx])
k = min(int(topk), len(CLASSES))
top_idx = np.argsort(-p)[:k]
top_lines = [f"{i+1}) {CLASSES[int(ix)]} β {float(p[int(ix)]):.2f}" for i, ix in enumerate(top_idx)]
stage, sev_score, note = triage(pred_label, conf, symptom_text)
out1 = f"**Prediction:** {pred_label}\n\n**Confidence:** {conf:.2f}\n\n**Triage:** {stage} (score {sev_score}/5)\n\n{note}"
out2 = "\n".join(top_lines)
return out1, out2
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Skin image"),
gr.Textbox(lines=3, label="Symptoms (text)"),
gr.Slider(1, 5, value=3, step=1, label="Top-K"),
],
outputs=[
gr.Markdown(label="Result"),
gr.Textbox(label="Top-K"),
],
title="SmartSkin β SCIN Multimodal (Image + Text Fusion)",
description="Demo only. Not medical advice.",
)
if __name__ == "__main__":
demo.launch()
|