EngReem85 commited on
Commit
f077fac
·
verified ·
1 Parent(s): 2d7d0ad

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +1519 -0
model.py ADDED
@@ -0,0 +1,1519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import system libraries
2
+ import os
3
+ import cv2
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from glob import glob
7
+ from PIL import Image
8
+ from sklearn.model_selection import train_test_split
9
+
10
+ # Import data handling tools
11
+ import pandas as pd
12
+ import seaborn as sns
13
+ sns.set_style('darkgrid')
14
+ from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
15
+ from sklearn.utils import shuffle
16
+ from torch.utils.data import WeightedRandomSampler
17
+ from skimage.feature import local_binary_pattern
18
+
19
+ # Import deep learning libraries
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import torch.optim as optim
24
+ import torchvision.transforms as transforms
25
+ import torchvision.models as models
26
+ from torch.utils.data import Dataset, DataLoader
27
+
28
+ # Define dataset path and classes
29
+ DATASET_PATH = "/kaggle/input/ms-dfu/DFU_CLASSES(4)"
30
+ CLASSES = ["NONE", "INFECTION", "ISCHAEMIA", "BOTH"]
31
+
32
+ # Ensure output directories exist
33
+ os.makedirs("/kaggle/working/logs", exist_ok=True)
34
+ os.makedirs("/kaggle/working/predictions", exist_ok=True)
35
+ os.makedirs("/kaggle/working/visualizations", exist_ok=True)
36
+
37
+ # Squeeze-and-Excitation Block
38
+ class SEBlock(nn.Module):
39
+ def __init__(self, in_channels, reduction=16):
40
+ super(SEBlock, self).__init__()
41
+ self.fc1 = nn.Linear(in_channels, in_channels // reduction, bias=False)
42
+ self.fc2 = nn.Linear(in_channels // reduction, in_channels, bias=False)
43
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
44
+
45
+ def forward(self, x):
46
+ batch, channels, _, _ = x.size()
47
+ y = self.global_pool(x).view(batch, channels)
48
+ y = F.relu(self.fc1(y))
49
+ y = torch.sigmoid(self.fc2(y)).view(batch, channels, 1, 1)
50
+ return x * y
51
+
52
+ # Focal Loss Implementation
53
+ class FocalLoss(nn.Module):
54
+ def __init__(self, gamma=3.0, alpha=0.5):
55
+ super(FocalLoss, self).__init__()
56
+ self.gamma = gamma
57
+ self.alpha = alpha
58
+
59
+ def forward(self, inputs, targets):
60
+ ce_loss = F.cross_entropy(inputs, targets, reduction='none')
61
+ pt = torch.exp(-ce_loss)
62
+ focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
63
+ return focal_loss.mean()
64
+
65
+ # Channel-Centric Depth-wise Group Shuffle (CCDGS) Block
66
+ class CCDGSBlock(nn.Module):
67
+ def __init__(self, in_channels, group_size=4):
68
+ super(CCDGSBlock, self).__init__()
69
+ self.group_size = group_size
70
+ self.group_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=group_size, bias=False)
71
+ self.bn1 = nn.BatchNorm2d(in_channels)
72
+ self.depth_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels, bias=False)
73
+ self.bn2 = nn.BatchNorm2d(in_channels)
74
+ self.point_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False)
75
+ self.bn3 = nn.BatchNorm2d(in_channels)
76
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
77
+
78
+ def channel_shuffle(self, x, groups):
79
+ batchsize, num_channels, height, width = x.size()
80
+ channels_per_group = num_channels // groups
81
+ x = x.view(batchsize, groups, channels_per_group, height, width)
82
+ x = torch.transpose(x, 1, 2).contiguous()
83
+ x = x.view(batchsize, -1, height, width)
84
+ return x
85
+
86
+ def forward(self, x):
87
+ out = self.group_conv(x)
88
+ out = self.bn1(out)
89
+ out = F.relu(out)
90
+ out = self.channel_shuffle(out, self.group_size)
91
+ out = self.depth_conv(out)
92
+ out = self.bn2(out)
93
+ out = F.relu(out)
94
+ out = self.point_conv(out)
95
+ out = self.bn3(out)
96
+ out = F.relu(out)
97
+ out = self.global_pool(out)
98
+ return out
99
+
100
+ # Triplet Attention Module
101
+ class TripletAttention(nn.Module):
102
+ def __init__(self, in_channels, kernel_size=7):
103
+ super(TripletAttention, self).__init__()
104
+ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
105
+ self.conv2 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
106
+ self.conv3 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
107
+
108
+ def z_pool(self, x):
109
+ max_pool = torch.max(x, dim=1, keepdim=True)[0]
110
+ avg_pool = torch.mean(x, dim=1, keepdim=True)
111
+ return torch.cat([max_pool, avg_pool], dim=1)
112
+
113
+ def forward(self, x):
114
+ x1 = torch.rot90(x, 1, [2, 3])
115
+ x1 = self.z_pool(x1)
116
+ x1 = self.conv1(x1)
117
+ x1 = torch.sigmoid(x1)
118
+ x1 = torch.rot90(x1, -1, [2, 3])
119
+ y1 = x * x1
120
+ x2 = torch.rot90(x, 1, [1, 3])
121
+ x2 = self.z_pool(x2)
122
+ x2 = self.conv2(x2)
123
+ x2 = torch.sigmoid(x2)
124
+ x2 = torch.rot90(x2, -1, [1, 3])
125
+ y2 = x * x2
126
+ x3 = self.z_pool(x)
127
+ x3 = self.conv3(x3)
128
+ x3 = torch.sigmoid(x3)
129
+ y3 = x * x3
130
+ out = (y1 + y2 + y3) / 3.0
131
+ return out
132
+
133
+ # Dense-ShuffleGCANet Model
134
+ class DenseShuffleGCANet(nn.Module):
135
+ def __init__(self, num_classes=4, handcrafted_feature_dim=41):
136
+ super(DenseShuffleGCANet, self).__init__()
137
+ densenet = models.densenet169(weights='IMAGENET1K_V1')
138
+ self.features = densenet.features
139
+ self.ccdgs = CCDGSBlock(in_channels=1664, group_size=4)
140
+ self.triplet_attention = TripletAttention(in_channels=1664)
141
+ self.se_block = SEBlock(in_channels=1664)
142
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
143
+ self.flatten = nn.Flatten()
144
+ self.dropout = nn.Dropout(0.6)
145
+ self.fc1 = nn.Linear(1664 + handcrafted_feature_dim, 512)
146
+ self.fc2 = nn.Linear(512, num_classes)
147
+
148
+ def forward(self, x, handcrafted_features=None):
149
+ x = self.features(x)
150
+ x = self.ccdgs(x)
151
+ x = self.triplet_attention(x)
152
+ x = self.se_block(x)
153
+ x = self.global_pool(x)
154
+ x = self.flatten(x)
155
+ if handcrafted_features is not None:
156
+ x = torch.cat([x, handcrafted_features], dim=1)
157
+ x = self.dropout(x)
158
+ x = F.relu(self.fc1(x))
159
+ x = self.dropout(x)
160
+ x = self.fc2(x)
161
+ return x
162
+
163
+ # Function to display sample images
164
+ def display_sample_images(images, labels, split_name, classes, num_samples=4):
165
+ plt.figure(figsize=(15, 10))
166
+ for class_idx, class_name in enumerate(classes):
167
+ class_indices = [i for i, label in enumerate(labels) if label == class_idx]
168
+ if not class_indices:
169
+ continue
170
+ selected_indices = class_indices[:num_samples]
171
+ for i, idx in enumerate(selected_indices):
172
+ img = cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB)
173
+ plt.subplot(len(classes), num_samples, class_idx * num_samples + i + 1)
174
+ plt.imshow(img)
175
+ plt.title(f'{class_name}')
176
+ plt.axis('off')
177
+ plt.suptitle(f'{split_name} Sample Images')
178
+ plt.tight_layout(rect=[0, 0, 1, 0.95])
179
+ plt.savefig(f'/kaggle/working/visualizations/{split_name.lower()}_samples.png')
180
+ plt.close()
181
+
182
+ # Function to visualize handcrafted features
183
+ # def visualize_handcrafted_features(images, labels, classes, num_samples=2):
184
+ # for class_idx, class_name in enumerate(classes):
185
+ # class_indices = [i for i, label in enumerate(labels) if label == class_idx]
186
+ # if not class_indices:
187
+ # continue
188
+ # selected_indices = class_indices[:num_samples]
189
+ # for idx in selected_indices:
190
+ # img = cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB)
191
+ # gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
192
+
193
+ # gray_blur = cv2.GaussianBlur(gray, (5, 5), 0)
194
+ # gray_eq = cv2.equalizeHist(gray_blur)
195
+
196
+ # median = np.median(gray_eq)
197
+ # lower_threshold = int(max(0, 0.66 * median))
198
+ # upper_threshold = int(min(255, 1.33 * median))
199
+
200
+ # edges = cv2.Canny(gray_eq, lower_threshold, upper_threshold)
201
+ # print(f"Visualization - Class: {class_name}, Image {idx}, Edge pixels: {np.sum(edges > 0)}")
202
+
203
+ # edge_hist, _ = np.histogram(edges.ravel(), bins=8, range=(0, 256), density=True)
204
+
205
+ # lbp = local_binary_pattern(gray, P=8, R=1, method='uniform')
206
+ # lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, 10), density=True)
207
+ # color_hist = []
208
+ # for channel in range(img.shape[2]):
209
+ # hist, _ = np.histogram(img[:, :, channel], bins=8, range=(0, 256), density=True)
210
+ # color_hist.extend(hist)
211
+
212
+ # plt.figure(figsize=(18, 4))
213
+ # plt.subplot(1, 5, 1)
214
+ # plt.imshow(img)
215
+ # plt.title(f'Original ({class_name})')
216
+ # plt.axis('off')
217
+
218
+ # plt.subplot(1, 5, 2)
219
+ # plt.imshow(edges, cmap='gray')
220
+ # plt.title('Canny Edge Map')
221
+ # plt.axis('off')
222
+
223
+ # plt.subplot(1, 5, 3)
224
+ # plt.bar(range(len(lbp_hist)), lbp_hist)
225
+ # plt.title('LBP Histogram')
226
+
227
+ # plt.subplot(1, 5, 4)
228
+ # plt.bar(range(len(color_hist)), color_hist)
229
+ # plt.title('Color Histogram')
230
+
231
+ # plt.subplot(1, 5, 5)
232
+ # plt.bar(range(len(edge_hist)), edge_hist)
233
+ # plt.title('Edge Histogram')
234
+
235
+ # plt.tight_layout()
236
+ # plt.savefig(f'/kaggle/working/visualizations/handcrafted_features_{class_name}_{idx}.png')
237
+ # plt.close()
238
+
239
+ # Function to visualize handcrafted features
240
+ # def visualize_handcrafted_features(images, labels, classes, num_samples=1):
241
+ # for class_idx, class_name in enumerate(classes):
242
+ # class_indices = [i for i, label in enumerate(labels) if label == class_idx]
243
+ # if not class_indices:
244
+ # continue
245
+ # selected_indices = class_indices[:num_samples]
246
+ # for idx in selected_indices:
247
+ # img = cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB)
248
+ # gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
249
+
250
+ # gray_blur = cv2.GaussianBlur(gray, (5, 5), 0)
251
+ # gray_eq = cv2.equalizeHist(gray_blur)
252
+
253
+ # median = np.median(gray_eq)
254
+ # lower_threshold = int(max(0, 0.66 * median))
255
+ # upper_threshold = int(min(255, 1.33 * median))
256
+
257
+ # edges = cv2.Canny(gray_eq, lower_threshold, upper_threshold)
258
+ # print(f"Visualization - Class: {class_name}, Image {idx}, Edge pixels: {np.sum(edges > 0)}")
259
+
260
+ # edge_hist, _ = np.histogram(edges.ravel(), bins=8, range=(0, 256), density=True)
261
+
262
+ # lbp = local_binary_pattern(gray, P=8, R=1, method='uniform')
263
+ # lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, 10), density=True)
264
+ # color_hist = []
265
+ # for channel in range(img.shape[2]):
266
+ # hist, _ = np.histogram(img[:, :, channel], bins=8, range=(0, 256), density=True)
267
+ # color_hist.extend(hist)
268
+
269
+ # plt.figure(figsize=(18, 4))
270
+ # plt.subplot(1, 5, 1)
271
+ # plt.imshow(img)
272
+ # plt.title(f'Original ({class_name})')
273
+ # plt.axis('off')
274
+
275
+ # plt.subplot(1, 5, 2)
276
+ # plt.imshow(edges, cmap='gray')
277
+ # plt.title('Canny Edge Map')
278
+ # plt.axis('off')
279
+
280
+ # # LBP Histogram with bold values
281
+ # plt.subplot(1, 5, 3)
282
+ # bars = plt.bar(range(len(lbp_hist)), lbp_hist)
283
+ # plt.title('LBP Histogram')
284
+ # for bar in bars:
285
+ # height = bar.get_height()
286
+ # plt.text(bar.get_x() + bar.get_width()/2., height,
287
+ # f'{height:.2f}',
288
+ # ha='center', va='bottom',
289
+ # fontsize=4, fontweight='bold') # Bold and slightly larger
290
+
291
+ # # Color Histogram with bold values
292
+ # plt.subplot(1, 5, 4)
293
+ # bars = plt.bar(range(len(color_hist)), color_hist)
294
+ # plt.title('Color Histogram')
295
+ # for bar in bars:
296
+ # height = bar.get_height()
297
+ # plt.text(bar.get_x() + bar.get_width()/2., height,
298
+ # f'{height:.2f}',
299
+ # ha='center', va='bottom',
300
+ # fontsize=4, fontweight='bold') # Bold and slightly larger
301
+
302
+ # # Edge Histogram with bold values
303
+ # plt.subplot(1, 5, 5)
304
+ # bars = plt.bar(range(len(edge_hist)), edge_hist)
305
+ # plt.title('Edge Histogram')
306
+ # for bar in bars:
307
+ # height = bar.get_height()
308
+ # plt.text(bar.get_x() + bar.get_width()/2., height,
309
+ # f'{height:.2f}',
310
+ # ha='center', va='bottom',
311
+ # fontsize=4, fontweight='bold') # Bold and slightly larger
312
+
313
+ # plt.tight_layout()
314
+ # plt.savefig(f'/kaggle/working/visualizations/handcrafted_features_{class_name}_{idx}.png')
315
+ # plt.close()
316
+
317
+
318
+ def visualize_handcrafted_features(images, labels, classes, num_samples=1):
319
+ # Create main visualization directory
320
+ main_dir = '/kaggle/working/visualizations/handcrafted_features'
321
+ os.makedirs(main_dir, exist_ok=True)
322
+
323
+ for class_idx, class_name in enumerate(classes):
324
+ # Create class-specific subdirectory
325
+ class_dir = os.path.join(main_dir, f"class_{class_idx}_{class_name}")
326
+ os.makedirs(class_dir, exist_ok=True)
327
+
328
+ class_indices = [i for i, label in enumerate(labels) if label == class_idx]
329
+ if not class_indices:
330
+ continue
331
+
332
+ selected_indices = class_indices[:num_samples]
333
+ for idx in selected_indices:
334
+ img = cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB)
335
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
336
+
337
+ # Preprocessing
338
+ gray_blur = cv2.GaussianBlur(gray, (5, 5), 0)
339
+ gray_eq = cv2.equalizeHist(gray_blur)
340
+
341
+ # Edge detection
342
+ median = np.median(gray_eq)
343
+ lower_threshold = int(max(0, 0.66 * median))
344
+ upper_threshold = int(min(255, 1.33 * median))
345
+ edges = cv2.Canny(gray_eq, lower_threshold, upper_threshold)
346
+ print(f"Visualization - Class: {class_name}, Image {idx}, Edge pixels: {np.sum(edges > 0)}")
347
+
348
+ # Feature extraction
349
+ edge_hist, _ = np.histogram(edges.ravel(), bins=8, range=(0, 256), density=True)
350
+ lbp = local_binary_pattern(gray, P=8, R=1, method='uniform')
351
+ lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, 10), density=True)
352
+ color_hist = []
353
+ for channel in range(img.shape[2]):
354
+ hist, _ = np.histogram(img[:, :, channel], bins=8, range=(0, 256), density=True)
355
+ color_hist.extend(hist)
356
+
357
+ # 1. Original Image
358
+ plt.figure(figsize=(5, 5))
359
+ plt.imshow(img)
360
+ plt.title(f'Original ({class_name})')
361
+ plt.axis('off')
362
+ plt.savefig(os.path.join(class_dir, f'sample_{idx}_original.png'),
363
+ dpi=120, bbox_inches='tight')
364
+ plt.close()
365
+
366
+ # 2. Edge Map
367
+ plt.figure(figsize=(5, 5))
368
+ plt.imshow(edges, cmap='gray')
369
+ plt.title('Canny Edge Map')
370
+ plt.axis('off')
371
+ plt.savefig(os.path.join(class_dir, f'sample_{idx}_edges.png'),
372
+ dpi=120, bbox_inches='tight')
373
+ plt.close()
374
+
375
+ # 3. LBP Histogram (with your exact text styling)
376
+ plt.figure(figsize=(8, 4))
377
+ bars = plt.bar(range(len(lbp_hist)), lbp_hist)
378
+ plt.title('LBP Histogram')
379
+ for bar in bars:
380
+ height = bar.get_height()
381
+ plt.text(bar.get_x() + bar.get_width()/2., height,
382
+ f'{height:.2f}',
383
+ ha='center', va='bottom',
384
+ fontsize=8, fontweight='bold') # Slightly larger font for separate image
385
+ plt.savefig(os.path.join(class_dir, f'sample_{idx}_lbp_hist.png'),
386
+ dpi=120, bbox_inches='tight')
387
+ plt.close()
388
+
389
+ # 4. Color Histogram
390
+ plt.figure(figsize=(10, 4))
391
+ bars = plt.bar(range(len(color_hist)), color_hist)
392
+ plt.title('Color Histogram')
393
+ for bar in bars:
394
+ height = bar.get_height()
395
+ plt.text(bar.get_x() + bar.get_width()/2., height,
396
+ f'{height:.2f}',
397
+ ha='center', va='bottom',
398
+ fontsize=8, fontweight='bold') # Slightly larger font
399
+ plt.savefig(os.path.join(class_dir, f'sample_{idx}_color_hist.png'),
400
+ dpi=120, bbox_inches='tight')
401
+ plt.close()
402
+
403
+ # 5. Edge Histogram
404
+ plt.figure(figsize=(8, 4))
405
+ bars = plt.bar(range(len(edge_hist)), edge_hist)
406
+ plt.title('Edge Histogram')
407
+ for bar in bars:
408
+ height = bar.get_height()
409
+ plt.text(bar.get_x() + bar.get_width()/2., height,
410
+ f'{height:.2f}',
411
+ ha='center', va='bottom',
412
+ fontsize=8, fontweight='bold') # Slightly larger font
413
+ plt.savefig(os.path.join(class_dir, f'sample_{idx}_edge_hist.png'),
414
+ dpi=120, bbox_inches='tight')
415
+ plt.close()
416
+
417
+ print(f"Saved separate visualizations for class {class_name} sample {idx} in: {class_dir}")
418
+
419
+
420
+
421
+ # Function to extract handcrafted features
422
+ def extract_handcrafted_features(image):
423
+ if isinstance(image, torch.Tensor):
424
+ image = image.cpu().numpy()
425
+ image = np.transpose(image, (1, 2, 0))
426
+ image = (image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])) * 255
427
+ image = image.astype(np.uint8)
428
+
429
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
430
+
431
+ gray_blur = cv2.GaussianBlur(gray, (5, 5), 0)
432
+ gray_eq = cv2.equalizeHist(gray_blur)
433
+
434
+ median = np.median(gray_eq)
435
+ lower_threshold = int(max(0, 0.66 * median))
436
+ upper_threshold = int(min(255, 1.33 * median))
437
+
438
+ edges = cv2.Canny(gray_eq, lower_threshold, upper_threshold)
439
+ print(f"Edge detection stats - Lower threshold: {lower_threshold}, Upper threshold: {upper_threshold}, Edge pixels: {np.sum(edges > 0)}")
440
+
441
+ edge_hist, _ = np.histogram(edges.ravel(), bins=8, range=(0, 256), density=True)
442
+
443
+ lbp = local_binary_pattern(gray, P=8, R=1, method='uniform')
444
+ lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, 10), density=True)
445
+
446
+ color_hist = []
447
+ for channel in range(image.shape[2]):
448
+ hist, _ = np.histogram(image[:, :, channel], bins=8, range=(0, 256), density=True)
449
+ color_hist.extend(hist)
450
+
451
+ features = np.concatenate([lbp_hist, color_hist, edge_hist])
452
+ return torch.tensor(features, dtype=torch.float32)
453
+
454
+ # Function to load images with handcrafted features
455
+ def load_images(split_path, classes, use_csv=False):
456
+ image_data = []
457
+ labels = []
458
+ handcrafted_features = []
459
+ print(f"Loading images from: {split_path}")
460
+ csv_path = os.path.join(DATASET_PATH, "labels.csv")
461
+ if use_csv and os.path.exists(csv_path):
462
+ print("Found labels.csv, loading dataset from CSV")
463
+ df = pd.read_csv(csv_path)
464
+ print("CSV columns:", df.columns)
465
+ for idx, row in df.iterrows():
466
+ img_path = os.path.join(DATASET_PATH, row['image_path'])
467
+ label_name = row['label']
468
+ if label_name not in CLASSES:
469
+ print(f"Warning: Label {label_name} not in CLASSES, skipping")
470
+ continue
471
+ label = CLASSES.index(label_name)
472
+ try:
473
+ img = Image.open(img_path).convert('RGB')
474
+ img_array = np.array(img)
475
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
476
+ features = extract_handcrafted_features(img_array)
477
+ except Exception as e:
478
+ print(f"Warning: Failed to load image {img_path}: {e}")
479
+ continue
480
+ image_data.append(img_array)
481
+ labels.append(label)
482
+ handcrafted_features.append(features)
483
+ else:
484
+ if not os.path.exists(split_path):
485
+ print(f"Error: Directory {split_path} does not exist")
486
+ return image_data, labels, handcrafted_features
487
+ for class_idx, class_name in enumerate(classes):
488
+ class_path = os.path.join(split_path, class_name)
489
+ print(f"Checking class: {class_name} at {class_path}")
490
+ if not os.path.exists(class_path):
491
+ print(f"Warning: Class directory {class_path} does not exist")
492
+ continue
493
+ all_files = glob(os.path.join(class_path, '*'))
494
+ print(f"All files in {class_path}: {all_files[:5]}")
495
+ image_paths = glob(os.path.join(class_path, '*.[jJ][pP][gG]')) + \
496
+ glob(os.path.join(class_path, '*.[jJ][pP][eE][gG]')) + \
497
+ glob(os.path.join(class_path, '*.png'))
498
+ print(f"Found {len(image_paths)} images for class {class_name}")
499
+ for img_path in image_paths:
500
+ try:
501
+ img = Image.open(img_path).convert('RGB')
502
+ img_array = np.array(img)
503
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
504
+ features = extract_handcrafted_features(img_array)
505
+ except Exception as e:
506
+ print(f"Warning: Failed to load image {img_path}: {e}")
507
+ continue
508
+ image_data.append(img_array)
509
+ labels.append(class_idx)
510
+ handcrafted_features.append(features)
511
+ print(f"Total images loaded: {len(image_data)}")
512
+ return image_data, labels, handcrafted_features
513
+
514
+ # Function to visualize dataset distribution
515
+ # def visualize_data_distribution(split_path, split_name, classes):
516
+ # class_counts = []
517
+ # for class_name in classes:
518
+ # class_path = os.path.join(split_path, class_name)
519
+ # image_paths = glob(os.path.join(class_path, '*.[jJ][pP][gG]')) + \
520
+ # glob(os.path.join(class_path, '*.[jJ][pP][eE][gG]')) + \
521
+ # glob(os.path.join(class_path, '*.png'))
522
+ # class_counts.append(len(image_paths))
523
+ # print(f"Split: {split_name}, Class: {class_name}, Number of images: {len(image_paths)}")
524
+ # plt.figure(figsize=(10, 6))
525
+ # plt.bar(classes, class_counts)
526
+ # plt.title(f'{split_name} Dataset Distribution')
527
+ # plt.xlabel('Classes')
528
+ # plt.ylabel('Number of Images')
529
+ # plt.xticks(rotation=45, ha='right')
530
+ # plt.tight_layout()
531
+ # plt.savefig(f"/kaggle/working/visualizations/{split_name.lower()}_distribution.png")
532
+ # plt.close()
533
+
534
+ # Function to visualize dataset distribution
535
+ def visualize_data_distribution(split_path, split_name, classes):
536
+ class_counts = []
537
+ for class_name in classes:
538
+ class_path = os.path.join(split_path, class_name)
539
+ image_paths = glob(os.path.join(class_path, '*.[jJ][pP][gG]')) + \
540
+ glob(os.path.join(class_path, '*.[jJ][pP][eE][gG]')) + \
541
+ glob(os.path.join(class_path, '*.png'))
542
+ class_counts.append(len(image_paths))
543
+ print(f"Split: {split_name}, Class: {class_name}, Number of images: {len(image_paths)}")
544
+
545
+ plt.figure(figsize=(10, 6))
546
+ bars = plt.bar(classes, class_counts) # Store the bar objects
547
+ plt.title(f'{split_name} Dataset Distribution')
548
+ plt.xlabel('Classes')
549
+ plt.ylabel('Number of Images')
550
+ plt.xticks(rotation=45, ha='right')
551
+
552
+ # Add value labels on top of each bar
553
+ for bar in bars:
554
+ height = bar.get_height()
555
+ plt.text(bar.get_x() + bar.get_width()/2., height,
556
+ f'{height}',
557
+ ha='center', va='bottom',
558
+ fontsize=10, fontweight='bold')
559
+
560
+ plt.tight_layout()
561
+ plt.savefig(f"/kaggle/working/visualizations/{split_name.lower()}_distribution.png")
562
+ plt.close()
563
+
564
+
565
+ # Custom Dataset Class
566
+ class FootUlcerDataset(Dataset):
567
+ def __init__(self, images, labels, handcrafted_features):
568
+ self.images = images
569
+ self.labels = labels
570
+ self.handcrafted_features = handcrafted_features
571
+
572
+ def __len__(self):
573
+ return len(self.images)
574
+
575
+ def __getitem__(self, idx):
576
+ image = self.images[idx]
577
+ label = self.labels[idx]
578
+ features = self.handcrafted_features[idx]
579
+ image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
580
+ if label in [2, 3]:
581
+ transform = transforms.Compose([
582
+ transforms.RandomHorizontalFlip(p=0.5),
583
+ transforms.RandomRotation(30),
584
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
585
+ transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
586
+ transforms.ToTensor(),
587
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
588
+ ])
589
+ else:
590
+ transform = transforms.Compose([
591
+ transforms.RandomHorizontalFlip(p=0.5),
592
+ transforms.ToTensor(),
593
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
594
+ ])
595
+ image = transform(image)
596
+ return image, label, features
597
+
598
+ # Function to print model summary
599
+ def print_model_summary(model, input_size=(3, 224, 224), handcrafted_feature_dim=41):
600
+ device = next(model.parameters()).device
601
+ model.eval()
602
+ print("\nModel Summary:")
603
+ print("=" * 80)
604
+ print(f"{'Layer':<30} {'Output Shape':<25} {'Param #':<15}")
605
+ print("-" * 80)
606
+
607
+ total_params = 0
608
+ x = torch.randn(1, *input_size).to(device)
609
+ handcrafted_features = torch.randn(1, handcrafted_feature_dim).to(device)
610
+
611
+ def register_hook(module, input, output):
612
+ nonlocal total_params
613
+ class_name = str(module.__class__.__name__)
614
+ param_count = sum(p.numel() for p in module.parameters())
615
+ total_params += param_count
616
+ output_shape = list(output.shape) if isinstance(output, torch.Tensor) else "N/A"
617
+ print(f"{class_name:<30} {str(output_shape):<25} {param_count:<15}")
618
+
619
+ hooks = []
620
+ for name, module in model.named_modules():
621
+ if module != model:
622
+ hooks.append(module.register_forward_hook(register_hook))
623
+
624
+ with torch.no_grad():
625
+ model(x, handcrafted_features)
626
+
627
+ for hook in hooks:
628
+ hook.remove()
629
+
630
+ print("-" * 80)
631
+ print(f"Total Parameters: {total_params:,}")
632
+ print("=" * 80)
633
+
634
+ # Function to plot ROC curves
635
+ def plot_roc_curves(labels, probabilities, split_name, classes, model_idx=None):
636
+ plt.figure(figsize=(10, 8))
637
+ for i, class_name in enumerate(classes):
638
+ fpr, tpr, _ = roc_curve(np.array(labels) == i, probabilities[:, i])
639
+ roc_auc = auc(fpr, tpr)
640
+ plt.plot(fpr, tpr, label=f'{class_name} (AUC = {roc_auc:.2f})')
641
+ plt.plot([0, 1], [0, 1], 'k--')
642
+ plt.xlim([0.0, 1.0])
643
+ plt.ylim([0.0, 1.05])
644
+ plt.xlabel('False Positive Rate')
645
+ plt.ylabel('True Positive Rate')
646
+ plt.title(f'ROC Curves - {split_name}' + (f' (Model {model_idx})' if model_idx is not None else ''))
647
+ plt.legend(loc='lower right')
648
+ plt.grid(True)
649
+ filename = f'/kaggle/working/visualizations/roc_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png')
650
+ plt.savefig(filename)
651
+ plt.close()
652
+
653
+ # Function to visualize feature extraction layer by layer
654
+ # def visualize_feature_extraction(model, dataloader, device, classes, num_samples=1):
655
+ # model.eval()
656
+ # feature_maps = {}
657
+ # layer_names = ['features', 'ccdgs', 'triplet_attention', 'se_block', 'fc1', 'fc2']
658
+
659
+ # print("\nModel structure (named modules):")
660
+ # for name, module in model.named_modules():
661
+ # print(f"Layer: {name}, Module: {type(module).__name__}")
662
+
663
+ # print("\nRegistering forward hooks for layers:", layer_names)
664
+
665
+ # def get_hook(name):
666
+ # def hook(module, input, output):
667
+ # feature_maps[name] = output.detach()
668
+ # print(f"Captured output for {name}, shape: {output.shape}")
669
+ # return hook
670
+
671
+ # hooks = []
672
+ # for name in layer_names:
673
+ # module = getattr(model, name, None)
674
+ # if module:
675
+ # hooks.append(module.register_forward_hook(get_hook(name)))
676
+ # print(f"Hook registered for {name}")
677
+ # else:
678
+ # print(f"Warning: Layer {name} not found in model")
679
+
680
+ # images_list = []
681
+ # labels_list = []
682
+ # probs_list = []
683
+ # features_list = []
684
+ # with torch.no_grad():
685
+ # for images, labels, features in dataloader:
686
+ # images, labels, features = images.to(device), labels.to(device), features.to(device)
687
+ # print(f"Processing batch with {images.shape[0]} images, features shape: {features.shape}")
688
+ # outputs = model(images, features)
689
+ # probs = F.softmax(outputs, dim=1)
690
+ # images_list.extend(images.cpu().numpy())
691
+ # labels_list.extend(labels.cpu().numpy())
692
+ # probs_list.extend(probs.cpu().numpy())
693
+ # features_list.extend(features.cpu().numpy())
694
+ # break
695
+
696
+ # print(f"Removing {len(hooks)} hooks")
697
+ # for hook in hooks:
698
+ # hook.remove()
699
+
700
+ # print(f"Feature maps captured: {list(feature_maps.keys())}")
701
+
702
+ # for idx in range(min(num_samples, len(images_list))):
703
+ # img = images_list[idx].transpose(1, 2, 0)
704
+ # img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
705
+ # img = np.clip(img, 0, 1)
706
+ # true_label = CLASSES[labels_list[idx]]
707
+
708
+ # plt.figure(figsize=(5, 5))
709
+ # plt.imshow(img)
710
+ # plt.title(f'Input Image (Class: {true_label})')
711
+ # plt.axis('off')
712
+ # input_img_path = f'/kaggle/working/visualizations/input_image_sample_{idx}.png'
713
+ # plt.savefig(input_img_path)
714
+ # plt.close()
715
+ # print(f"Saved input image to: {input_img_path}")
716
+
717
+ # for layer_name in layer_names:
718
+ # if layer_name not in feature_maps:
719
+ # print(f"No feature map for {layer_name}, skipping visualization")
720
+ # continue
721
+
722
+ # features = feature_maps[layer_name][idx]
723
+ # print(f"Visualizing {layer_name}, feature shape: {features.shape}, feature dim:{features.dim()}")
724
+
725
+ # if features.dim() == 3:
726
+ # num_channels = min(features.shape[0], 16)
727
+ # plt.figure(figsize=(15, 10))
728
+ # for i in range(num_channels):
729
+ # plt.subplot(4, 4, i + 1)
730
+ # feature_map = features[i].cpu().numpy()
731
+ # feature_map = (feature_map - feature_map.min()) / (feature_map.max() - feature_map.min() + 1e-8)
732
+ # plt.imshow(feature_map, cmap='viridis')
733
+ # plt.title(f'Channel {i+1}')
734
+ # plt.axis('off')
735
+ # plt.suptitle(f'Feature Maps - {layer_name} (Class: {true_label})')
736
+ # plt.tight_layout(rect=[0, 0, 1, 0.95])
737
+ # feature_map_path = f'/kaggle/working/visualizations/feature_maps_{layer_name}_{idx}.png'
738
+ # plt.savefig(feature_map_path)
739
+ # plt.close()
740
+ # print(f"Saved {num_channels} feature maps for {layer_name} to: {feature_map_path}")
741
+
742
+ # else:
743
+ # values = features.flatten().cpu().numpy()
744
+ # plt.figure(figsize=(10, 5))
745
+ # plt.bar(range(len(values)), values, color='blue') # Changed bar color to blue
746
+ # plt.title(f'Feature Vector - {layer_name} (Class: {true_label})')
747
+ # plt.xlabel('Index')
748
+ # plt.ylabel('Value')
749
+ # ## Use only if it,s looking good, the grid part
750
+ # plt.grid(True, linestyle='--', alpha=0.6) # Added light grid for better readability
751
+ # plt.tight_layout()
752
+ # vector_path = f'/kaggle/working/visualizations/feature_vector_{layer_name}_{idx}.png'
753
+ # plt.savefig(vector_path)
754
+ # plt.close()
755
+ # print(f"Saved feature vector with {len(values)} elements for {layer_name} to: {vector_path}")
756
+
757
+ # plt.figure(figsize=(8, 6))
758
+ # bars = plt.bar(classes, probs_list[idx]) # Store the bar objects
759
+ # plt.title(f'Classification Probabilities (True: {true_label})')
760
+ # plt.xlabel('Classes')
761
+ # plt.ylabel('Probability')
762
+ # plt.xticks(rotation=45)
763
+
764
+ # # Add value labels on top of each bar
765
+ # for bar in bars:
766
+ # height = bar.get_height()
767
+ # plt.text(bar.get_x() + bar.get_width()/2., height,
768
+ # f'{height:.3f}',
769
+ # ha='center', va='bottom',
770
+ # fontsize=10, fontweight='bold')
771
+
772
+ # plt.tight_layout()
773
+ # probs_path = f'/kaggle/working/visualizations/classification_probs_{idx}.png'
774
+ # plt.savefig(probs_path)
775
+ # plt.close()
776
+ # print(f"Saved classification probabilities to: {probs_path}")
777
+
778
+ # print("\nListing saved visualization files:")
779
+ # os.system('ls /kaggle/working/visualizations/')
780
+
781
+ def visualize_feature_extraction(model, dataloader, device, classes, num_samples_per_class=1):
782
+ model.eval()
783
+ feature_maps = {}
784
+ layer_names = ['features', 'ccdgs', 'triplet_attention', 'se_block', 'fc1', 'fc2']
785
+
786
+ print("\nModel structure (named modules):")
787
+ for name, module in model.named_modules():
788
+ print(f"Layer: {name}, Module: {type(module).__name__}")
789
+
790
+ print("\nRegistering forward hooks for layers:", layer_names)
791
+
792
+ def get_hook(name):
793
+ def hook(module, input, output):
794
+ feature_maps[name] = output.detach()
795
+ print(f"Captured output for {name}, shape: {output.shape}")
796
+ return hook
797
+
798
+ hooks = []
799
+ for name in layer_names:
800
+ module = getattr(model, name, None)
801
+ if module:
802
+ hooks.append(module.register_forward_hook(get_hook(name)))
803
+ print(f"Hook registered for {name}")
804
+ else:
805
+ print(f"Warning: Layer {name} not found in model")
806
+
807
+ # Collect samples from each class
808
+ class_samples = {class_idx: [] for class_idx in range(len(classes))}
809
+ with torch.no_grad():
810
+ for images, labels, features in dataloader:
811
+ images, labels, features = images.to(device), labels.to(device), features.to(device)
812
+ outputs = model(images, features)
813
+ probs = F.softmax(outputs, dim=1)
814
+
815
+ for i in range(len(images)):
816
+ class_idx = labels[i].item()
817
+ if len(class_samples[class_idx]) < num_samples_per_class:
818
+ class_samples[class_idx].append((
819
+ images[i].cpu().numpy(),
820
+ labels[i].cpu().numpy(),
821
+ probs[i].cpu().numpy(),
822
+ features[i].cpu().numpy()
823
+ ))
824
+
825
+ # Check if we have enough samples from each class
826
+ if all(len(samples) >= num_samples_per_class for samples in class_samples.values()):
827
+ break
828
+
829
+ print(f"Removing {len(hooks)} hooks")
830
+ for hook in hooks:
831
+ hook.remove()
832
+
833
+ print(f"Feature maps captured: {list(feature_maps.keys())}")
834
+
835
+ # Process one sample from each class
836
+ for class_idx in range(len(classes)):
837
+ if not class_samples[class_idx]:
838
+ print(f"No samples found for class {class_idx} ({classes[class_idx]})")
839
+ continue
840
+
841
+ # Take the first sample for this class
842
+ img, label, prob, features = class_samples[class_idx][0]
843
+ true_label = classes[label]
844
+
845
+ # Process input image
846
+ img = img.transpose(1, 2, 0)
847
+ img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
848
+ img = np.clip(img, 0, 1)
849
+
850
+ plt.figure(figsize=(5, 5))
851
+ plt.imshow(img)
852
+ plt.title(f'Input Image (Class: {true_label})')
853
+ plt.axis('off')
854
+ input_img_path = f'/kaggle/working/visualizations/class_{class_idx}_input_image.png'
855
+ plt.savefig(input_img_path)
856
+ plt.close()
857
+ print(f"Saved input image for class {true_label} to: {input_img_path}")
858
+
859
+ # Process feature maps for each layer
860
+ for layer_name in layer_names:
861
+ if layer_name not in feature_maps:
862
+ print(f"No feature map for {layer_name}, skipping visualization")
863
+ continue
864
+
865
+ features = feature_maps[layer_name][class_idx] # Assuming feature maps are in order
866
+ print(f"Visualizing {layer_name} for class {true_label}, feature shape: {features.shape}")
867
+
868
+ if features.dim() == 3:
869
+ num_channels = min(features.shape[0], 16)
870
+ plt.figure(figsize=(15, 10))
871
+ for i in range(num_channels):
872
+ plt.subplot(4, 4, i + 1)
873
+ feature_map = features[i].cpu().numpy()
874
+ feature_map = (feature_map - feature_map.min()) / (feature_map.max() - feature_map.min() + 1e-8)
875
+ plt.imshow(feature_map, cmap='viridis')
876
+ plt.title(f'Channel {i+1}')
877
+ plt.axis('off')
878
+ plt.suptitle(f'Feature Maps - {layer_name} (Class: {true_label})')
879
+ plt.tight_layout(rect=[0, 0, 1, 0.95])
880
+ feature_map_path = f'/kaggle/working/visualizations/class_{class_idx}_feature_maps_{layer_name}.png'
881
+ plt.savefig(feature_map_path)
882
+ plt.close()
883
+ print(f"Saved {num_channels} feature maps for {layer_name} to: {feature_map_path}")
884
+
885
+ # else:
886
+ # values = features.flatten().cpu().numpy()
887
+ # plt.figure(figsize=(10, 5))
888
+ # plt.bar(range(len(values)), values, color='blue')
889
+ # plt.title(f'Feature Vector - {layer_name} (Class: {true_label})')
890
+ # plt.xlabel('Index')
891
+ # plt.ylabel('Value')
892
+ # plt.grid(True, linestyle='--', alpha=0.6)
893
+ # plt.tight_layout()
894
+ # vector_path = f'/kaggle/working/visualizations/class_{class_idx}_feature_vector_{layer_name}.png'
895
+ # plt.savefig(vector_path)
896
+ # plt.close()
897
+ # print(f"Saved feature vector with {len(values)} elements for {layer_name} to: {vector_path}")
898
+
899
+ else:
900
+ values = features.flatten().cpu().numpy()
901
+ num_features = len(values)
902
+
903
+ # Adjust figure width based on number of features
904
+ fig_width = max(20, num_features * 0.025) # 0.025 inches per bar (adjustable)
905
+ plt.figure(figsize=(fig_width, 5)) # Wider for more bars
906
+
907
+ # Plot bars with optimized width & spacing
908
+ bars = plt.bar(
909
+ range(num_features),
910
+ values,
911
+ color='#1f77b4', # Matplotlib default blue (better than 'blue')
912
+ edgecolor='#1f77b4',
913
+ linewidth=0.05, # Thinner border for dense plots
914
+ width=0.9, # Slightly narrower to guarantee gaps
915
+ align='center'
916
+ )
917
+
918
+ # Hide x-axis labels if too many features
919
+ if num_features > 100:
920
+ ticks = list(range(0, num_features, 50)) + [num_features-1] # Add last feature
921
+ plt.xticks(ticks) # Diagonal labels
922
+
923
+ plt.title(f'Feature Vector - {layer_name} (Class: {true_label})')
924
+ plt.xlabel('Feature')
925
+ plt.ylabel('Activation Value')
926
+ plt.grid(True, linestyle=':', alpha=0.5)
927
+ plt.tight_layout()
928
+ vector_path = f'/kaggle/working/visualizations/class_{class_idx}_feature_vector_{layer_name}.png'
929
+ plt.savefig(vector_path, dpi=120, bbox_inches='tight', facecolor='white')
930
+ plt.close()
931
+ print(f"Saved feature vector with {len(values)} elements for {layer_name} to: {vector_path}")
932
+
933
+ # Process classification probabilities
934
+ plt.figure(figsize=(8, 6))
935
+ bars = plt.bar(classes, prob)
936
+ plt.title(f'Classification Probabilities (True: {true_label})')
937
+ plt.xlabel('Classes')
938
+ plt.ylabel('Probability')
939
+ plt.xticks(rotation=45)
940
+
941
+ for bar in bars:
942
+ height = bar.get_height()
943
+ plt.text(bar.get_x() + bar.get_width()/2., height,
944
+ f'{height:.3f}',
945
+ ha='center', va='bottom',
946
+ fontsize=10, fontweight='bold')
947
+
948
+ plt.tight_layout()
949
+ probs_path = f'/kaggle/working/visualizations/class_{class_idx}_classification_probs.png'
950
+ plt.savefig(probs_path)
951
+ plt.close()
952
+ print(f"Saved classification probabilities to: {probs_path}")
953
+
954
+ print("\nListing saved visualization files:")
955
+ os.system('ls /kaggle/working/visualizations/')
956
+
957
+ # Training Function
958
+ def train_model(model, dataloader, criterion, optimizer, device, epochs=100, model_idx=0):
959
+ history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
960
+ best_val_loss = float('inf')
961
+ best_val_acc = 0.0
962
+ best_train_acc = 0.0
963
+ patience = 10
964
+ counter = 0
965
+ scaler = torch.cuda.amp.GradScaler()
966
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
967
+ train_batches = len(dataloader['train'])
968
+ val_batches = len(dataloader['val'])
969
+ print(f"Training dataset: {train_batches} batches")
970
+ print(f"Validation dataset: {val_batches} batches")
971
+ for epoch in range(epochs):
972
+ print(f"\n--- Epoch {epoch+1}/{epochs} ---")
973
+ model.train()
974
+ running_loss = 0.0
975
+ correct, total = 0, 0
976
+ for batch_idx, (images, labels, features) in enumerate(dataloader['train']):
977
+ print(f"Training epoch {epoch+1}, batch {batch_idx+1}/{train_batches}")
978
+ images, labels, features = images.to(device), labels.to(device), features.to(device)
979
+ optimizer.zero_grad()
980
+ with torch.cuda.amp.autocast():
981
+ outputs = model(images, features)
982
+ loss = criterion(outputs, labels)
983
+ scaler.scale(loss).backward()
984
+ scaler.step(optimizer)
985
+ scaler.update()
986
+ running_loss += loss.item()
987
+ _, predicted = outputs.max(1)
988
+ total += labels.size(0)
989
+ correct += predicted.eq(labels).sum().item()
990
+ epoch_loss = running_loss / train_batches if train_batches > 0 else 0.0
991
+ epoch_acc = 100. * correct / total if total > 0 else 0.0
992
+ history['train_loss'].append(epoch_loss)
993
+ history['train_acc'].append(epoch_acc)
994
+ best_train_acc = max(best_train_acc, epoch_acc)
995
+ model.eval()
996
+ val_loss = 0.0
997
+ val_correct, val_total = 0, 0
998
+ with torch.no_grad():
999
+ for batch_idx, (val_images, val_labels, val_features) in enumerate(dataloader['val']):
1000
+ print(f"Validation epoch {epoch+1}, batch {batch_idx+1}/{val_batches}")
1001
+ val_images, val_labels, val_features = val_images.to(device), val_labels.to(device), val_features.to(device)
1002
+ with torch.cuda.amp.autocast():
1003
+ val_outputs = model(val_images, val_features)
1004
+ loss = criterion(val_outputs, val_labels)
1005
+ val_loss += loss.item()
1006
+ _, predicted = val_outputs.max(1)
1007
+ val_total += val_labels.size(0)
1008
+ val_correct += predicted.eq(val_labels).sum().item()
1009
+ val_epoch_loss = val_loss / val_batches if val_batches > 0 else 0.0
1010
+ val_epoch_acc = 100. * val_correct / val_total if val_total > 0 else 0.0
1011
+ history['val_loss'].append(val_epoch_loss)
1012
+ history['val_acc'].append(val_epoch_acc)
1013
+ best_val_acc = max(best_val_acc, val_epoch_acc)
1014
+ scheduler.step(val_epoch_loss)
1015
+ if val_epoch_loss < best_val_loss:
1016
+ best_val_loss = val_epoch_loss
1017
+ counter = 0
1018
+ torch.save(model.state_dict(), f'/kaggle/working/best_model_{model_idx}.pth')
1019
+ else:
1020
+ counter += 1
1021
+ if counter >= patience:
1022
+ print("Early stopping triggered")
1023
+ break
1024
+ print(f"Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")
1025
+ print(f"Validation Loss: {val_epoch_loss:.4f}, Accuracy: {val_epoch_acc:.2f}%")
1026
+ print(f"Best Training Accuracy (Model {model_idx}): {best_train_acc:.2f}%")
1027
+ print(f"Best Validation Accuracy (Model {model_idx}): {best_val_acc:.2f}%")
1028
+ return history, best_train_acc, best_val_acc
1029
+
1030
+ # Function to Plot Training History
1031
+ def plot_training_history(history, epochs, model_idx=0):
1032
+ epochs_range = range(1, len(history['train_loss']) + 1)
1033
+ plt.figure(figsize=(12, 5))
1034
+ plt.subplot(1, 2, 1)
1035
+ plt.plot(epochs_range, history['train_loss'], label='Training Loss')
1036
+ plt.plot(epochs_range, history['val_loss'], label='Validation Loss')
1037
+ plt.xlabel('Epochs')
1038
+ plt.ylabel('Loss')
1039
+ plt.title(f'Training and Validation Loss (Model {model_idx})')
1040
+ plt.legend()
1041
+ plt.grid(True)
1042
+ plt.subplot(1, 2, 2)
1043
+ plt.plot(epochs_range, history['train_acc'], label='Training Accuracy')
1044
+ plt.plot(epochs_range, history['val_acc'], label='Validation Accuracy')
1045
+ plt.xlabel('Epochs')
1046
+ plt.ylabel('Accuracy (%)')
1047
+ plt.title(f'Training and Validation Accuracy (Model {model_idx})')
1048
+ plt.legend()
1049
+ plt.grid(True)
1050
+ plt.tight_layout()
1051
+ plt.savefig(f'/kaggle/working/visualizations/training_history_model_{model_idx}.png')
1052
+ plt.close()
1053
+
1054
+ # Function to Evaluate Model
1055
+ def evaluate_model(model, dataloader, device, split_name, classes, model_idx=None, use_tta=False):
1056
+ model.eval()
1057
+ correct = 0
1058
+ total = 0
1059
+ all_predictions = []
1060
+ all_labels = []
1061
+ all_probs = []
1062
+ mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(3, 1, 1)
1063
+ std = torch.tensor([0.229, 0.224, 0.225], device=device).view(3, 1, 1)
1064
+ tta_transforms = [
1065
+ transforms.Compose([
1066
+ transforms.ToTensor(),
1067
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
1068
+ ]),
1069
+ transforms.Compose([
1070
+ transforms.RandomHorizontalFlip(p=1.0),
1071
+ transforms.ToTensor(),
1072
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
1073
+ ]),
1074
+ transforms.Compose([
1075
+ transforms.RandomRotation(10),
1076
+ transforms.ToTensor(),
1077
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
1078
+ ])
1079
+ ]
1080
+ with torch.no_grad():
1081
+ for images, labels, features in dataloader:
1082
+ images, labels, features = images.to(device), labels.to(device), features.to(device)
1083
+ if use_tta:
1084
+ batch_probs = []
1085
+ for transform in tta_transforms:
1086
+ denorm_images = images * std + mean
1087
+ denorm_images = denorm_images.clamp(0, 1) * 255
1088
+ denorm_images = denorm_images.to(torch.uint8)
1089
+ tta_images = torch.stack([
1090
+ transform(Image.fromarray(img.cpu().numpy().transpose(1, 2, 0)))
1091
+ for img in denorm_images
1092
+ ]).to(device)
1093
+ outputs = model(tta_images, features)
1094
+ batch_probs.append(F.softmax(outputs, dim=1))
1095
+ avg_probs = torch.stack(batch_probs).mean(dim=0)
1096
+ _, predicted = torch.max(avg_probs, 1)
1097
+ all_probs.extend(avg_probs.cpu().numpy())
1098
+ else:
1099
+ outputs = model(images, features)
1100
+ _, predicted = torch.max(outputs.data, 1)
1101
+ all_probs.extend(F.softmax(outputs, dim=1).cpu().numpy())
1102
+ all_predictions.extend(predicted.cpu().numpy())
1103
+ all_labels.extend(labels.cpu().numpy())
1104
+ total += labels.size(0)
1105
+ correct += (predicted == labels).sum().item()
1106
+ accuracy = 100 * correct / total if total > 0 else 0.0
1107
+ # cm = confusion_matrix(all_labels, all_predictions)
1108
+ # plt.figure(figsize=(10, 8))
1109
+ # sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
1110
+ # plt.title(f'Confusion Matrix - {split_name}' + (f' (Model {model_idx})' if model_idx is not None else ''))
1111
+ # plt.xlabel('Predicted')
1112
+ # plt.ylabel('True')
1113
+
1114
+
1115
+ cm = confusion_matrix(all_labels, all_predictions)
1116
+ plt.figure(figsize=(10, 8))
1117
+
1118
+ # Create heatmap with custom annotation formatting
1119
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
1120
+ xticklabels=classes, yticklabels=classes,
1121
+ annot_kws={'size': 12, 'weight': 'bold'}) # Larger bold annotations
1122
+
1123
+ # Make title and axis labels bold
1124
+ plt.title(f'Confusion Matrix - {split_name}' +
1125
+ (f' (Model {model_idx})' if model_idx is not None else ''),
1126
+ fontsize=14, fontweight='bold') # Bold title with larger font
1127
+
1128
+ plt.xlabel('Predicted', fontsize=12, fontweight='bold') # Bold x-label
1129
+ plt.ylabel('True', fontsize=12, fontweight='bold') # Bold y-label
1130
+
1131
+ plt.tight_layout()
1132
+ filename = f'/kaggle/working/visualizations/cm_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png')
1133
+ plt.savefig(filename)
1134
+ plt.close()
1135
+ report = classification_report(all_labels, all_predictions, target_names=classes, output_dict=True)
1136
+ report_df = pd.DataFrame(report).transpose()
1137
+ report_filename = f'/kaggle/working/classification_report_{split_name.lower()}' + (f'_model_{model_idx}.csv' if model_idx is not None else '_ensemble.csv')
1138
+ report_df.to_csv(report_filename)
1139
+ all_probs = np.array(all_probs)
1140
+ plot_roc_curves(all_labels, all_probs, split_name, classes, model_idx)
1141
+ return accuracy, all_predictions, all_labels, all_probs, report_df
1142
+
1143
+ # Ensemble Voting Function
1144
+ def ensemble_voting(models, dataloader, device, split_name, classes):
1145
+ all_predictions = []
1146
+ all_labels = []
1147
+ all_probs = []
1148
+ for model in models:
1149
+ model.eval()
1150
+ with torch.no_grad():
1151
+ for images, labels, features in dataloader:
1152
+ images, labels, features = images.to(device), labels.to(device), features.to(device)
1153
+ votes = []
1154
+ probs = []
1155
+ for model in models:
1156
+ outputs = model(images, features)
1157
+ _, predicted = torch.max(outputs.data, 1)
1158
+ votes.append(predicted.cpu().numpy())
1159
+ probs.append(F.softmax(outputs, dim=1).cpu().numpy())
1160
+ votes = np.array(votes)
1161
+ final_predictions = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=votes)
1162
+ avg_probs = np.mean(probs, axis=0)
1163
+ all_predictions.extend(final_predictions)
1164
+ all_labels.extend(labels.cpu().numpy())
1165
+ all_probs.extend(avg_probs)
1166
+ accuracy = 100 * sum(np.array(all_predictions) == np.array(all_labels)) / len(all_labels)
1167
+ # cm = confusion_matrix(all_labels, all_predictions)
1168
+ # plt.figure(figsize=(10, 8))
1169
+ # sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
1170
+ # plt.title(f'Confusion Matrix - {split_name} (Ensemble)')
1171
+ # plt.xlabel('Predicted')
1172
+ # plt.ylabel('True')
1173
+
1174
+
1175
+ cm = confusion_matrix(all_labels, all_predictions)
1176
+ plt.figure(figsize=(10, 8))
1177
+
1178
+ # Create heatmap with custom annotation formatting
1179
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
1180
+ xticklabels=classes, yticklabels=classes,
1181
+ annot_kws={'size': 12, 'weight': 'bold'}) # Larger bold annotations
1182
+
1183
+ plt.title(f'Confusion Matrix - {split_name} (Ensemble)', fontsize=14, fontweight='bold')
1184
+
1185
+ plt.xlabel('Predicted', fontsize=12, fontweight='bold') # Bold x-label
1186
+ plt.ylabel('True', fontsize=12, fontweight='bold') # Bold y-label
1187
+
1188
+
1189
+ plt.tight_layout()
1190
+ plt.savefig(f'/kaggle/working/visualizations/cm_{split_name.lower()}_ensemble.png')
1191
+ plt.close()
1192
+ report = classification_report(all_labels, all_predictions, target_names=classes, output_dict=True)
1193
+ report_df = pd.DataFrame(report).transpose()
1194
+ report_df.to_csv(f'/kaggle/working/classification_report_{split_name.lower()}_ensemble.csv')
1195
+ all_probs = np.array(all_probs)
1196
+ plot_roc_curves(all_labels, all_probs, f'{split_name} (Ensemble)', classes)
1197
+ return accuracy, all_predictions, all_labels, all_probs, report_df
1198
+
1199
+ # Function to visualize voting process
1200
+ def visualize_voting_process(models, dataloader, device, classes, num_samples=5):
1201
+ model_predictions = []
1202
+ true_labels = []
1203
+ images_list = []
1204
+ with torch.no_grad():
1205
+ for images, labels, features in dataloader:
1206
+ images, labels, features = images.to(device), labels.to(device), features.to(device)
1207
+ batch_preds = []
1208
+ for model in models:
1209
+ model.eval()
1210
+ outputs = model(images, features)
1211
+ _, predicted = torch.max(outputs.data, 1)
1212
+ batch_preds.append(predicted.cpu().numpy())
1213
+ model_predictions.extend(np.array(batch_preds).T)
1214
+ true_labels.extend(labels.cpu().numpy())
1215
+ images_list.extend(images.cpu().numpy())
1216
+ if len(true_labels) >= num_samples:
1217
+ break
1218
+ model_predictions = model_predictions[:num_samples]
1219
+ true_labels = true_labels[:num_samples]
1220
+ images_list = images_list[:num_samples]
1221
+ plt.figure(figsize=(15, num_samples * 3))
1222
+ for i in range(num_samples):
1223
+ img = images_list[i].transpose(1, 2, 0)
1224
+ img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
1225
+ img = np.clip(img, 0, 1)
1226
+ plt.subplot(num_samples, 1, i + 1)
1227
+ plt.imshow(img)
1228
+ preds = [CLASSES[p] for p in model_predictions[i]]
1229
+ ensemble_pred = CLASSES[np.bincount(model_predictions[i]).argmax()]
1230
+ title = f'True: {CLASSES[true_labels[i]]}\n' + \
1231
+ f'Model 1: {preds[0]}, Model 2: {preds[1]}, Model 3: {preds[2]}\n' + \
1232
+ f'Ensemble: {ensemble_pred}'
1233
+ plt.title(title)
1234
+ plt.axis('off')
1235
+ plt.tight_layout()
1236
+ plt.savefig('/kaggle/working/visualizations/voting_process.png')
1237
+ plt.close()
1238
+
1239
+ # Function to visualize predictions per class
1240
+ # def visualize_predictions_per_class(model, dataloader, device, classes, split_name, model_idx=None, num_samples=4):
1241
+ # model.eval()
1242
+ # class_images = {i: [] for i in range(len(classes))}
1243
+ # class_preds = {i: [] for i in range(len(classes))}
1244
+ # class_labels = {i: [] for i in range(len(classes))}
1245
+ # with torch.no_grad():
1246
+ # for images, labels, features in dataloader:
1247
+ # images, labels, features = images.to(device), labels.to(device), features.to(device)
1248
+ # outputs = model(images, features)
1249
+ # _, predicted = torch.max(outputs.data, 1)
1250
+ # for img, pred, label in zip(images.cpu().numpy(), predicted.cpu().numpy(), labels.cpu().numpy()):
1251
+ # if len(class_images[label]) < num_samples:
1252
+ # class_images[label].append(img)
1253
+ # class_preds[label].append(pred)
1254
+ # class_labels[label].append(label)
1255
+ # if all(len(class_images[i]) >= num_samples for i in range(len(classes))):
1256
+ # break
1257
+ # for class_idx, class_name in enumerate(classes):
1258
+ # plt.figure(figsize=(15, 5))
1259
+ # for i in range(min(num_samples, len(class_images[class_idx]))):
1260
+ # img = class_images[class_idx][i].transpose(1, 2, 0)
1261
+ # img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
1262
+ # img = np.clip(img, 0, 1)
1263
+ # plt.subplot(1, num_samples, i + 1)
1264
+ # plt.imshow(img)
1265
+ # plt.title(f'True: {class_name}\nPred: {CLASSES[class_preds[class_idx][i]]}')
1266
+ # plt.axis('off')
1267
+ # plt.suptitle(f'Predictions for {class_name} ({split_name})')
1268
+ # plt.tight_layout(rect=[0, 0, 1, 0.95])
1269
+ # filename = f'/kaggle/working/visualizations/predictions_{class_name}_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png')
1270
+ # plt.savefig(filename)
1271
+ # plt.close()
1272
+
1273
+ def visualize_predictions_per_class(model, dataloader, device, classes, split_name, model_idx=None, num_samples=4):
1274
+ model.eval()
1275
+ class_images = {i: [] for i in range(len(classes))}
1276
+ class_preds = {i: [] for i in range(len(classes))}
1277
+ class_labels = {i: [] for i in range(len(classes))}
1278
+ with torch.no_grad():
1279
+ for images, labels, features in dataloader:
1280
+ images, labels, features = images.to(device), labels.to(device), features.to(device)
1281
+ outputs = model(images, features)
1282
+ _, predicted = torch.max(outputs.data, 1)
1283
+ for img, pred, label in zip(images.cpu().numpy(), predicted.cpu().numpy(), labels.cpu().numpy()):
1284
+ if len(class_images[label]) < num_samples:
1285
+ class_images[label].append(img)
1286
+ class_preds[label].append(pred)
1287
+ class_labels[label].append(label)
1288
+ if all(len(class_images[i]) >= num_samples for i in range(len(classes))):
1289
+ break
1290
+ for class_idx, class_name in enumerate(classes):
1291
+ plt.figure(figsize=(15, 5))
1292
+ for i in range(min(num_samples, len(class_images[class_idx]))):
1293
+ img = class_images[class_idx][i].transpose(1, 2, 0)
1294
+ img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
1295
+ img = np.clip(img, 0, 1)
1296
+ plt.subplot(1, num_samples, i + 1)
1297
+ plt.imshow(img)
1298
+ plt.title(f'True: {class_name}\nPred: {CLASSES[class_preds[class_idx][i]]}',
1299
+ fontweight='bold') # Bold title
1300
+ plt.axis('off')
1301
+ # Make suptitle bold and adjust font properties
1302
+ plt.suptitle(f'Predictions for {class_name} ({split_name})',
1303
+ fontweight='bold',
1304
+ fontsize=12) # Optional: slightly larger font
1305
+ plt.tight_layout(rect=[0, 0, 1, 0.95])
1306
+ filename = f'/kaggle/working/visualizations/predictions_{class_name}_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png')
1307
+ plt.savefig(filename)
1308
+ plt.close()
1309
+
1310
+
1311
+ # # Function to visualize predictions combined per class
1312
+ # def visualize_predictions_grid_per_class(model, dataloader, device, classes, split_name, model_idx=None, num_samples=2):
1313
+ # import os
1314
+ # os.makedirs("/kaggle/working/visualizations", exist_ok=True)
1315
+
1316
+ # model.eval()
1317
+ # num_classes = len(classes)
1318
+
1319
+ # # Collect samples
1320
+ # class_images = {i: [] for i in range(num_classes)}
1321
+ # class_preds = {i: [] for i in range(num_classes)}
1322
+ # class_labels = {i: [] for i in range(num_classes)}
1323
+
1324
+ # with torch.no_grad():
1325
+ # for images, labels, features in dataloader:
1326
+ # images, labels, features = images.to(device), labels.to(device), features.to(device)
1327
+ # outputs = model(images, features)
1328
+ # _, predicted = torch.max(outputs.data, 1)
1329
+ # for img, pred, label in zip(images.cpu().numpy(), predicted.cpu().numpy(), labels.cpu().numpy()):
1330
+ # if len(class_images[label]) < num_samples:
1331
+ # class_images[label].append(img)
1332
+ # class_preds[label].append(pred)
1333
+ # class_labels[label].append(label)
1334
+ # if all(len(class_images[i]) >= num_samples for i in range(num_classes)):
1335
+ # break
1336
+
1337
+ # # Plot: Grid of num_samples rows × num_classes columns
1338
+ # plt.figure(figsize=(4 * num_classes, 4 * num_samples))
1339
+ # for row in range(num_samples):
1340
+ # for class_idx in range(num_classes):
1341
+ # if row >= len(class_images[class_idx]):
1342
+ # continue
1343
+ # img = class_images[class_idx][row].transpose(1, 2, 0)
1344
+ # img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) # unnormalize
1345
+ # img = np.clip(img, 0, 1)
1346
+ # ax_idx = row * num_classes + class_idx + 1
1347
+ # plt.subplot(num_samples, num_classes, ax_idx)
1348
+ # true_label = classes[class_labels[class_idx][row]]
1349
+ # pred_label = classes[class_preds[class_idx][row]]
1350
+ # plt.imshow(img)
1351
+ # plt.title(f'True: {true_label}\nPred: {pred_label}')
1352
+ # plt.axis('off')
1353
+
1354
+ # plt.suptitle(f'{split_name} Predictions Grid ({num_samples}×{num_classes})')
1355
+ # filename = f'/kaggle/working/visualizations/predictions_grid_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png')
1356
+ # plt.tight_layout(rect=[0, 0, 1, 0.95])
1357
+ # plt.savefig(filename)
1358
+ # plt.close()
1359
+
1360
+
1361
+ # Main Execution
1362
+ if __name__ == "__main__":
1363
+ # Set device
1364
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1365
+ print(f"Using device: {device}")
1366
+
1367
+ # Debug dataset directory
1368
+ print("Checking dataset directory structure:")
1369
+ for dirname, _, filenames in os.walk(DATASET_PATH):
1370
+ print(f"Directory: {dirname}, Files: {len(filenames)}")
1371
+ for filename in filenames[:5]:
1372
+ print(f" - {os.path.join(dirname, filename)}")
1373
+
1374
+ # Check for CSV file
1375
+ csv_path = os.path.join(DATASET_PATH, "labels.csv")
1376
+ use_csv = os.path.exists(csv_path)
1377
+ if use_csv:
1378
+ print("Detected labels.csv, will load dataset from CSV")
1379
+ else:
1380
+ print("No labels.csv found, assuming directory-based structure")
1381
+
1382
+ # Load dataset
1383
+ train_path = os.path.join(DATASET_PATH, "TRAIN")
1384
+ val_path = os.path.join(DATASET_PATH, "VALIDATION")
1385
+ test_path = os.path.join(DATASET_PATH, "TEST")
1386
+
1387
+ train_images, train_labels, train_features = load_images(train_path, CLASSES, use_csv)
1388
+ val_images, val_labels, val_features = load_images(val_path, CLASSES, use_csv)
1389
+ test_images, test_labels, test_features = load_images(test_path, CLASSES, use_csv)
1390
+
1391
+ # Check if datasets are empty
1392
+ if not train_images:
1393
+ raise ValueError("Training dataset is empty. Please check the dataset path, class names, image files, or CSV structure.")
1394
+ if not val_images:
1395
+ print("Warning: Validation dataset is empty. Creating validation split from training data.")
1396
+ train_images, val_images, train_labels, val_labels, train_features, val_features = train_test_split(
1397
+ train_images, train_labels, train_features, test_size=0.2, stratify=train_labels, random_state=42
1398
+ )
1399
+ if not test_images:
1400
+ print("Warning: Test dataset is empty.")
1401
+ test_images, test_labels, test_features = [], [], []
1402
+
1403
+ # Visualize dataset
1404
+ visualize_data_distribution(train_path, "Train", CLASSES)
1405
+ visualize_data_distribution(val_path, "Validation", CLASSES)
1406
+ visualize_data_distribution(test_path, "Test", CLASSES)
1407
+ display_sample_images(train_images, train_labels, "Train", CLASSES)
1408
+ display_sample_images(val_images, val_labels, "Validation", CLASSES)
1409
+ display_sample_images(test_images, test_labels, "Test", CLASSES)
1410
+ visualize_handcrafted_features(train_images, train_labels, CLASSES)
1411
+
1412
+ # Create datasets
1413
+ train_dataset = FootUlcerDataset(train_images, train_labels, train_features)
1414
+ val_dataset = FootUlcerDataset(val_images, val_labels, val_features)
1415
+ test_dataset = FootUlcerDataset(test_images, test_labels, test_features)
1416
+
1417
+ # Create WeightedRandomSampler
1418
+ train_labels_np = np.array(train_labels)
1419
+ class_counts = np.array([sum(train_labels_np == i) for i in range(len(CLASSES))])
1420
+ print(f"Class counts: {dict(zip(CLASSES, class_counts))}")
1421
+ if np.any(class_counts == 0):
1422
+ print("Warning: Some classes have zero samples in the training set.")
1423
+ class_weights = 1.0 / (class_counts + 1e-6)
1424
+ sample_weights = class_weights[train_labels_np]
1425
+ sampler = WeightedRandomSampler(sample_weights, len(train_labels), replacement=True)
1426
+
1427
+ # Create DataLoaders
1428
+ batch_size = 32
1429
+ dataloader = {
1430
+ 'train': DataLoader(train_dataset, batch_size=batch_size, sampler=sampler),
1431
+ 'val': DataLoader(val_dataset, batch_size=batch_size, shuffle=False),
1432
+ 'test': DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
1433
+ }
1434
+
1435
+ # Train and evaluate models
1436
+ num_models = 3
1437
+ models_list = []
1438
+ test_accuracies = []
1439
+ best_train_accuracies = []
1440
+ best_val_accuracies = []
1441
+
1442
+ for i in range(num_models):
1443
+ print(f"\nTraining Model {i+1}/{num_models}")
1444
+ model = DenseShuffleGCANet(num_classes=len(CLASSES), handcrafted_feature_dim=41).to(device)
1445
+ print(f"\nModel {i+1} Summary:")
1446
+ print_model_summary(model, input_size=(3, 224, 224), handcrafted_feature_dim=41)
1447
+
1448
+ criterion = FocalLoss(gamma=3.0, alpha=0.5)
1449
+ optimizer = optim.Adam(model.parameters(), lr=0.00005, weight_decay=0.001)
1450
+ history, best_train_acc, best_val_acc = train_model(model, dataloader, criterion, optimizer, device, epochs=100, model_idx=i)
1451
+ plot_training_history(history, len(history['train_loss']), model_idx=i)
1452
+ best_train_accuracies.append(best_train_acc)
1453
+ best_val_accuracies.append(best_val_acc)
1454
+
1455
+ print(f"\nEvaluating Model {i+1} on Training Set")
1456
+ train_acc, train_preds, train_labels, _, _ = evaluate_model(model, dataloader['train'], device, 'Train', CLASSES, i)
1457
+ print(f"Model {i+1} Train Accuracy: {train_acc:.2f}%")
1458
+
1459
+ print(f"\nEvaluating Model {i+1} on Validation Set")
1460
+ val_acc, val_preds, val_labels, _, _ = evaluate_model(model, dataloader['val'], device, 'Validation', CLASSES, i)
1461
+ print(f"Model {i+1} Validation Accuracy: {val_acc:.2f}%")
1462
+
1463
+ print(f"\nEvaluating Model {i+1} on Test Set")
1464
+ test_acc, test_preds, test_labels, _, _ = evaluate_model(model, dataloader['test'], device, 'Test', CLASSES, i)
1465
+ print(f"Model {i+1} Test Accuracy: {test_acc:.2f}%")
1466
+ test_accuracies.append(test_acc)
1467
+
1468
+ visualize_predictions_per_class(model, dataloader['test'], device, CLASSES, 'Test', model_idx=i)
1469
+ if i == 0:
1470
+ print(f"\nVisualizing Feature Extraction for Model {i+1}")
1471
+ visualize_feature_extraction(model, dataloader['test'], device, CLASSES, num_samples_per_class=1)
1472
+ models_list.append(model)
1473
+
1474
+ # Evaluate ensemble
1475
+ print("\nEvaluating Ensemble on Training Set")
1476
+ ensemble_train_acc, _, _, _, _ = ensemble_voting(models_list, dataloader['train'], device, 'Train', CLASSES)
1477
+ print(f"Ensemble Train Accuracy: {ensemble_train_acc:.2f}%")
1478
+
1479
+ print("\nEvaluating Ensemble on Validation Set")
1480
+ ensemble_val_acc, _, _, _, _ = ensemble_voting(models_list, dataloader['val'], device, 'Validation', CLASSES)
1481
+ print(f"Ensemble Validation Accuracy: {ensemble_val_acc:.2f}%")
1482
+
1483
+ print("\nEvaluating Ensemble on Test Set")
1484
+ ensemble_test_acc, ensemble_test_preds, ensemble_test_labels, _, _ = ensemble_voting(models_list, dataloader['test'], device, 'Test', CLASSES)
1485
+ print(f"Ensemble Test Accuracy: {ensemble_test_acc:.2f}%")
1486
+
1487
+ visualize_predictions_per_class(models_list[0], dataloader['test'], device, CLASSES, 'Test', model_idx=None)
1488
+ visualize_voting_process(models_list, dataloader['test'], device, CLASSES)
1489
+
1490
+ # Evaluate TTA
1491
+ print("\nEvaluating Best Model with TTA on Test Set")
1492
+ tta_acc, _, _, _, _ = evaluate_model(models_list[0], dataloader['test'], device, 'Test_TTA', CLASSES, model_idx=0, use_tta=True)
1493
+ print(f"TTA Test Accuracy: {tta_acc:.2f}%")
1494
+
1495
+ # Statistical Analysis
1496
+ print("\nStatistical Analysis of Model Performance:")
1497
+ print(f"Mean Test Accuracy: {np.mean(test_accuracies):.2f}% ± {np.std(test_accuracies):.2f}%")
1498
+ print(f"Best Training Accuracies: {[f'{acc:.2f}%' for acc in best_train_accuracies]}")
1499
+ print(f"Best Validation Accuracies: {[f'{acc:.2f}%' for acc in best_val_accuracies]}")
1500
+
1501
+ # Save predictions
1502
+ predictions_df = pd.DataFrame({
1503
+ 'True_Label': [CLASSES[label] for label in ensemble_test_labels],
1504
+ 'Predicted_Label': [CLASSES[pred] for pred in ensemble_test_preds]
1505
+ })
1506
+ predictions_df.to_csv('/kaggle/working/predictions/test_predictions_ensemble.csv', index=False)
1507
+ print("Predictions saved to /kaggle/working/predictions/test_predictions_ensemble.csv")
1508
+
1509
+ # Save summary report
1510
+ summary_report = {
1511
+ 'Model': [f'Model {i+1}' for i in range(num_models)] + ['Ensemble', 'TTA'],
1512
+ 'Test_Accuracy': test_accuracies + [ensemble_test_acc, tta_acc],
1513
+ 'Best_Train_Accuracy': best_train_accuracies + [None, None],
1514
+ 'Best_Val_Accuracy': best_val_accuracies + [None, None]
1515
+ }
1516
+ summary_df = pd.DataFrame(summary_report)
1517
+ summary_df.to_csv('/kaggle/working/summary_report.csv', index=False)
1518
+ print("\nSummary Report:")
1519
+ print(summary_df)