cfoli commited on
Commit
7376e8f
·
1 Parent(s): 6069169

Initial draft of Gradio app

Browse files
Files changed (1) hide show
  1. gradio_app_chestvision_pro.py +318 -0
gradio_app_chestvision_pro.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """gradio_app_chestvision-PRO.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1gVrx5TyipNPvn8D7GaK0pNBCnLeYTAD_
8
+ """
9
+
10
+ !pip install --upgrade gradio
11
+
12
+ !pip install lightning torchmetrics
13
+
14
+ """### Import dependencies"""
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import torchvision
19
+ from torchvision import transforms, models, datasets
20
+ from torch import nn, optim
21
+ from torch.utils.data import DataLoader, Dataset
22
+ from tqdm import tqdm
23
+ from torch.utils.data import random_split
24
+ import pytorch_lightning as torch_light
25
+ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
26
+ import torchmetrics
27
+ from torchmetrics import Metric
28
+ import os
29
+ import shutil
30
+ import subprocess
31
+ import pandas as pd
32
+ from PIL import Image
33
+ import gradio
34
+ from functools import partial
35
+
36
+ """### Set parameters"""
37
+
38
+ configs = {
39
+ "IMAGE_SIZE": (224, 224), # Resize images to (W, H)
40
+ "NUM_CHANNELS": 3, # RGB images
41
+ "NUM_CLASSES": 15, # Number of output labels
42
+
43
+ # ImageNet dataset normalization values (for pretrained backbones)
44
+ "MEAN": (0.485, 0.456, 0.406),
45
+ "STD": (0.229, 0.224, 0.225),
46
+
47
+ "DEFAULT_BACKBONE": "ViT-base-16",
48
+
49
+ "THRESHOLD": 0.5
50
+ }
51
+
52
+ MODEL_REGISTRY = {
53
+ "CheXFormer-small": "m42-health/CXformer-small",
54
+ "ViT-base-16": "google/vit-base-patch16-224",
55
+ }
56
+
57
+ MODEL_CACHE = {}
58
+
59
+ """### Define helper functions"""
60
+
61
+ # helper function for loading pre-trained model
62
+ # ===================================================================================================
63
+ class get_pretrained_model(nn.Module):
64
+ def __init__(
65
+ self,
66
+ model_name: str,
67
+ num_classes: int,
68
+ num_layers_to_unfreeze: int = 0):
69
+ super().__init__()
70
+
71
+ print(f"Loading pretrained [{model_name}] model")
72
+
73
+ self.backbone = AutoModel.from_pretrained(
74
+ MODEL_REGISTRY[model_name],
75
+ trust_remote_code=True)
76
+
77
+ hidden_size = self.backbone.config.hidden_size
78
+
79
+ # Freeze entire backbone first
80
+ for param in self.backbone.parameters():
81
+ param.requires_grad = False
82
+
83
+ # Selectively unfreeze last N layers
84
+ if num_layers_to_unfreeze > 0:
85
+ self._unfreeze_last_n_layers(num_layers_to_unfreeze)
86
+
87
+ # Single classification head
88
+ self.classifier = nn.Sequential(
89
+ nn.LayerNorm(hidden_size),
90
+ nn.Dropout(0.4),
91
+ nn.Linear(hidden_size, num_classes) )
92
+
93
+ def forward(self, x):
94
+ outputs = self.backbone(x)
95
+
96
+ # Use CLS token
97
+ img_embeddings = outputs.last_hidden_state[:, 0]
98
+
99
+ logits = self.classifier(img_embeddings)
100
+ return logits
101
+
102
+ def _unfreeze_last_n_layers(self, n: int):
103
+ if hasattr(self.backbone, "encoder"):
104
+ encoder_layers = self.backbone.encoder.layer
105
+ elif hasattr(self.backbone, "vision_model"):
106
+ encoder_layers = self.backbone.vision_model.encoder.layer
107
+ else:
108
+ raise ValueError("Cannot find encoder layers in backbone.")
109
+
110
+ total_layers = len(encoder_layers)
111
+ n = min(n, total_layers)
112
+
113
+ print(f"Unfreezing last {n} of {total_layers} transformer layers.")
114
+
115
+ for layer in encoder_layers[-n:]:
116
+ for param in layer.parameters():
117
+ param.requires_grad = True
118
+
119
+
120
+ # helper function for preprocessing input images
121
+ # ===================================================================================================
122
+ preprocess_fxn = transforms.Compose(
123
+ [transforms.Resize(size=configs["IMAGE_SIZE"][::-1]),
124
+ transforms.ToTensor(),
125
+ transforms.Normalize(configs["MEAN"], configs["STD"], inplace=True)])
126
+
127
+ # Map numeric outputs to string labels
128
+ labels_dict = {
129
+ 0: "Atelectasis",
130
+ 1: "Cardiomegaly",
131
+ 2: "Consolidation",
132
+ 3: "Edema",
133
+ 4: "Effusion",
134
+ 5: "Emphysema",
135
+ 6: "Fibrosis",
136
+ 7: "Hernia",
137
+ 8: "Infiltration",
138
+ 9: "Mass",
139
+ 10: "No finding",
140
+ 11: "Nodule",
141
+ 12: "Pleural_Thickening",
142
+ 13: "Pneumonia",
143
+ 14: "Pneumothorax"}
144
+
145
+ """### Create torch lightning model (i.e., classifier) module"""
146
+
147
+ class modelModule(torch_light.LightningModule):
148
+ def __init__(self, num_classes, backbone_model_name, num_layers_to_unfreeze):
149
+ super().__init__()
150
+ self.num_classes = num_classes
151
+ self.backbone_model_name = backbone_model_name
152
+ self.num_layers_to_unfreeze = num_layers_to_unfreeze
153
+
154
+ # Load a pretrained backbone and replace its final layer
155
+ self.model = get_pretrained_model(
156
+ num_classes = self.num_classes,
157
+ model_name = self.backbone_model_name,
158
+ num_layers_to_unfreeze = self.num_layers_to_unfreeze)
159
+
160
+ # Binary classification loss operating on raw logits
161
+ self.loss_function = torch.nn.BCEWithLogitsLoss()
162
+
163
+ self.accuracy_function = torchmetrics.classification.MultilabelAccuracy(num_labels=self.num_classes, average="weighted", threshold=0.5)
164
+ self.f1_score_function = torchmetrics.classification.MultilabelF1Score(num_labels=self.num_classes, average="weighted", threshold=0.5)
165
+ self.auroc_function = torchmetrics.classification.MultilabelAUROC(num_labels=self.num_classes, average="weighted", thresholds=10)
166
+ self.map_score_function = torchmetrics.classification.MultilabelAveragePrecision(num_labels=self.num_classes, average="weighted", thresholds=10)
167
+ # average options: macro (simple average), micro (sum), weighted (weight by class size, then avg)
168
+ # threshold: Threshold for transforming probability to binary (0,1) predictions. For some metrics (e.g., AUROC), represents the number of thresholds (evenly spaced b/n 0–1) the metric should be computed at (resulting array of values are the averaged to obtain the final score)
169
+
170
+ def forward(self, x):
171
+ # Forward pass through the backbone model
172
+ return self.model(x)
173
+
174
+ def _common_step(self, batch, batch_idx):
175
+ """
176
+ Shared logic for train / val / test steps.
177
+ Computes loss and evaluation metrics.
178
+ """
179
+ x, y = batch
180
+
181
+ # Compute model predictions ()
182
+ y_logits = self.forward(x)
183
+ y_prob = torch.sigmoid(y_logits)
184
+
185
+ # Compute metrics (expects logits + labels)
186
+ loss = self.loss_function(y_logits, y.float())
187
+
188
+ # Compute mean loss over all classes
189
+ # loss = torchmetrics.aggregation.MeanMetric(self.loss_function(y_hat, y.float()), weight=X.shape[0])
190
+ accuracy = self.accuracy_function(y_prob, y)
191
+ f1_score = self.f1_score_function(y_prob, y)
192
+ auroc = self.auroc_function(y_prob, y)
193
+ mAP = self.map_score_function(y_prob, y) # mean average precision
194
+
195
+ return loss, y_logits, y, accuracy, f1_score, auroc, mAP
196
+
197
+ def training_step(self, batch, batch_idx):
198
+ # Run shared step
199
+ loss, y_logits, y, accuracy, f1_score, auroc, mAP = self._common_step(batch, batch_idx)
200
+
201
+ # Log epoch-level training metrics
202
+ self.log_dict(
203
+ {"train_loss": loss, "train_accuracy": accuracy, "train_f1_score": f1_score, "train_auroc": auroc, "train_mAP": mAP},
204
+ on_step=False, on_epoch=True, prog_bar=True)
205
+
206
+ # Lightning expects the loss key for backprop
207
+ return {"loss": loss}
208
+
209
+ def validation_step(self, batch, batch_idx):
210
+ # Run shared step
211
+ loss, y_logits, y, accuracy, f1_score, auroc, mAP = self._common_step(batch, batch_idx)
212
+
213
+ # Log validation metrics
214
+ self.log_dict(
215
+ {"val_loss": loss, "val_accuracy": accuracy,"val_f1_score": f1_score, "val_auroc": auroc, "val_mAP": mAP},
216
+ on_step=False, on_epoch=True, prog_bar=True)
217
+
218
+ def test_step(self, batch, batch_idx):
219
+ # Run shared step
220
+ loss, y_logits, y, accuracy, f1_score, auroc, mAP = self._common_step(batch, batch_idx)
221
+
222
+ # Log test metrics
223
+ self.log_dict(
224
+ {"test_loss": loss, "test_accuracy": accuracy,"test_f1_score": f1_score, "test_auroc": auroc, "test_mAP": mAP},
225
+ on_step=False, on_epoch=True, prog_bar=True)
226
+
227
+ def predict_step(self, batch, batch_idx):
228
+ """
229
+ Prediction logic used by trainer.predict().
230
+ Returns model outputs without computing loss.
231
+ """
232
+ x = batch if not isinstance(batch, (tuple, list)) else batch[0]
233
+ logits = self.forward(x)
234
+
235
+ # Convert logits to probabilities for inference
236
+ probs = torch.sigmoid(logits)
237
+
238
+ return probs
239
+
240
+ def configure_optimizers(self):
241
+ # Optimizer over all trainable parameters
242
+ optimizer = optim.Adam(self.parameters(), lr=3e-5)
243
+ return optimizer
244
+
245
+ """### Create function for running inference (i.e., assistive medical diagnosis)"""
246
+
247
+ @torch.inference_mode()
248
+ def run_diagnosis(
249
+ backbone_name,
250
+ input_image,
251
+ preprocess_fn=None,
252
+ Idx2labels=None,
253
+ threshold=configs["THRESHOLD"]):
254
+
255
+ # Preprocess
256
+ x = preprocess_fn(input_image).unsqueeze(0)
257
+
258
+ # Resolve backbone
259
+ backbone_info = MODEL_REGISTRY[backbone_name]
260
+ ckpt_path = os.path.join(CKPT_ROOT, backbone_info["ckpt"])
261
+
262
+ if not os.path.exists(ckpt_path):
263
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
264
+
265
+ # Load model (cache for speed)
266
+ if backbone_name not in MODEL_CACHE:
267
+ MODEL_CACHE[backbone_name] = modelModule.load_from_checkpoint(
268
+ ckpt_path, backbone_model_name=backbone_info["torchvision_name"], num_layers_to_unfreeze = 2)
269
+ model = MODEL_CACHE[backbone_name]
270
+
271
+ model.eval()
272
+
273
+ # Forward
274
+ logits = model(x)
275
+ probs = torch.sigmoid(logits)[0].cpu().numpy()
276
+
277
+ output_probs = {
278
+ Idx2labels[i]: float(p) for i, p in enumerate(probs)
279
+ }
280
+
281
+ predicted_classes = [
282
+ Idx2labels[i] for i, p in enumerate(probs) if p >= threshold
283
+ ]
284
+
285
+ return "\n".join(predicted_classes), output_probs
286
+
287
+ """### Gradio app"""
288
+
289
+ # example_list_dir = os.path.join(os.getcwd(), "Curated test samples")
290
+ # example_list_img_names = os.listdir(example_list_dir)
291
+ example_list_img_names = os.listdir(os.getcwd())
292
+ CKPT_ROOT = os.getcwd()
293
+
294
+ example_list = [
295
+ [configs["DEFAULT_BACKBONE"], os.path.join(os.getcwd(), example_img)]
296
+ for example_img in example_list_img_names
297
+ if example_img.lower().endswith(".png")]
298
+
299
+ # example_list = [['/content/new_labels.csv',"ResNet50"]]
300
+
301
+ gradio_app = gradio.Interface(
302
+ fn = partial(run_diagnosis, preprocess_fn = preprocess_fxn, Idx2labels = labels_dict, threshold = configs["THRESHOLD"]),
303
+
304
+ # inputs = [gradio.Dropdown(["ConvNeXt(small)", "ConvNeXt(tiny)", "EfficientNet(v2_small)", "EfficientNet(b3)", "RegNet(x3_2GF)","ResNet50"], value="EfficientNet(b3)", label="Select Backbone Model"),
305
+ # gradio.Image(type="pil", label="Load chest-X-ray image here")],
306
+ inputs = [gradio.Dropdown(["CheXFormer-small", "ViT-base-16"], value="ViT-base-16", label="Select Backbone Model"),
307
+ gradio.Image(type="pil", label="Load chest-X-ray image here")],
308
+
309
+ outputs = [gradio.Textbox(label="Predicted Medical Conditions"),
310
+ gradio.Label(label="Predicted Probabilities", show_label=False)],
311
+
312
+ examples = example_list,
313
+ cache_examples = True,
314
+ title = "ChestVision",
315
+ description = "Vision-Transformer solutions for assistive medical diagnosis with Vision-Language-based prediction justification",
316
+ article = "Author: C. Foli (02.2026) | Website: coming soon...")
317
+
318
+ gradio_app.launch()