pedrobouzon commited on
Commit
30cb9cf
·
1 Parent(s): dc64968

first commit

Browse files
Files changed (1) hide show
  1. app.py +284 -0
app.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import onnxruntime as ort
4
+ import numpy as np
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ import torch.nn.functional as F
8
+ import matplotlib.pyplot as plt
9
+
10
+
11
+ _metadata_columns = [
12
+ "age", "usePesticide_I", "usePesticide_False", "usePesticide_True", "gender_M", "gender_F", "gender_O",
13
+ "familySkinCancerHistory_False", "familySkinCancerHistory_True", "familySkinCancerHistory_I", "familyCancerHistory_True",
14
+ "familyCancerHistory_False", "familyCancerHistory_I", "fitzpatrickSkinType_2.0", "fitzpatrickSkinType_1.0",
15
+ "fitzpatrickSkinType_4.0", "fitzpatrickSkinType_3.0", "fitzpatrickSkinType_5.0", "macroBodyRegion_PEITORAL",
16
+ "macroBodyRegion_NARIZ", "macroBodyRegion_LABIOS", "macroBodyRegion_DORSO", "macroBodyRegion_ANTEBRACO", "macroBodyRegion_BRACO",
17
+ "macroBodyRegion_PERNA", "macroBodyRegion_FACE", "macroBodyRegion_MAO", "macroBodyRegion_COURO CABELUDO", "macroBodyRegion_PESCOCO",
18
+ "macroBodyRegion_PE", "macroBodyRegion_ORELHA", "macroBodyRegion_COXA", "macroBodyRegion_ABDOME",
19
+ "hasItched_True", "hasItched_False", "hasItched_I", "hasGrown_I", "hasGrown_False", "hasGrown_True", "hasHurt_True", "hasHurt_False",
20
+ "hasHurt_I", "hasChanged_I", "hasChanged_False", "hasChanged_True", "hasBled_False", "hasBled_True", "hasBled_I", "hasElevation_I",
21
+ "hasElevation_False", "hasElevation_True"
22
+ ]
23
+
24
+ metadata_mapping = {
25
+ "age": "age",
26
+ "usePesticide_I": "usePesticide_I",
27
+ "usePesticide_False": "pesticide_False",
28
+ "usePesticide_True": "pesticide_True",
29
+ "gender_M": "gender_MALE",
30
+ "gender_F": "gender_FEMALE",
31
+ "gender_O": "gender_OTHER",
32
+ "familySkinCancerHistory_False": "skin_cancer_history_False",
33
+ "familySkinCancerHistory_True": "skin_cancer_history_True",
34
+ "familySkinCancerHistory_I": "familySkinCancerHistory_I",
35
+ "familyCancerHistory_True": "cancer_history_True",
36
+ "familyCancerHistory_False": "cancer_history_False",
37
+ "familyCancerHistory_I": "familyCancerHistory_I",
38
+ "fitzpatrickSkinType_2.0": "fitspatrick_2.0",
39
+ "fitzpatrickSkinType_1.0": "fitspatrick_1.0",
40
+ "fitzpatrickSkinType_4.0": "fitspatrick_4.0",
41
+ "fitzpatrickSkinType_3.0": "fitspatrick_3.0",
42
+ "fitzpatrickSkinType_5.0": "fitspatrick_5.0",
43
+ "fitzpatrickSkinType_6.0": "fitspatrick_6.0",
44
+ "macroBodyRegion_PEITORAL": "region_CHEST",
45
+ "macroBodyRegion_NARIZ": "region_NOSE",
46
+ "macroBodyRegion_LABIOS": "region_LIP",
47
+ "macroBodyRegion_DORSO": "region_BACK",
48
+ "macroBodyRegion_ANTEBRACO": "region_FOREARM",
49
+ "macroBodyRegion_BRACO": "region_ARM",
50
+ "macroBodyRegion_PERNA": "region_THIGH",
51
+ "macroBodyRegion_FACE": "region_FACE",
52
+ "macroBodyRegion_MAO": "region_HAND",
53
+ "macroBodyRegion_COURO CABELUDO": "region_SCALP",
54
+ "macroBodyRegion_PESCOCO": "region_NECK",
55
+ "macroBodyRegion_PE": "region_FOOT",
56
+ "macroBodyRegion_ORELHA": "region_EAR",
57
+ "macroBodyRegion_COXA": "region_THIGH",
58
+ "macroBodyRegion_ABDOME": "region_ABDOMEN",
59
+ "hasItched_True": "itch_True",
60
+ "hasItched_False": "itch_False",
61
+ "hasItched_I": "itch_UNK",
62
+ "hasGrown_I": "grew_UNK",
63
+ "hasGrown_False": "grew_False",
64
+ "hasGrown_True": "grew_True",
65
+ "hasHurt_True": "hurt_True",
66
+ "hasHurt_False": "hurt_False",
67
+ "hasHurt_I": "hurt_UNK",
68
+ "hasChanged_I": "changed_UNK",
69
+ "hasChanged_False": "changed_False",
70
+ "hasChanged_True": "changed_True",
71
+ "hasBled_False": "bleed_False",
72
+ "hasBled_True": "bleed_True",
73
+ "hasBled_I": "bleed_UNK",
74
+ "hasElevation_I": "elevation_UNK",
75
+ "hasElevation_False": "elevation_False",
76
+ "hasElevation_True": "elevation_True"
77
+ }
78
+
79
+ _metadata_columns = [metadata_mapping[col] for col in _metadata_columns if metadata_mapping[col] is not None]
80
+ try:
81
+ ort_session = ort.InferenceSession("./pad25_mobilenetv3_folder_1.onnx")
82
+ print("ONNX model loaded successfully.")
83
+ except Exception as e:
84
+ print(f"Error loading ONNX model: {e}")
85
+ ort_session = None
86
+
87
+ LABELS = ['ACK', 'BCC', 'MEL', 'NEV', 'SCC', 'SEK']
88
+
89
+ def create_plot(probs_history, steps_labels):
90
+ fig, ax = plt.subplots(figsize=(10, 6))
91
+
92
+ class_data = {label: [] for label in LABELS}
93
+ for step_probs in probs_history:
94
+ for label, prob in step_probs.items():
95
+ class_data[label].append(prob * 100)
96
+
97
+ # Identify top 3 classes based on final probability
98
+ final_probs = {label: values[-1] for label, values in class_data.items()}
99
+ top_classes = sorted(final_probs, key=final_probs.get, reverse=True)[:3]
100
+
101
+ annotations = {}
102
+
103
+ # Plot every class
104
+ for name, values in class_data.items():
105
+ x_vals = range(len(values))
106
+
107
+ # Style logic
108
+ if name in top_classes: # Highlight top classes
109
+ line, = ax.plot(x_vals, values, label=name, linewidth=2, marker='o')
110
+ color = line.get_color()
111
+
112
+ # Collect Text Annotations
113
+ for x, y in zip(x_vals, values):
114
+ if x not in annotations:
115
+ annotations[x] = []
116
+ annotations[x].append((y, f"{y:.1f}", color))
117
+ else:
118
+ # Other low prob classes (faded)
119
+ ax.plot(x_vals, values, label=name, alpha=1, linewidth=1)
120
+
121
+ # Process annotations to avoid overlap
122
+ for x in sorted(annotations.keys()):
123
+ points = sorted(annotations[x], key=lambda p: p[0])
124
+
125
+ min_dist = 5
126
+ last_text_y = -100
127
+
128
+ for i, (y, text, color) in enumerate(points):
129
+ text_y = y + 3
130
+
131
+ if text_y < last_text_y + min_dist:
132
+ text_y = last_text_y + min_dist
133
+
134
+ ax.text(x, text_y, text, ha='center', fontweight='bold', fontsize=10, color='black')
135
+ last_text_y = text_y
136
+
137
+ ax.set_xticks(range(len(steps_labels)))
138
+ ax.set_xticklabels(steps_labels, rotation=30, ha='right')
139
+ ax.set_ylabel("Probability (%)")
140
+ ax.set_xlabel("Incremental Features Added")
141
+ ax.set_ylim(0, 115)
142
+ ax.grid(True, linestyle='--', alpha=0.3)
143
+ ax.legend(loc='upper right', bbox_to_anchor=(1.10, 1), borderaxespad=0., framealpha=0.8)
144
+
145
+ plt.tight_layout()
146
+ return fig
147
+
148
+ def predict(image, age, region, cancer_history, skin_cancer_history, bleed, hurt, itch, grown, changed, elevation):
149
+ if ort_session is None:
150
+ return "Model not loaded", None
151
+
152
+ steps = [
153
+ ("Baseline (Image only)", {}),
154
+ (f"Age ({age})", {"age": age}),
155
+ (f"Region ({region})", {"region": region}),
156
+ ]
157
+
158
+ symptoms_map = {
159
+ "Cancer History": ("cancer_history", cancer_history),
160
+ "Skin Cancer History": ("skin_cancer_history", skin_cancer_history),
161
+ "Bleed": ("bleed", bleed),
162
+ "Hurt": ("hurt", hurt),
163
+ "Itch": ("itch", itch),
164
+ "Grew": ("grew", grown),
165
+ "Changed": ("changed", changed),
166
+ "Elevation": ("elevation", elevation)
167
+ }
168
+
169
+ for label, (key, val) in symptoms_map.items():
170
+ steps.append((f"{label} ({val})", {key: val}))
171
+
172
+ probs_history = []
173
+ steps_labels = []
174
+
175
+ if image is not None:
176
+ transform = transforms.Compose([
177
+ transforms.Resize((224, 224)),
178
+ transforms.ToTensor(),
179
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
180
+ ])
181
+ image_pil = Image.open(image).convert('RGB')
182
+ image_tensor = transform(image_pil).unsqueeze(0)
183
+ else:
184
+ image_tensor = torch.zeros(1, 3, 224, 224)
185
+
186
+ def set_feature(vector, feature_name, value):
187
+ col_name = f"{feature_name}_{value}"
188
+ if col_name in _metadata_columns:
189
+ idx = _metadata_columns.index(col_name)
190
+ vector[idx] = 1.0
191
+
192
+ accumulated_features = {}
193
+
194
+ for step_name, new_features in steps:
195
+ steps_labels.append(step_name)
196
+ accumulated_features.update(new_features)
197
+
198
+ metadata_vector = np.zeros(len(_metadata_columns), dtype=np.float32)
199
+
200
+ if "age" in accumulated_features and accumulated_features["age"] is not None:
201
+ if "age" in _metadata_columns:
202
+ val = accumulated_features["age"]
203
+ metadata_vector[_metadata_columns.index("age")] = float(val) if val is not None else np.nan
204
+ else:
205
+ if "age" in _metadata_columns:
206
+ metadata_vector[_metadata_columns.index("age")] = np.nan
207
+
208
+ if "region" in accumulated_features and accumulated_features["region"]:
209
+ set_feature(metadata_vector, "region", accumulated_features["region"])
210
+
211
+ symptom_keys = ["cancer_history", "skin_cancer_history", "bleed", "hurt", "itch", "grew", "changed", "elevation"]
212
+ for key in symptom_keys:
213
+ if key in accumulated_features:
214
+ val = accumulated_features[key]
215
+ if val != "None":
216
+ set_feature(metadata_vector, key, val)
217
+
218
+ metadata_tensor = torch.tensor(metadata_vector, dtype=torch.float32).unsqueeze(0)
219
+
220
+ ort_inputs = {
221
+ ort_session.get_inputs()[0].name: image_tensor.numpy(),
222
+ ort_session.get_inputs()[1].name: metadata_tensor.numpy()
223
+ }
224
+ ort_outs = ort_session.run(None, ort_inputs)
225
+ log_probs = ort_outs[0][0]
226
+ probs = F.softmax(torch.tensor(log_probs), dim=0).numpy()
227
+
228
+ probs_dict = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
229
+ probs_history.append(probs_dict)
230
+
231
+ final_result = probs_history[-1]
232
+
233
+ plot = create_plot(probs_history, steps_labels)
234
+
235
+ return final_result, plot
236
+
237
+ def clear_func():
238
+ return None, None, None, "None", "None", "None", "None", "None", "None", "None", "None", None, None
239
+
240
+ with gr.Blocks() as demo:
241
+ with gr.Row():
242
+ gr.Markdown("# PRISM: A Clinically Interpretable Stepwise Framework for Multimodal Skin Cancer Diagnosis (DOI: TODO)")
243
+
244
+ with gr.Row():
245
+ with gr.Column():
246
+ image = gr.Image(type="filepath", height=534, label="Input Image",)
247
+ with gr.Column():
248
+ age = gr.Number(label="Age", value=None)
249
+ dropdown = gr.Dropdown(multiselect=False, allow_custom_value=False, label="Region", choices=['ARM', 'NECK', 'FACE', 'HAND', 'FOREARM', 'CHEST', 'NOSE', 'THIGH', 'SCALP', 'EAR', 'BACK', 'FOOT', 'ABDOMEN', 'LIP', 'TORSO', None])
250
+
251
+ with gr.Row():
252
+ with gr.Column():
253
+ cancer_history = gr.Radio(label="Cancer history", choices=["True", "False", "None"], value="None")
254
+ skin_cancer_history = gr.Radio(label="Skin cancer history", choices=["True", "False", "None"], value="None")
255
+ bleed = gr.Radio(label="Bled", choices=["True", "False", "None"], value="None")
256
+ hurt = gr.Radio(label="Pain", choices=["True", "False", "None"], value="None")
257
+ with gr.Column():
258
+ itch = gr.Radio(label="Itch", choices=["True", "False", "None"], value="None")
259
+ grown = gr.Radio(label="Grew", choices=["True", "False", "None"], value="None")
260
+ changed = gr.Radio(label="Changed", choices=["True", "False", "None"], value="None")
261
+ elevation = gr.Radio(label="Elevation", choices=["True", "False", "None"], value="None")
262
+
263
+ examples = [
264
+ [None, 45, "ARM", "True", "False", "True", "False", "True", "True", "False", "True"],
265
+ [None, 30, "TORSO", "False", "False", "False", "False", "False", "False", "False", "False"],
266
+ [None, 60, "ARM", "False", "True", "False", "True", "False", "True", "True", "False"],
267
+ ]
268
+ gr.Examples(examples=examples, inputs=[image, age, dropdown, cancer_history, skin_cancer_history, bleed, hurt, itch, grown, changed, elevation])
269
+
270
+ with gr.Row():
271
+ with gr.Column():
272
+ output_plot = gr.Plot(label="Incremental Prediction Change")
273
+ with gr.Column():
274
+ output = gr.Label(label="Output", num_top_classes=6)
275
+
276
+ with gr.Row():
277
+ with gr.Column():
278
+ submit = gr.Button("Submit")
279
+ submit.click(predict, inputs=[image, age, dropdown, cancer_history, skin_cancer_history, bleed, hurt, itch, grown, changed, elevation], outputs=[output, output_plot])
280
+
281
+ clear = gr.Button("Clear")
282
+ clear.click(clear_func, inputs=[], outputs=[image, age, dropdown, cancer_history, skin_cancer_history, bleed, hurt, itch, grown, changed, elevation, output, output_plot])
283
+
284
+ demo.launch(share=True)