File size: 7,316 Bytes
bc4c032
 
 
6f85d0e
007ae18
681b6e2
bc4c032
59123a7
bc4c032
 
 
 
 
007ae18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc4c032
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
007ae18
 
 
 
 
 
 
 
 
 
 
af19b8e
 
007ae18
 
 
681b6e2
af19b8e
007ae18
 
681b6e2
bc4c032
 
 
 
 
 
 
044ce91
bc4c032
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
044ce91
bc4c032
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import json, os
import numpy as np
from PIL import Image




import torch
import torch.nn as nn
import timm
from timm.data import resolve_model_data_config, create_transform
from transformers import AutoTokenizer, AutoModel
import gradio as gr
import ast
from huggingface_hub import hf_hub_download

SPACE_REPO = os.getenv("SPACE_REPO_NAME", "muruga778/api_for_model")  # change if your space id differs

def safe_torch_load(filename: str):
    """
    1) try local file
    2) if corrupted -> force-download from Hub cache and load again
    """
    try:
        print(f"πŸ”Ž Loading weights: {filename} (local)")
        return torch.load(filename, map_location="cpu")
    except Exception as e:
        print(f"⚠️ Local load failed for {filename}: {repr(e)}")
        print("⬇️ Force-downloading from Hugging Face Hub cache...")
        cached = hf_hub_download(
            repo_id=SPACE_REPO,
            repo_type="space",
            filename=filename,
            force_download=True,
        )
        print("βœ… Downloaded to:", cached, "size(MB)=", os.path.getsize(cached)/1024/1024)
        return torch.load(cached, map_location="cpu")


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def load_json(path):
    with open(path, "r") as f:
        return json.load(f)

def clean_state_dict(sd):
    for key in ["state_dict", "model", "model_state_dict"]:
        if isinstance(sd, dict) and key in sd and isinstance(sd[key], dict):
            sd = sd[key]
    if isinstance(sd, dict) and any(k.startswith("module.") for k in sd.keys()):
        sd = {k.replace("module.", "", 1): v for k, v in sd.items()}
    return sd

def softmax_np(x):
    x = x - np.max(x)
    e = np.exp(x)
    return e / (np.sum(e) + 1e-9)

# --- Triage rules (simple + demo friendly)
SEVERITY_BY_LABEL = {
    "acne": 1, "tinea": 2, "tinea versicolor": 1, "eczema": 2, "urticaria": 2,
    "psoriasis": 2, "folliculitis": 2, "impetigo": 3, "herpes zoster": 3,
    "drug rash": 4, "scabies": 3, "unknown": 2
}
RED_FLAG_WORDS = [
    "fever","breathing","shortness of breath","face","eye","mouth","genital",
    "severe pain","blister","purple","swelling","rapid","spreading","bleeding"
]

def triage(label, conf, text):
    label_l = (label or "").lower().strip()
    text_l = (text or "").lower()
    score = SEVERITY_BY_LABEL.get(label_l, 2)
    hits = sum(1 for w in RED_FLAG_WORDS if w in text_l)
    if hits >= 2: score += 2
    elif hits == 1: score += 1
    if conf < 0.50: score += 1
    if conf < 0.35: score += 1
    score = int(max(1, min(5, score)))
    stage = "SELF-CARE / MONITOR" if score <= 2 else ("DOCTOR (24–48h)" if score <= 4 else "URGENT NOW")
    note = "Not medical advice. If rapidly worsening / fever / face-eye involvement / breathing trouble β†’ seek urgent care."
    return stage, score, note

# ---- Load config + label map
CFG = load_json("fusion_config.json")
LABEL_MAP = load_json("label_map.json")

# Your label_map.json looks like: {"classes":[...], "label2idx":{...}}
if isinstance(LABEL_MAP, dict) and "classes" in LABEL_MAP and isinstance(LABEL_MAP["classes"], list):
    CLASSES = [str(x) for x in LABEL_MAP["classes"]]
    label2idx = LABEL_MAP.get("label2idx", {c: i for i, c in enumerate(CLASSES)})

# Older possible formats:
elif isinstance(LABEL_MAP, dict) and all(isinstance(k, str) and k.isdigit() for k in LABEL_MAP.keys()):
    # {"0":"eczema", ...}
    idx2label = {int(k): str(v) for k, v in LABEL_MAP.items()}
    CLASSES = [idx2label[i] for i in sorted(idx2label.keys())]
    label2idx = {c: i for i, c in enumerate(CLASSES)}

