huseyincavus commited on
Commit
924a630
·
verified ·
1 Parent(s): 15940f8

Upload DeepStroke_SE_ResNeXt50.ipynb

Browse files
Files changed (1) hide show
  1. DeepStroke_SE_ResNeXt50.ipynb +653 -0
DeepStroke_SE_ResNeXt50.ipynb ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "kernelspec": {
4
+ "name": "python3",
5
+ "display_name": "Python 3",
6
+ "language": "python"
7
+ },
8
+ "language_info": {
9
+ "name": "python",
10
+ "version": "3.11.11",
11
+ "mimetype": "text/x-python",
12
+ "codemirror_mode": {
13
+ "name": "ipython",
14
+ "version": 3
15
+ },
16
+ "pygments_lexer": "ipython3",
17
+ "nbconvert_exporter": "python",
18
+ "file_extension": ".py"
19
+ },
20
+ "colab": {
21
+ "provenance": [],
22
+ "gpuType": "T4"
23
+ },
24
+ "accelerator": "GPU",
25
+ "kaggle": {
26
+ "accelerator": "nvidiaTeslaT4",
27
+ "dataSources": [
28
+ {
29
+ "sourceId": 12128762,
30
+ "sourceType": "datasetVersion",
31
+ "datasetId": 7637522
32
+ }
33
+ ],
34
+ "dockerImageVersionId": 31041,
35
+ "isInternetEnabled": true,
36
+ "language": "python",
37
+ "sourceType": "notebook",
38
+ "isGpuEnabled": true
39
+ }
40
+ },
41
+ "nbformat_minor": 0,
42
+ "nbformat": 4,
43
+ "cells": [
44
+ {
45
+ "source": [
46
+ "# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE\n",
47
+ "# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.\n",
48
+ "import kagglehub\n",
49
+ "kagglehub.login()\n"
50
+ ],
51
+ "metadata": {
52
+ "id": "y-g5cNB2gIQw"
53
+ },
54
+ "cell_type": "code",
55
+ "outputs": [],
56
+ "execution_count": null
57
+ },
58
+ {
59
+ "cell_type": "markdown",
60
+ "source": [
61
+ "https://www.kaggle.com/code/huseyincavus/deepstroke-se-resnext50/notebook"
62
+ ],
63
+ "metadata": {
64
+ "id": "N3v_TfvYgIuD"
65
+ }
66
+ },
67
+ {
68
+ "source": [
69
+ "# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,\n",
70
+ "# THEN FEEL FREE TO DELETE THIS CELL.\n",
71
+ "# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON\n",
72
+ "# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR\n",
73
+ "# NOTEBOOK.\n",
74
+ "\n",
75
+ "huseyincavus_deepstroke_path = kagglehub.dataset_download('huseyincavus/deepstroke')\n",
76
+ "\n",
77
+ "print('Data source import complete.')\n"
78
+ ],
79
+ "metadata": {
80
+ "id": "9JQXx5XYgIQx"
81
+ },
82
+ "cell_type": "code",
83
+ "outputs": [],
84
+ "execution_count": null
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "source": [
89
+ "# Cell 1: Display random images from the dataset\n",
90
+ "import os\n",
91
+ "import random\n",
92
+ "import matplotlib.pyplot as plt\n",
93
+ "import matplotlib.image as mpimg\n",
94
+ "\n",
95
+ "# Define the data directories\n",
96
+ "base_dir = \"/kaggle/input/deepstroke/DeepStroke1_Data\"\n",
97
+ "ischaemic_dir = os.path.join(base_dir, \"Ischaemic\")\n",
98
+ "non_ischaemic_dir = os.path.join(base_dir, \"Non-Ischaemic\")\n",
99
+ "hemoraj_dir = os.path.join(base_dir, \"Hemoraj\")\n",
100
+ "\n",
101
+ "def display_random_images(directory, num_images=4):\n",
102
+ " if not os.path.exists(directory):\n",
103
+ " print(f\"Directory not found: {directory}\")\n",
104
+ " return\n",
105
+ "\n",
106
+ " all_files = os.listdir(directory)\n",
107
+ " # Filter out non-image files\n",
108
+ " image_files = [f for f in all_files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
109
+ "\n",
110
+ " if not image_files:\n",
111
+ " print(f\"No image files found in {directory}\")\n",
112
+ " return\n",
113
+ "\n",
114
+ " random_files = random.sample(image_files, min(num_images, len(image_files)))\n",
115
+ "\n",
116
+ " # Create a figure and axes for the subplots\n",
117
+ " fig, axes = plt.subplots(1, len(random_files), figsize=(15, 5))\n",
118
+ " if len(random_files) == 1:\n",
119
+ " axes = [axes]\n",
120
+ "\n",
121
+ " # Display each image in a subplot\n",
122
+ " for i, file in enumerate(random_files):\n",
123
+ " image_path = os.path.join(directory, file)\n",
124
+ " img = mpimg.imread(image_path)\n",
125
+ " axes[i].imshow(img)\n",
126
+ " axes[i].axis('off')\n",
127
+ " axes[i].set_title(os.path.basename(file))\n",
128
+ "\n",
129
+ " plt.tight_layout()\n",
130
+ " plt.show()\n",
131
+ "\n",
132
+ "# Display Ischaemic images\n",
133
+ "if os.path.exists(ischaemic_dir):\n",
134
+ " print(\"Ischaemic Images:\")\n",
135
+ " print(f\"Number of Ischaemic samples: {len(os.listdir(ischaemic_dir))}\")\n",
136
+ " display_random_images(ischaemic_dir)\n",
137
+ "else:\n",
138
+ " print(f\"Directory not found: {ischaemic_dir}\")\n",
139
+ "\n",
140
+ "# Display Non-Ischaemic images\n",
141
+ "if os.path.exists(non_ischaemic_dir):\n",
142
+ " print(\"\\nNon-Ischaemic Images:\")\n",
143
+ " print(f\"Number of Non-Ischemic samples: {len(os.listdir(non_ischaemic_dir))}\")\n",
144
+ " display_random_images(non_ischaemic_dir)\n",
145
+ "else:\n",
146
+ " print(f\"Directory not found: {non_ischaemic_dir}\")\n",
147
+ "\n",
148
+ "# Display Hemoraj images\n",
149
+ "if os.path.exists(hemoraj_dir):\n",
150
+ " print(\"\\nHemoraj Images:\")\n",
151
+ "\n",
152
+ " # Check if Hemoraj has subdirectories (class folders)\n",
153
+ " subdirs = [d for d in os.listdir(hemoraj_dir) if os.path.isdir(os.path.join(hemoraj_dir, d))]\n",
154
+ "\n",
155
+ " if subdirs: # If there are subdirectories (class folders)\n",
156
+ " for subdir in subdirs:\n",
157
+ " class_dir = os.path.join(hemoraj_dir, subdir)\n",
158
+ " print(f\"-- {subdir} Images:\")\n",
159
+ " print(f\" Number of {subdir} samples: {len(os.listdir(class_dir))}\")\n",
160
+ " display_random_images(class_dir)\n",
161
+ " else: # If images are directly in the base directory\n",
162
+ " print(f\"Number of samples: {len([f for f in os.listdir(hemoraj_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])}\")\n",
163
+ " display_random_images(hemoraj_dir)\n",
164
+ "else:\n",
165
+ " print(f\"Directory not found: {hemoraj_dir}\")"
166
+ ],
167
+ "metadata": {
168
+ "id": "E_hVhkJtX0Qc",
169
+ "trusted": true,
170
+ "execution": {
171
+ "iopub.status.busy": "2025-06-11T09:23:49.210399Z",
172
+ "iopub.execute_input": "2025-06-11T09:23:49.21068Z",
173
+ "iopub.status.idle": "2025-06-11T09:23:54.643448Z",
174
+ "shell.execute_reply.started": "2025-06-11T09:23:49.210656Z",
175
+ "shell.execute_reply": "2025-06-11T09:23:54.642689Z"
176
+ }
177
+ },
178
+ "outputs": [],
179
+ "execution_count": null
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "source": [
184
+ "# Cell 2: Prepare the dataset for training (Optimized for Speed)\n",
185
+ "import os\n",
186
+ "import random\n",
187
+ "import torch\n",
188
+ "from torch.utils.data import Dataset, DataLoader, random_split\n",
189
+ "from torchvision import transforms\n",
190
+ "from PIL import Image\n",
191
+ "import numpy as np\n",
192
+ "import matplotlib.pyplot as plt\n",
193
+ "\n",
194
+ "# --- OPTIMIZATION: Get the number of available CPU cores ---\n",
195
+ "# This will be used to parallelize data loading.\n",
196
+ "try:\n",
197
+ " NUM_CPUS = os.cpu_count()\n",
198
+ "except:\n",
199
+ " NUM_CPUS = 4 # A reasonable fallback for platforms where os.cpu_count() might fail\n",
200
+ "\n",
201
+ "# Constants for ResNet50\n",
202
+ "IMG_WIDTH = 224\n",
203
+ "IMG_HEIGHT = 224\n",
204
+ "BATCH_SIZE = 128 # Increased for better GPU utilization\n",
205
+ "\n",
206
+ "# ImageNet normalization values for ResNet50\n",
207
+ "MEAN = [0.485, 0.456, 0.406]\n",
208
+ "STD = [0.229, 0.224, 0.225]\n",
209
+ "\n",
210
+ "# Setup device for GPU usage\n",
211
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
212
+ "print(f\"Using device: {device}\")\n",
213
+ "print(f\"Using {NUM_CPUS} CPU workers for data loading.\")\n",
214
+ "\n",
215
+ "# DeepStroke dataset paths\n",
216
+ "base_dir = \"/kaggle/input/deepstroke/DeepStroke1_Data\"\n",
217
+ "normal_dir = os.path.join(base_dir, \"Non-Ischaemic\") # Normal images (label 0)\n",
218
+ "abnormal_dirs = [\n",
219
+ " os.path.join(base_dir, \"Ischaemic\"), # Abnormal images (label 1)\n",
220
+ " os.path.join(base_dir, \"Hemoraj\") # Abnormal images (label 1)\n",
221
+ "]\n",
222
+ "\n",
223
+ "class ImageDataset(Dataset):\n",
224
+ " def __init__(self, normal_dir, abnormal_dirs, transform=None):\n",
225
+ " self.image_paths = []\n",
226
+ " self.labels = []\n",
227
+ "\n",
228
+ " # Add normal images (label 0)\n",
229
+ " if os.path.exists(normal_dir):\n",
230
+ " for f in os.listdir(normal_dir):\n",
231
+ " if f.lower().endswith(('.png', '.jpg', '.jpeg')):\n",
232
+ " self.image_paths.append(os.path.join(normal_dir, f))\n",
233
+ " self.labels.append(0) # Normal\n",
234
+ "\n",
235
+ " # Add abnormal images (label 1)\n",
236
+ " for abnormal_dir in abnormal_dirs:\n",
237
+ " if os.path.exists(abnormal_dir):\n",
238
+ " for f in os.listdir(abnormal_dir):\n",
239
+ " if f.lower().endswith(('.png', '.jpg', '.jpeg')):\n",
240
+ " self.image_paths.append(os.path.join(abnormal_dir, f))\n",
241
+ " self.labels.append(1) # Abnormal\n",
242
+ "\n",
243
+ " self.transform = transform\n",
244
+ "\n",
245
+ " def __len__(self):\n",
246
+ " return len(self.image_paths)\n",
247
+ "\n",
248
+ " def __getitem__(self, idx):\n",
249
+ " image_path = self.image_paths[idx]\n",
250
+ " image = Image.open(image_path).convert('RGB')\n",
251
+ " label = self.labels[idx]\n",
252
+ " if self.transform:\n",
253
+ " image = self.transform(image)\n",
254
+ " return image, torch.tensor(label, dtype=torch.long)\n",
255
+ "\n",
256
+ "# Define the transformations for training and for validation/testing\n",
257
+ "train_transform = transforms.Compose([\n",
258
+ " transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),\n",
259
+ " transforms.RandomHorizontalFlip(),\n",
260
+ " transforms.RandomRotation(degrees=15),\n",
261
+ " transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),\n",
262
+ " transforms.ToTensor(),\n",
263
+ " transforms.Normalize(mean=MEAN, std=STD)\n",
264
+ "])\n",
265
+ "\n",
266
+ "val_test_transform = transforms.Compose([\n",
267
+ " transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),\n",
268
+ " transforms.ToTensor(),\n",
269
+ " transforms.Normalize(mean=MEAN, std=STD)\n",
270
+ "])\n",
271
+ "\n",
272
+ "# Create dataset with augmented images for training\n",
273
+ "dataset = ImageDataset(normal_dir, abnormal_dirs, transform=train_transform)\n",
274
+ "\n",
275
+ "# Split dataset into training (70%), validation (15%), and test (15%)\n",
276
+ "generator = torch.Generator().manual_seed(42)\n",
277
+ "train_size = int(0.7 * len(dataset))\n",
278
+ "val_size = int(0.15 * len(dataset))\n",
279
+ "test_size = len(dataset) - train_size - val_size\n",
280
+ "train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size], generator=generator)\n",
281
+ "\n",
282
+ "# --- Oversampling for Training Set ---\n",
283
+ "train_indices = train_dataset.indices\n",
284
+ "normal_indices = [i for i in train_indices if dataset[i][1].item() == 0]\n",
285
+ "abnormal_indices = [i for i in train_indices if dataset[i][1].item() == 1]\n",
286
+ "\n",
287
+ "num_normal = len(normal_indices)\n",
288
+ "num_abnormal = len(abnormal_indices)\n",
289
+ "print(f\"Training set before balancing: Normal={num_normal}, Abnormal={num_abnormal}\")\n",
290
+ "\n",
291
+ "if num_normal < num_abnormal:\n",
292
+ " oversampled_indices = random.choices(normal_indices, k=num_abnormal - num_normal)\n",
293
+ " new_train_indices = train_indices + oversampled_indices\n",
294
+ " print(f\"Oversampling Normal class: added {len(oversampled_indices)} samples\")\n",
295
+ "elif num_abnormal < num_normal:\n",
296
+ " oversampled_indices = random.choices(abnormal_indices, k=num_normal - num_abnormal)\n",
297
+ " new_train_indices = train_indices + oversampled_indices\n",
298
+ " print(f\"Oversampling Abnormal class: added {len(oversampled_indices)} samples\")\n",
299
+ "else:\n",
300
+ " new_train_indices = train_indices\n",
301
+ " print(\"Classes are already balanced\")\n",
302
+ "\n",
303
+ "random.shuffle(new_train_indices)\n",
304
+ "train_dataset = torch.utils.data.Subset(dataset, new_train_indices)\n",
305
+ "\n",
306
+ "# Create an \"original\" dataset for displaying purposes\n",
307
+ "original_dataset = ImageDataset(normal_dir, abnormal_dirs, transform=None)\n",
308
+ "original_train_dataset = torch.utils.data.Subset(original_dataset, new_train_indices)\n",
309
+ "\n",
310
+ "# Create datasets with validation/test transforms\n",
311
+ "val_dataset_with_transform = torch.utils.data.Subset(\n",
312
+ " ImageDataset(normal_dir, abnormal_dirs, transform=val_test_transform),\n",
313
+ " val_dataset.indices\n",
314
+ ")\n",
315
+ "test_dataset_with_transform = torch.utils.data.Subset(\n",
316
+ " ImageDataset(normal_dir, abnormal_dirs, transform=val_test_transform),\n",
317
+ " test_dataset.indices\n",
318
+ ")\n",
319
+ "\n",
320
+ "# --- OPTIMIZATION: Create DataLoaders using multiple CPU cores ---\n",
321
+ "# num_workers > 0 enables multi-process data loading.\n",
322
+ "# pin_memory=True speeds up CPU to GPU data transfer.\n",
323
+ "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=NUM_CPUS)\n",
324
+ "val_loader = DataLoader(val_dataset_with_transform, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=NUM_CPUS)\n",
325
+ "test_loader = DataLoader(test_dataset_with_transform, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=NUM_CPUS)\n",
326
+ "\n",
327
+ "\n",
328
+ "# The rest of the cell remains the same...\n",
329
+ "# Count images per class in the full dataset\n",
330
+ "normal_count = sum(1 for _, label in dataset if label.item() == 0)\n",
331
+ "abnormal_count = sum(1 for _, label in dataset if label.item() == 1)\n",
332
+ "\n",
333
+ "print(f\"\\nFull dataset:\")\n",
334
+ "print(f\"Number of normal images: {normal_count}\")\n",
335
+ "print(f\"Number of abnormal images: {abnormal_count}\")\n",
336
+ "print(f\"Total images: {len(dataset)}\")\n",
337
+ "\n",
338
+ "class_names = [\"Normal\", \"Abnormal\"]\n",
339
+ "\n",
340
+ "def display_random_images_from_dataset(dataset, save=False, filename=\"random_images.png\", num_images=4):\n",
341
+ " indices = random.sample(range(len(dataset)), min(num_images, len(dataset)))\n",
342
+ " fig, axes = plt.subplots(1, len(indices), figsize=(15, 5))\n",
343
+ " if len(indices) == 1:\n",
344
+ " axes = [axes]\n",
345
+ " for i, idx in enumerate(indices):\n",
346
+ " image, label = dataset[idx]\n",
347
+ " image = image.cpu() if isinstance(image, torch.Tensor) and image.device.type != 'cpu' else image\n",
348
+ " if isinstance(image, torch.Tensor):\n",
349
+ " image = image.cpu().numpy().transpose((1, 2, 0))\n",
350
+ " image = image * np.array(STD) + np.array(MEAN)\n",
351
+ " image = np.clip(image, 0, 1)\n",
352
+ " axes[i].imshow(image)\n",
353
+ " axes[i].axis('off')\n",
354
+ " label_idx = label.item() if torch.is_tensor(label) else label\n",
355
+ " axes[i].set_title(f\"Label: {class_names[label_idx]}\")\n",
356
+ " plt.tight_layout()\n",
357
+ " if save:\n",
358
+ " plt.savefig(filename)\n",
359
+ " print(f\"Saved random images to {filename}\")\n",
360
+ " plt.show()\n",
361
+ "\n",
362
+ "print(\"\\nRandom Images from Training Set:\")\n",
363
+ "display_random_images_from_dataset(train_dataset, save=False, filename=\"random_train_images.png\")\n",
364
+ "\n",
365
+ "def display_augmented_images(original_dataset, augmented_dataset, save=False, base_filename=\"augmented_\", num_images=4):\n",
366
+ " indices = random.sample(range(len(original_dataset)), min(num_images, len(original_dataset)))\n",
367
+ "\n",
368
+ " for i, idx in enumerate(indices):\n",
369
+ " orig_image, label = original_dataset[idx]\n",
370
+ " label_idx = label.item() if torch.is_tensor(label) else label\n",
371
+ " file_path = original_dataset.dataset.image_paths[original_dataset.indices[idx]]\n",
372
+ "\n",
373
+ " augmented_images = []\n",
374
+ " for _ in range(10):\n",
375
+ " augmented_image, _ = augmented_dataset[idx]\n",
376
+ " augmented_images.append(augmented_image)\n",
377
+ "\n",
378
+ " fig_aug, axes_aug = plt.subplots(2, 5, figsize=(15, 6))\n",
379
+ " fig_aug.suptitle(f\"Class: {class_names[label_idx]}\\nFilename: {os.path.basename(file_path)}\", fontsize=14)\n",
380
+ " axes_aug = axes_aug.flatten()\n",
381
+ "\n",
382
+ " for j, aug_img in enumerate(augmented_images):\n",
383
+ " aug_disp = aug_img.numpy().transpose((1, 2, 0))\n",
384
+ " aug_disp = aug_disp * np.array(STD) + np.array(MEAN)\n",
385
+ " aug_disp = np.clip(aug_disp, 0, 1)\n",
386
+ " axes_aug[j].imshow(aug_disp)\n",
387
+ " axes_aug[j].axis('off')\n",
388
+ " axes_aug[j].set_title(f\"Augmented {j+1}\")\n",
389
+ "\n",
390
+ " plt.tight_layout()\n",
391
+ " if save:\n",
392
+ " out_filename = f\"{base_filename}{class_names[label_idx]}_{i+1}.png\"\n",
393
+ " plt.savefig(out_filename)\n",
394
+ " plt.show()\n",
395
+ "\n",
396
+ "print(\"\\nSome Augmented Images:\")\n",
397
+ "display_augmented_images(original_train_dataset, train_dataset, save=False)"
398
+ ],
399
+ "metadata": {
400
+ "id": "mLGvtdz6X0Qe",
401
+ "trusted": true,
402
+ "execution": {
403
+ "iopub.status.busy": "2025-06-11T09:59:02.189054Z",
404
+ "iopub.execute_input": "2025-06-11T09:59:02.189649Z",
405
+ "iopub.status.idle": "2025-06-11T10:05:07.493857Z",
406
+ "shell.execute_reply.started": "2025-06-11T09:59:02.189616Z",
407
+ "shell.execute_reply": "2025-06-11T10:05:07.493183Z"
408
+ }
409
+ },
410
+ "outputs": [],
411
+ "execution_count": null
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "source": [
416
+ "import torch\n",
417
+ "\n",
418
+ "torch.cuda.empty_cache()\n",
419
+ "torch.cuda.synchronize() # optional: waits for all kernels to finish"
420
+ ],
421
+ "metadata": {
422
+ "trusted": true,
423
+ "execution": {
424
+ "iopub.status.busy": "2025-06-11T10:05:42.242558Z",
425
+ "iopub.execute_input": "2025-06-11T10:05:42.242855Z",
426
+ "iopub.status.idle": "2025-06-11T10:05:42.246728Z",
427
+ "shell.execute_reply.started": "2025-06-11T10:05:42.242834Z",
428
+ "shell.execute_reply": "2025-06-11T10:05:42.246031Z"
429
+ },
430
+ "id": "bwhjOi91gIQz"
431
+ },
432
+ "outputs": [],
433
+ "execution_count": null
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "source": [
438
+ "import torch\n",
439
+ "import torch.nn as nn\n",
440
+ "import torch.optim as optim\n",
441
+ "import torchvision.models as models\n",
442
+ "from torchvision.models.resnet import Bottleneck\n",
443
+ "from tqdm import tqdm\n",
444
+ "import matplotlib.pyplot as plt\n",
445
+ "import numpy as np\n",
446
+ "import os\n",
447
+ "import pickle\n",
448
+ "# --- OPTIMIZATION: Import for Automatic Mixed Precision (AMP) ---\n",
449
+ "from torch.cuda.amp import GradScaler, autocast\n",
450
+ "\n",
451
+ "# --- OPTIMIZATION: Configuration updated for new batch size ---\n",
452
+ "num_epochs = 30\n",
453
+ "patience = 5\n",
454
+ "# Batch size was 32, now 128 (4x). Scale learning rate by 4x.\n",
455
+ "learning_rate = 4e-4 # Previously 1e-4\n",
456
+ "weight_decay = 1e-4\n",
457
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
458
+ "print(f\"Using device: {device}\")\n",
459
+ "print(f\"Learning rate adjusted to {learning_rate} for larger batch size.\")\n",
460
+ "\n",
461
+ "# Define Focal Loss\n",
462
+ "class FocalLoss(nn.Module):\n",
463
+ " def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):\n",
464
+ " super(FocalLoss, self).__init__()\n",
465
+ " self.alpha = alpha\n",
466
+ " self.gamma = gamma\n",
467
+ " self.reduction = reduction\n",
468
+ " self.criterion = nn.BCEWithLogitsLoss(reduction='none')\n",
469
+ "\n",
470
+ " def forward(self, inputs, targets):\n",
471
+ " BCE_loss = self.criterion(inputs, targets)\n",
472
+ " pt = torch.exp(-BCE_loss)\n",
473
+ " F_loss = self.alpha * (1 - pt)**self.gamma * BCE_loss\n",
474
+ " if self.reduction == 'mean':\n",
475
+ " return torch.mean(F_loss)\n",
476
+ " elif self.reduction == 'sum':\n",
477
+ " return torch.sum(F_loss)\n",
478
+ " else:\n",
479
+ " return F_loss\n",
480
+ "\n",
481
+ "# --- Model Definition (No changes needed here) ---\n",
482
+ "class SELayer(nn.Module):\n",
483
+ " def __init__(self, channel, reduction=16):\n",
484
+ " super(SELayer, self).__init__()\n",
485
+ " self.avg_pool = nn.AdaptiveAvgPool2d(1)\n",
486
+ " self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid())\n",
487
+ " def forward(self, x):\n",
488
+ " b, c, _, _ = x.size(); y = self.avg_pool(x).view(b, c); y = self.fc(y).view(b, c, 1, 1); return x * y.expand_as(x)\n",
489
+ "\n",
490
+ "class SEBottleneck(Bottleneck):\n",
491
+ " expansion = 4\n",
492
+ " def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None, se_reduction=16):\n",
493
+ " super(SEBottleneck, self).__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer)\n",
494
+ " self.se = SELayer(planes * self.expansion, reduction=se_reduction)\n",
495
+ " def forward(self, x):\n",
496
+ " identity = x; out = self.conv1(x); out = self.bn1(out); out = self.relu(out); out = self.conv2(out); out = self.bn2(out); out = self.relu(out); out = self.conv3(out); out = self.bn3(out); out = self.se(out)\n",
497
+ " if self.downsample is not None: identity = self.downsample(x)\n",
498
+ " out += identity; out = self.relu(out); return out\n",
499
+ "\n",
500
+ "def get_seresnext50(num_classes=1, se_reduction=16):\n",
501
+ " model = models.resnext50_32x4d(pretrained=True); base_width = model.base_width\n",
502
+ " def replace_bottlenecks(module, se_reduction_ratio, base_width):\n",
503
+ " for name, child_module in module.named_children():\n",
504
+ " if isinstance(child_module, Bottleneck):\n",
505
+ " inplanes = child_module.conv1.in_channels; planes = child_module.conv3.out_channels // child_module.expansion; stride = child_module.stride; downsample = child_module.downsample; groups = child_module.conv2.groups; dilation = child_module.conv2.dilation[0]\n",
506
+ " new_bottleneck = SEBottleneck(inplanes=inplanes, planes=planes, stride=stride, downsample=downsample, groups=groups, base_width=base_width, dilation=dilation, se_reduction=se_reduction_ratio)\n",
507
+ " new_bottleneck.load_state_dict(child_module.state_dict(), strict=False); setattr(module, name, new_bottleneck)\n",
508
+ " else: replace_bottlenecks(child_module, se_reduction_ratio, base_width)\n",
509
+ " replace_bottlenecks(model, se_reduction, base_width)\n",
510
+ " in_features = model.fc.in_features; model.fc = nn.Linear(in_features, num_classes); return model\n",
511
+ "\n",
512
+ "# --- Initialization (No changes needed here) ---\n",
513
+ "se_reduction_ratio = 16\n",
514
+ "model = get_seresnext50(num_classes=1, se_reduction=se_reduction_ratio)\n",
515
+ "if torch.cuda.device_count() > 1:\n",
516
+ " print(f\"Using {torch.cuda.device_count()} GPUs: {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}\")\n",
517
+ " model = nn.DataParallel(model)\n",
518
+ "model = model.to(device)\n",
519
+ "criterion = FocalLoss(alpha=0.25, gamma=2.0)\n",
520
+ "optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
521
+ "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)\n",
522
+ "\n",
523
+ "# --- OPTIMIZATION: Training function with Mixed Precision ---\n",
524
+ "def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, patience, se_reduction=16):\n",
525
+ " best_val_loss = float('inf')\n",
526
+ " patience_counter = 0\n",
527
+ " history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'lr': []}\n",
528
+ "\n",
529
+ " # Initialize the gradient scaler for mixed precision\n",
530
+ " scaler = GradScaler()\n",
531
+ " print(\"Training with Automatic Mixed Precision (AMP) enabled.\")\n",
532
+ "\n",
533
+ " for epoch in range(num_epochs):\n",
534
+ " # Training phase\n",
535
+ " model.train()\n",
536
+ " train_loss, train_correct, train_total = 0.0, 0, 0\n",
537
+ " pbar_train = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')\n",
538
+ " for inputs, labels in pbar_train:\n",
539
+ " inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True).float().view(-1, 1)\n",
540
+ " optimizer.zero_grad()\n",
541
+ "\n",
542
+ " # Use autocast for the forward pass\n",
543
+ " with autocast():\n",
544
+ " outputs = model(inputs)\n",
545
+ " loss = criterion(outputs, labels)\n",
546
+ "\n",
547
+ " # Scale the loss and call backward()\n",
548
+ " scaler.scale(loss).backward()\n",
549
+ " # Unscale the gradients and call optimizer.step()\n",
550
+ " scaler.step(optimizer)\n",
551
+ " # Update the scaler for the next iteration\n",
552
+ " scaler.update()\n",
553
+ "\n",
554
+ " train_loss += loss.item() * inputs.size(0)\n",
555
+ " predicted = (torch.sigmoid(outputs) > 0.5).float()\n",
556
+ " train_total += labels.size(0)\n",
557
+ " train_correct += (predicted == labels).sum().item()\n",
558
+ " pbar_train.set_postfix({'loss': f'{loss.item():.4f}'})\n",
559
+ "\n",
560
+ " # ... (rest of the training loop is the same) ...\n",
561
+ "\n",
562
+ " epoch_train_loss = train_loss / len(train_loader.dataset)\n",
563
+ " epoch_train_acc = train_correct / train_total\n",
564
+ " history['train_loss'].append(epoch_train_loss)\n",
565
+ " history['train_acc'].append(epoch_train_acc)\n",
566
+ " history['lr'].append(optimizer.param_groups[0]['lr'])\n",
567
+ "\n",
568
+ " # Validation phase (autocast is recommended here too for consistency and speed)\n",
569
+ " model.eval()\n",
570
+ " val_loss, val_correct, val_total = 0.0, 0, 0\n",
571
+ " with torch.no_grad():\n",
572
+ " pbar_val = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')\n",
573
+ " for inputs, labels in pbar_val:\n",
574
+ " inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True).float().view(-1, 1)\n",
575
+ " with autocast():\n",
576
+ " outputs = model(inputs)\n",
577
+ " loss = criterion(outputs, labels)\n",
578
+ "\n",
579
+ " val_loss += loss.item() * inputs.size(0)\n",
580
+ " predicted = (torch.sigmoid(outputs) > 0.5).float()\n",
581
+ " val_total += labels.size(0)\n",
582
+ " val_correct += (predicted == labels).sum().item()\n",
583
+ " pbar_val.set_postfix({'loss': f'{loss.item():.4f}'})\n",
584
+ "\n",
585
+ " epoch_val_loss = val_loss / len(val_loader.dataset)\n",
586
+ " epoch_val_acc = val_correct / val_total\n",
587
+ " history['val_loss'].append(epoch_val_loss)\n",
588
+ " history['val_acc'].append(epoch_val_acc)\n",
589
+ "\n",
590
+ " print(f\"Epoch {epoch+1}/{num_epochs}: Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}, Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}\")\n",
591
+ " scheduler.step(epoch_val_loss)\n",
592
+ "\n",
593
+ " if epoch_val_loss < best_val_loss:\n",
594
+ " best_val_loss = epoch_val_loss\n",
595
+ " patience_counter = 0\n",
596
+ " checkpoint = {'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'epoch': epoch, 'val_loss': best_val_loss, 'history': history, 'config': {'learning_rate': learning_rate, 'weight_decay': weight_decay, 'batch_size': train_loader.batch_size, 'se_reduction_ratio': se_reduction}}\n",
597
+ " torch.save(checkpoint, 'best_seresnext50_model.pth')\n",
598
+ " print(f\"Model improved! Saved checkpoint at epoch {epoch+1}\")\n",
599
+ " else:\n",
600
+ " patience_counter += 1\n",
601
+ " print(f\"Model didn't improve for {patience_counter}/{patience} epochs\")\n",
602
+ " if patience_counter >= patience:\n",
603
+ " print(f\"Early stopping triggered after {epoch+1} epochs\")\n",
604
+ " break\n",
605
+ "\n",
606
+ " # Save final model regardless of performance\n",
607
+ " final_checkpoint = {'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'epoch': epoch, 'val_loss': epoch_val_loss, 'history': history, 'config': {'learning_rate': learning_rate, 'weight_decay': weight_decay, 'batch_size': train_loader.batch_size, 'se_reduction_ratio': se_reduction}}\n",
608
+ " torch.save(final_checkpoint, 'final_seresnext50_model.pth')\n",
609
+ " with open('seresnext50_training_history.pkl', 'wb') as f:\n",
610
+ " pickle.dump(history, f)\n",
611
+ " if os.path.exists('best_seresnext50_model.pth'):\n",
612
+ " checkpoint = torch.load('best_seresnext50_model.pth')\n",
613
+ " model.load_state_dict(checkpoint['model_state_dict'])\n",
614
+ " return model, history\n",
615
+ "\n",
616
+ "# Start training\n",
617
+ "model, history = train_model(\n",
618
+ " model=model,\n",
619
+ " train_loader=train_loader,\n",
620
+ " val_loader=val_loader,\n",
621
+ " criterion=criterion,\n",
622
+ " optimizer=optimizer,\n",
623
+ " scheduler=scheduler,\n",
624
+ " num_epochs=num_epochs,\n",
625
+ " patience=patience,\n",
626
+ " se_reduction=se_reduction_ratio\n",
627
+ ")\n",
628
+ "\n",
629
+ "# Plot training history\n",
630
+ "plt.figure(figsize=(15, 10))\n",
631
+ "plt.subplot(2, 2, 1); plt.plot(history['train_loss'], label='Training Loss'); plt.plot(history['val_loss'], label='Validation Loss'); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.title('Loss Curves')\n",
632
+ "plt.subplot(2, 2, 2); plt.plot(history['train_acc'], label='Training Accuracy'); plt.plot(history['val_acc'], label='Validation Accuracy'); plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.legend(); plt.title('Accuracy Curves')\n",
633
+ "plt.subplot(2, 2, 3); plt.plot(history['lr']); plt.xlabel('Epoch'); plt.ylabel('Learning Rate'); plt.title('Learning Rate Schedule')\n",
634
+ "plt.subplot(2, 2, 4); plt.axis('off'); info_text = (f\"Model: SE-ResNeXt50\\nSE Reduction Ratio: {se_reduction_ratio}\\nOptimizer: AdamW\\nInitial LR: {learning_rate}\\nWeight Decay: {weight_decay}\\nLoss: Focal Loss (α={criterion.alpha}, γ={criterion.gamma})\"); plt.text(0.1, 0.5, info_text, fontsize=12)\n",
635
+ "plt.tight_layout(); plt.savefig(f'seresnext50_r{se_reduction_ratio}_training_plots.png', dpi=300); plt.show()\n",
636
+ "print(f\"Training completed! SE-ResNeXt50 with reduction ratio {se_reduction_ratio}\")"
637
+ ],
638
+ "metadata": {
639
+ "trusted": true,
640
+ "execution": {
641
+ "iopub.status.busy": "2025-06-11T10:05:48.487763Z",
642
+ "iopub.execute_input": "2025-06-11T10:05:48.488061Z",
643
+ "iopub.status.idle": "2025-06-11T10:28:31.578634Z",
644
+ "shell.execute_reply.started": "2025-06-11T10:05:48.488041Z",
645
+ "shell.execute_reply": "2025-06-11T10:28:31.577803Z"
646
+ },
647
+ "id": "HjN6hDXCgIQ0"
648
+ },
649
+ "outputs": [],
650
+ "execution_count": null
651
+ }
652
+ ]
653
+ }