HunterNope commited on
Commit
f9d8456
·
1 Parent(s): dcea7e3

SGC-1 - Initial commit. Added model to run demo

Browse files
Files changed (5) hide show
  1. app.py +43 -0
  2. config.py +37 -0
  3. model.py +365 -0
  4. model_params_val_f1=0.878.ckpt +3 -0
  5. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import gradio as gr
4
+ import torchvision.transforms as transforms
5
+
6
+ from model import SkinGlanceCareClassifier
7
+ from config import Config
8
+
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ cfg = Config()
12
+
13
+ model = SkinGlanceCareClassifier.load_from_checkpoint(
14
+ "model_params_val_f1=0.878.ckpt",
15
+ cfg=cfg
16
+ )
17
+ model.to(device)
18
+ model.eval()
19
+
20
+ transform = transforms.Compose([
21
+ transforms.Resize((224, 224)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
24
+ std=[0.229, 0.224, 0.225])
25
+ ])
26
+
27
+ def predict(image: Image.Image):
28
+ img = image.convert("RGB")
29
+ x = transform(img).unsqueeze(0).to(device)
30
+ with torch.no_grad():
31
+ logits = model(x)
32
+ probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
33
+ return {f"class_{i}": float(probs[i]) for i in range(len(probs))}
34
+
35
+ iface = gr.Interface(fn=predict,
36
+ inputs=gr.Image(type="pil"),
37
+ outputs=gr.Label(num_top_classes=3),
38
+ title="SkinGlanceCareClassifier",
39
+ description="Upload an image for inference"
40
+ )
41
+
42
+ if __name__ == "__main__":
43
+ iface.launch()
config.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class Config:
6
+ base_model: str = "efficientnet_b3"
7
+
8
+ csv_path: str = os.getenv("CSV_PATH", "./preprocessed_dataset")
9
+ path_images: str = os.getenv("IMAGES_PATH", "./dataset/surajghuwalewala/ham1000-segmentation-and-classification/versions/2/images")
10
+ path_healthy: str = os.getenv("HEALTHY_PATH", "./dataset/MCVSLD/Skin Lesion Dataset/train/Healthy")
11
+
12
+ num_classes: int = 8
13
+ label_classes: tuple = ('MEL', 'NV', 'BCC', 'AKIEC', 'BKL', 'DF', 'VASC', 'HEAL')
14
+
15
+ batch_size: int = 96
16
+ accumulate_grad_batches: int = 2
17
+
18
+ image_size: int = 224
19
+
20
+ num_workers: int = 12
21
+ pin_memory: bool = True
22
+ persistent_workers: bool = True
23
+ prefetch_factor: int = 4
24
+ multiprocessing_context = "spawn"
25
+
26
+ max_epochs: int = 100
27
+ learning_rate: float = 2e-4
28
+ weight_decay: float = 5e-4
29
+ precision: str = "bf16-mixed"
30
+
31
+ use_weighted_sampler: bool = False
32
+ use_smote: bool = True
33
+ use_smote_startegy = "proportional" # "equal"
34
+ cache_in_memory: bool = False
35
+
36
+ channels_last: bool = True
37
+ cudnn_benchmark: bool = True
model.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchmetrics
2
+ import numpy as np
3
+ import torch
4
+ import seaborn as sns
5
+ from torchvision import models
6
+
7
+ import matplotlib.pyplot as plt
8
+ import pytorch_lightning as pl
9
+ import torch.nn as nn
10
+
11
+ from pytorch_grad_cam import GradCAM
12
+ from pytorch_grad_cam.utils.image import show_cam_on_image
13
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
14
+ from sklearn.metrics import confusion_matrix
15
+ from typing import Dict
16
+
17
+ from config import Config
18
+
19
+ class SkinGlanceCareClassifier(pl.LightningModule):
20
+ def __init__(self, cfg: Config):
21
+ super().__init__()
22
+ self.save_hyperparameters(ignore=['cfg'])
23
+ self.cfg = cfg
24
+
25
+ self.model = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.IMAGENET1K_V1)
26
+
27
+ in_feats = self.model.classifier[1].in_features
28
+ self.model.classifier = nn.Sequential(
29
+ nn.Dropout(0.4),
30
+ nn.Linear(in_feats, 512),
31
+ nn.GELU(),
32
+ nn.Dropout(0.25),
33
+ nn.Linear(512, cfg.num_classes)
34
+ )
35
+
36
+ self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
37
+
38
+ self._setup_metrics()
39
+ self.sample_images: Dict[int, Dict] = {}
40
+
41
+ def on_fit_start(self):
42
+ if self.cfg.channels_last:
43
+ self.to(memory_format=torch.channels_last)
44
+
45
+ def forward(self, x):
46
+ if self.cfg.channels_last and x.dim() == 4:
47
+ x = x.to(memory_format=torch.channels_last)
48
+ return self.model(x)
49
+
50
+ def configure_optimizers(self):
51
+
52
+ optimizer = torch.optim.AdamW(
53
+ self.parameters(),
54
+ lr=self.cfg.learning_rate,
55
+ weight_decay=self.cfg.weight_decay,
56
+ betas=(0.9, 0.999)
57
+ )
58
+
59
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
60
+ optimizer,
61
+ T_0=5,
62
+ T_mult=2,
63
+ eta_min=1e-6
64
+ )
65
+
66
+ return {
67
+ 'optimizer': optimizer,
68
+ 'lr_scheduler': {
69
+ 'scheduler': scheduler,
70
+ 'interval': 'epoch',
71
+ }
72
+ }
73
+
74
+ def training_step(self, batch, batch_idx):
75
+ x, y = batch
76
+
77
+ logits = self(x)
78
+ loss = self.loss_fn(logits, y)
79
+
80
+ preds = torch.argmax(logits, dim=1)
81
+ self.train_acc.update(preds, y)
82
+ self.train_f1.update(preds, y)
83
+
84
+ self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=False)
85
+
86
+ return loss
87
+
88
+ def on_train_epoch_end(self):
89
+ acc = self.train_acc.compute()
90
+ f1 = self.train_f1.compute()
91
+
92
+ self.log('train_acc', acc, prog_bar=True)
93
+ self.log('train_f1', f1, prog_bar=True)
94
+
95
+ self.train_acc.reset()
96
+ self.train_f1.reset()
97
+
98
+ def validation_step(self, batch, batch_idx):
99
+ x, y = batch
100
+
101
+ logits = self(x)
102
+ loss = self.loss_fn(logits, y)
103
+
104
+ preds = torch.argmax(logits, dim=1)
105
+
106
+ self.val_preds.append(preds.detach().cpu())
107
+ self.val_labels.append(y.detach().cpu())
108
+
109
+ self.val_acc.update(preds, y)
110
+ self.val_f1.update(preds, y)
111
+ self.val_precision.update(preds, y)
112
+ self.val_recall.update(preds, y)
113
+
114
+ self.log('val_loss', loss, on_epoch=True, prog_bar=False)
115
+
116
+ return loss
117
+
118
+ def on_validation_epoch_end(self):
119
+ acc = self.val_acc.compute()
120
+ f1 = self.val_f1.compute()
121
+
122
+ self.log('val_acc', acc, prog_bar=True)
123
+ self.log('val_f1', f1, prog_bar=True)
124
+
125
+ if (self.current_epoch % 5 == 4 or self.current_epoch == 0) and not self.trainer.sanity_checking:
126
+ val_preds = torch.cat(self.val_preds)
127
+ val_labels = torch.cat(self.val_labels)
128
+
129
+ cm = confusion_matrix(val_labels.numpy(), val_preds.numpy())
130
+ self._plot_confusion_matrix(cm, "Validation")
131
+
132
+ precision = self.val_precision.compute().cpu().numpy()
133
+ recall = self.val_recall.compute().cpu().numpy()
134
+ self._log_per_class_metrics(precision, recall)
135
+
136
+ self.val_acc.reset()
137
+ self.val_f1.reset()
138
+ self.val_precision.reset()
139
+ self.val_recall.reset()
140
+ self.val_preds.clear()
141
+ self.val_labels.clear()
142
+
143
+ def test_step(self, batch, batch_idx):
144
+ x, y = batch
145
+
146
+ logits = self(x)
147
+ loss = self.loss_fn(logits, y)
148
+
149
+ preds = torch.argmax(logits, dim=1)
150
+
151
+ self.test_preds.append(preds.detach().cpu())
152
+ self.test_labels.append(y.detach().cpu())
153
+
154
+ if batch_idx % 20 == 5:
155
+ for i, lbl in enumerate(y):
156
+ cls = int(lbl.item())
157
+ if cls not in self.sample_images:
158
+ self.sample_images[cls] = {
159
+ "image": x[i].detach().cpu().clone(),
160
+ "label": cls,
161
+ "pred": int(preds[i].item()),
162
+ }
163
+
164
+ self.test_acc.update(preds, y)
165
+ self.test_f1.update(preds, y)
166
+ self.test_precision.update(preds, y)
167
+ self.test_recall.update(preds, y)
168
+
169
+ self.log('test_loss', loss, on_epoch=True)
170
+
171
+ return loss
172
+
173
+ def on_test_epoch_end(self):
174
+
175
+ acc = self.test_acc.compute()
176
+ f1 = self.test_f1.compute()
177
+ precision = self.test_precision.compute()
178
+ recall = self.test_recall.compute()
179
+
180
+ self.log('test_acc', acc, prog_bar=True)
181
+ self.log('test_f1', f1, prog_bar=True)
182
+
183
+ test_preds = torch.cat(self.test_preds)
184
+ test_labels = torch.cat(self.test_labels)
185
+
186
+ cm = confusion_matrix(test_labels.numpy(), test_preds.numpy())
187
+ self._plot_confusion_matrix(cm, "Test")
188
+
189
+ print("\n" + "="*80)
190
+ print("Test Results - Per-Class Metrics:")
191
+ print("="*80)
192
+ print(f"{'Class':<10} {'Precision':<12} {'Recall':<12} {'Instances correctly classified':<10}")
193
+ print("-"*80)
194
+
195
+ for i, cls_name in enumerate(self.cfg.label_classes):
196
+ support = (test_labels == i).sum().item()
197
+ print(f"{cls_name:<10} {precision[i]:.4f} {recall[i]:.4f} {support:<10}")
198
+
199
+ print("-"*80)
200
+ print(f"{'Overall':<10} {'Acc: ' + f'{acc:.4f}':<12} {'F1: ' + f'{f1:.4f}':<12}")
201
+ print("="*80 + "\n")
202
+
203
+ # print("Grad-CAM visualizations!")
204
+ # self._generate_gradcam_visualizations()
205
+
206
+ self.test_acc.reset()
207
+ self.test_f1.reset()
208
+ self.test_precision.reset()
209
+ self.test_recall.reset()
210
+ self.test_preds.clear()
211
+ self.test_labels.clear()
212
+
213
+ def _plot_confusion_matrix(self, cm: np.ndarray, title: str = "Validation"):
214
+
215
+ cmn = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-10)
216
+
217
+ fig, ax = plt.subplots(figsize=(12, 10))
218
+ sns.heatmap(
219
+ cmn,
220
+ annot=True,
221
+ fmt='.2f',
222
+ cmap="Blues",
223
+ ax=ax,
224
+ xticklabels=self.cfg.label_classes,
225
+ yticklabels=self.cfg.label_classes,
226
+ cbar_kws={'label': 'Normalized Count'}
227
+ )
228
+ ax.set_xlabel("Predicted Label", fontsize=12)
229
+ ax.set_ylabel("True Label", fontsize=12)
230
+ ax.set_title(f"{title} Confusion Matrix (Epoch {self.current_epoch})", fontsize=14)
231
+
232
+ plt.tight_layout()
233
+ self.logger.experiment.add_figure(
234
+ f"{title}_Confusion_Matrix",
235
+ fig,
236
+ self.current_epoch
237
+ )
238
+ plt.close(fig)
239
+
240
+ def _log_per_class_metrics(self, precision: np.ndarray, recall: np.ndarray):
241
+ for i, cls_name in enumerate(self.cfg.label_classes):
242
+ self.logger.experiment.add_scalars(
243
+ f"PerClass/{cls_name}",
244
+ {
245
+ "precision": precision[i],
246
+ "recall": recall[i],
247
+ },
248
+ self.current_epoch,
249
+ )
250
+
251
+ def _find_last_conv_module(self, module: nn.Module):
252
+ last_conv = [self.model.features[-1][-1]]
253
+
254
+ for m in module.modules():
255
+ if type(m) is nn.Conv2d:
256
+ last_conv = m
257
+ return last_conv
258
+
259
+ def _generate_gradcam_visualizations(self):
260
+ if not self.sample_images:
261
+ print("No sample images")
262
+ return
263
+
264
+ target_conv = self._find_last_conv_module(self.model)
265
+
266
+ if target_conv is None:
267
+ raise RuntimeError("Not found last layer :(")
268
+
269
+ target_layers = [target_conv]
270
+ print(f"Target layer: {target_conv}")
271
+
272
+ cam = GradCAM(model=self.model, target_layers=target_layers)
273
+
274
+ self.model.eval()
275
+
276
+ orig_requires = [p.requires_grad for p in self.model.parameters()]
277
+ for p in self.model.parameters():
278
+ p.requires_grad_(True)
279
+
280
+ fig, axes = plt.subplots(2, self.cfg.num_classes, figsize=(24, 8))
281
+
282
+ try:
283
+ for cls_idx in range(self.cfg.num_classes):
284
+ if cls_idx not in self.sample_images:
285
+ axes[0, cls_idx].axis('off')
286
+ axes[1, cls_idx].axis('off')
287
+ continue
288
+
289
+ sample = self.sample_images[cls_idx]
290
+
291
+ img_tensor = sample["image"].unsqueeze(0).to(self.device).float()
292
+ true_label = int(sample["label"])
293
+ pred_label = int(sample["pred"])
294
+
295
+ targets = [ClassifierOutputTarget(pred_label)]
296
+
297
+ with torch.enable_grad():
298
+ img_tensor.requires_grad_(True)
299
+ out = self.model(img_tensor)
300
+
301
+ test_loss = out[0, pred_label]
302
+ test_loss.backward(retain_graph=True)
303
+
304
+ grayscale_cam = cam(input_tensor=img_tensor, targets=targets)
305
+ grayscale_cam = grayscale_cam[0, :]
306
+
307
+ img_np = img_tensor.squeeze(0).detach().cpu().numpy().transpose(1, 2, 0)
308
+ mean = np.array([0.485, 0.456, 0.406])
309
+ std = np.array([0.229, 0.224, 0.225])
310
+ img_np = img_np * std + mean
311
+ img_np = np.clip(img_np, 0, 1)
312
+
313
+ visualization = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
314
+
315
+ axes[0, cls_idx].imshow(img_np)
316
+ axes[0, cls_idx].set_title(
317
+ f"{self.cfg.label_classes[cls_idx]}\nTrue: {self.cfg.label_classes[true_label]}",
318
+ fontsize=10
319
+ )
320
+ axes[0, cls_idx].axis('off')
321
+
322
+ axes[1, cls_idx].imshow(visualization)
323
+ axes[1, cls_idx].set_title(
324
+ f"Pred: {self.cfg.label_classes[pred_label]}",
325
+ fontsize=10,
326
+ color='green' if true_label == pred_label else 'red'
327
+ )
328
+ axes[1, cls_idx].axis('off')
329
+
330
+ plt.suptitle("Grad-CAM Visualizations - Model Focus Areas", fontsize=16, y=1.02)
331
+ plt.tight_layout()
332
+
333
+ self.logger.experiment.add_figure( "GradCAM_Visualizations", fig, self.current_epoch)
334
+
335
+ print("Grad-CAM visualizations - SUCCESS!")
336
+
337
+ finally:
338
+ plt.close(fig)
339
+ for p, orig in zip(self.model.parameters(), orig_requires):
340
+ p.requires_grad_(orig)
341
+ try:
342
+ del cam
343
+ except Exception:
344
+ pass
345
+
346
+ def _setup_metrics(self):
347
+ num_classes = self.cfg.num_classes
348
+
349
+ self.train_acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
350
+ self.train_f1 = torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average='macro')
351
+
352
+ self.val_acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes, average='macro')
353
+ self.val_f1 = torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average='macro')
354
+ self.val_precision = torchmetrics.Precision(task='multiclass', num_classes=num_classes, average=None)
355
+ self.val_recall = torchmetrics.Recall(task='multiclass', num_classes=num_classes, average=None)
356
+
357
+ self.test_acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes, average='macro')
358
+ self.test_f1 = torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average='macro')
359
+ self.test_precision = torchmetrics.Precision(task='multiclass', num_classes=num_classes, average=None)
360
+ self.test_recall = torchmetrics.Recall(task='multiclass', num_classes=num_classes, average=None)
361
+
362
+ self.val_preds = []
363
+ self.val_labels = []
364
+ self.test_preds = []
365
+ self.test_labels = []
model_params_val_f1=0.878.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bf809400c2908d09fa2913f559a79df8f99b9d1ff72f7b2b52ed3ec61a7fc38
3
+ size 138701074
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pytorch_lightning
4
+ gradio
5
+ numpy
6
+ scikit-learn
7
+ matplotlib
8
+ seaborn
9
+ pytorch-grad-cam