muruga778 commited on
Commit
bc4c032
·
verified ·
1 Parent(s): 19d3fdc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, os
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import timm
8
+ from timm.data import resolve_model_data_config, create_transform
9
+ from transformers import AutoTokenizer, AutoModel
10
+ import gradio as gr
11
+
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ def load_json(path):
15
+ with open(path, "r") as f:
16
+ return json.load(f)
17
+
18
+ def clean_state_dict(sd):
19
+ for key in ["state_dict", "model", "model_state_dict"]:
20
+ if isinstance(sd, dict) and key in sd and isinstance(sd[key], dict):
21
+ sd = sd[key]
22
+ if isinstance(sd, dict) and any(k.startswith("module.") for k in sd.keys()):
23
+ sd = {k.replace("module.", "", 1): v for k, v in sd.items()}
24
+ return sd
25
+
26
+ def softmax_np(x):
27
+ x = x - np.max(x)
28
+ e = np.exp(x)
29
+ return e / (np.sum(e) + 1e-9)
30
+
31
+ # --- Triage rules (simple + demo friendly)
32
+ SEVERITY_BY_LABEL = {
33
+ "acne": 1, "tinea": 2, "tinea versicolor": 1, "eczema": 2, "urticaria": 2,
34
+ "psoriasis": 2, "folliculitis": 2, "impetigo": 3, "herpes zoster": 3,
35
+ "drug rash": 4, "scabies": 3, "unknown": 2
36
+ }
37
+ RED_FLAG_WORDS = [
38
+ "fever","breathing","shortness of breath","face","eye","mouth","genital",
39
+ "severe pain","blister","purple","swelling","rapid","spreading","bleeding"
40
+ ]
41
+
42
+ def triage(label, conf, text):
43
+ label_l = (label or "").lower().strip()
44
+ text_l = (text or "").lower()
45
+ score = SEVERITY_BY_LABEL.get(label_l, 2)
46
+ hits = sum(1 for w in RED_FLAG_WORDS if w in text_l)
47
+ if hits >= 2: score += 2
48
+ elif hits == 1: score += 1
49
+ if conf < 0.50: score += 1
50
+ if conf < 0.35: score += 1
51
+ score = int(max(1, min(5, score)))
52
+ stage = "SELF-CARE / MONITOR" if score <= 2 else ("DOCTOR (24–48h)" if score <= 4 else "URGENT NOW")
53
+ note = "Not medical advice. If rapidly worsening / fever / face-eye involvement / breathing trouble → seek urgent care."
54
+ return stage, score, note
55
+
56
+ # ---- Load config + label map
57
+ CFG = load_json("fusion_config.json")
58
+ LABEL_MAP = load_json("label_map.json")
59
+
60
+ # label_map can be {"label": idx} or {"0":"label"}
61
+ if all(isinstance(k, str) and k.isdigit() for k in LABEL_MAP.keys()):
62
+ idx2label = {int(k): v for k, v in LABEL_MAP.items()}
63
+ CLASSES = [idx2label[i] for i in sorted(idx2label.keys())]
64
+ else:
65
+ label2idx = {k: int(v) for k, v in LABEL_MAP.items()}
66
+ CLASSES = [c for c, _ in sorted(label2idx.items(), key=lambda x: x[1])]
67
+
68
+ NUM_CLASSES = len(CLASSES)
69
+
70
+ IMG_BACKBONE = CFG.get("img_backbone", "tf_efficientnetv2_s")
71
+ IMG_SIZE = int(CFG.get("img_size", 384))
72
+ TEXT_MODEL_NAME = CFG.get("text_model_name", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
73
+ MAX_LEN = int(CFG.get("max_len", 128))
74
+
75
+ # ---- Image model
76
+ img_model = timm.create_model(IMG_BACKBONE, pretrained=False, num_classes=NUM_CLASSES)
77
+ sd_img = clean_state_dict(torch.load("best_scin_image.pt", map_location="cpu"))
78
+ img_model.load_state_dict(sd_img, strict=True)
79
+ img_model.to(DEVICE).eval()
80
+
81
+ data_cfg = resolve_model_data_config(img_model)
82
+ data_cfg["input_size"] = (3, IMG_SIZE, IMG_SIZE)
83
+ img_tfm = create_transform(**data_cfg, is_training=False)
84
+
85
+ # ---- Text model
86
+ class TextClassifier(nn.Module):
87
+ def __init__(self, model_name, num_classes, dropout=0.2):
88
+ super().__init__()
89
+ self.backbone = AutoModel.from_pretrained(model_name)
90
+ self.drop = nn.Dropout(dropout)
91
+ self.head = nn.Linear(self.backbone.config.hidden_size, num_classes)
92
+ def forward(self, input_ids, attention_mask):
93
+ out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
94
+ feat = out.pooler_output if hasattr(out, "pooler_output") and out.pooler_output is not None else out.last_hidden_state[:, 0]
95
+ return self.head(self.drop(feat))
96
+
97
+ tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
98
+ text_model = TextClassifier(TEXT_MODEL_NAME, NUM_CLASSES)
99
+ sd_txt = clean_state_dict(torch.load("best_scin_text.pt", map_location="cpu"))
100
+ text_model.load_state_dict(sd_txt, strict=False)
101
+ text_model.to(DEVICE).eval()
102
+
103
+ W_IMG = float(CFG.get("fusion_weights", {}).get("image", 0.6))
104
+ W_TXT = float(CFG.get("fusion_weights", {}).get("text", 0.4))
105
+ s = W_IMG + W_TXT
106
+ W_IMG, W_TXT = W_IMG / s, W_TXT / s
107
+
108
+ @torch.inference_mode()
109
+ def predict(image, symptom_text, topk=3):
110
+ if image is None:
111
+ return "Upload an image.", ""
112
+
113
+ pil = image.convert("RGB") if hasattr(image, "convert") else Image.open(image).convert("RGB")
114
+ x_img = img_tfm(pil).unsqueeze(0).to(DEVICE)
115
+
116
+ tok = tokenizer(symptom_text or "", truncation=True, padding="max_length", max_length=MAX_LEN, return_tensors="pt")
117
+ tok = {k: v.to(DEVICE) for k, v in tok.items()}
118
+
119
+ img_logits = img_model(x_img)[0].detach().float().cpu().numpy()
120
+ txt_logits = text_model(tok["input_ids"], tok["attention_mask"])[0].detach().float().cpu().numpy()
121
+
122
+ p_img = softmax_np(img_logits)
123
+ p_txt = softmax_np(txt_logits)
124
+ p = W_IMG * p_img + W_TXT * p_txt
125
+
126
+ pred_idx = int(np.argmax(p))
127
+ pred_label = CLASSES[pred_idx]
128
+ conf = float(p[pred_idx])
129
+
130
+ k = min(int(topk), len(CLASSES))
131
+ top_idx = np.argsort(-p)[:k]
132
+ top_lines = [f"{i+1}) {CLASSES[int(ix)]} — {float(p[int(ix)]):.2f}" for i, ix in enumerate(top_idx)]
133
+
134
+ stage, sev_score, note = triage(pred_label, conf, symptom_text)
135
+
136
+ out1 = f"**Prediction:** {pred_label}\n\n**Confidence:** {conf:.2f}\n\n**Triage:** {stage} (score {sev_score}/5)\n\n{note}"
137
+ out2 = "\n".join(top_lines)
138
+ return out1, out2
139
+
140
+ demo = gr.Interface(
141
+ fn=predict,
142
+ inputs=[
143
+ gr.Image(type="pil", label="Skin image"),
144
+ gr.Textbox(lines=3, label="Symptoms (text)"),
145
+ gr.Slider(1, 5, value=3, step=1, label="Top-K"),
146
+ ],
147
+ outputs=[
148
+ gr.Markdown(label="Result"),
149
+ gr.Textbox(label="Top-K"),
150
+ ],
151
+ title="SmartSkin — SCIN Multimodal (Image + Text Fusion)",
152
+ description="Demo only. Not medical advice.",
153
+ )
154
+
155
+ if __name__ == "__main__":
156
+ demo.launch()