else:
    # {"eczema": 0, ...}
    label2idx = {str(k): int(v) for k, v in LABEL_MAP.items()}
    CLASSES = [c for c, _ in sorted(label2idx.items(), key=lambda x: x[1])]

NUM_CLASSES = len(CLASSES)
print("βœ… NUM_CLASSES:", NUM_CLASSES)
print("βœ… First labels:", CLASSES[:5])

IMG_BACKBONE = CFG.get("img_backbone", "tf_efficientnetv2_s")
IMG_SIZE = int(CFG.get("img_size", 384))
TEXT_MODEL_NAME = CFG.get("text_model_name", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
MAX_LEN = int(CFG.get("max_len", 128))

# ---- Image model
img_model = timm.create_model(IMG_BACKBONE, pretrained=False, num_classes=NUM_CLASSES)
sd_img = clean_state_dict(safe_torch_load("best_scin_image.pt"))
img_model.load_state_dict(sd_img, strict=True)
img_model.to(DEVICE).eval()

data_cfg = resolve_model_data_config(img_model)
data_cfg["input_size"] = (3, IMG_SIZE, IMG_SIZE)
img_tfm = create_transform(**data_cfg, is_training=False)

# ---- Text model
class TextClassifier(nn.Module):
    def __init__(self, model_name, num_classes, dropout=0.2):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(model_name)
        self.drop = nn.Dropout(dropout)
        self.head = nn.Linear(self.backbone.config.hidden_size, num_classes)
    def forward(self, input_ids, attention_mask):
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        feat = out.pooler_output if hasattr(out, "pooler_output") and out.pooler_output is not None else out.last_hidden_state[:, 0]
        return self.head(self.drop(feat))

tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
text_model = TextClassifier(TEXT_MODEL_NAME, NUM_CLASSES)
sd_txt = clean_state_dict(safe_torch_load("best_scin_text.pt"))
text_model.load_state_dict(sd_txt, strict=False)
text_model.to(DEVICE).eval()

W_IMG = float(CFG.get("fusion_weights", {}).get("image", 0.6))
W_TXT = float(CFG.get("fusion_weights", {}).get("text", 0.4))
s = W_IMG + W_TXT
W_IMG, W_TXT = W_IMG / s, W_TXT / s

@torch.inference_mode()
def predict(image, symptom_text, topk=3):
    if image is None:
        return "Upload an image.", ""

    pil = image.convert("RGB") if hasattr(image, "convert") else Image.open(image).convert("RGB")
    x_img = img_tfm(pil).unsqueeze(0).to(DEVICE)

    tok = tokenizer(symptom_text or "", truncation=True, padding="max_length", max_length=MAX_LEN, return_tensors="pt")
    tok = {k: v.to(DEVICE) for k, v in tok.items()}

    img_logits = img_model(x_img)[0].detach().float().cpu().numpy()
    txt_logits = text_model(tok["input_ids"], tok["attention_mask"])[0].detach().float().cpu().numpy()

    p_img = softmax_np(img_logits)
    p_txt = softmax_np(txt_logits)
    p = W_IMG * p_img + W_TXT * p_txt

    pred_idx = int(np.argmax(p))
    pred_label = CLASSES[pred_idx]
    conf = float(p[pred_idx])

    k = min(int(topk), len(CLASSES))
    top_idx = np.argsort(-p)[:k]
    top_lines = [f"{i+1}) {CLASSES[int(ix)]} β€” {float(p[int(ix)]):.2f}" for i, ix in enumerate(top_idx)]

    stage, sev_score, note = triage(pred_label, conf, symptom_text)

    out1 = f"**Prediction:** {pred_label}\n\n**Confidence:** {conf:.2f}\n\n**Triage:** {stage} (score {sev_score}/5)\n\n{note}"
    out2 = "\n".join(top_lines)
    return out1, out2

demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Skin image"),
        gr.Textbox(lines=3, label="Symptoms (text)"),
        gr.Slider(1, 5, value=3, step=1, label="Top-K"),
    ],
    outputs=[
        gr.Markdown(label="Result"),
        gr.Textbox(label="Top-K"),
    ],
    title="SmartSkin β€” SCIN Multimodal (Image + Text Fusion)",
    description="Demo only. Not medical advice.",
)

if __name__ == "__main__":
    demo.launch()