akushwaha.ext commited on
Commit
7db32ef
·
1 Parent(s): 8bf68a3

ct images also added

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sample_images/* filter=lfs diff=lfs merge=lfs -text
Brain_Tumor_MRI_Diagnose_plus_openai.ipynb → Brain_Tumor_MRI_CT_Images_Diagnose_plus_openai.ipynb RENAMED
@@ -1,418 +1,82 @@
1
  {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "provenance": [],
7
- "gpuType": "T4"
8
- },
9
- "kernelspec": {
10
- "name": "python3",
11
- "display_name": "Python 3"
12
- },
13
- "language_info": {
14
- "name": "python"
15
- },
16
- "accelerator": "GPU",
17
- "widgets": {
18
- "application/vnd.jupyter.widget-state+json": {
19
- "1b0dc97d07a94fca9a24670e3813ff29": {
20
- "model_module": "@jupyter-widgets/controls",
21
- "model_name": "HBoxModel",
22
- "model_module_version": "1.5.0",
23
- "state": {
24
- "_dom_classes": [],
25
- "_model_module": "@jupyter-widgets/controls",
26
- "_model_module_version": "1.5.0",
27
- "_model_name": "HBoxModel",
28
- "_view_count": null,
29
- "_view_module": "@jupyter-widgets/controls",
30
- "_view_module_version": "1.5.0",
31
- "_view_name": "HBoxView",
32
- "box_style": "",
33
- "children": [
34
- "IPY_MODEL_d19c38e4492c4b9bb4c03fb285df62a2",
35
- "IPY_MODEL_815e13d941ee4a26929ca4ca01c8c615",
36
- "IPY_MODEL_1fa8e074d5954a6daed1c025aa8191ab"
37
- ],
38
- "layout": "IPY_MODEL_489118d586fb4c0bb2a148796d39d0e3"
39
- }
40
- },
41
- "d19c38e4492c4b9bb4c03fb285df62a2": {
42
- "model_module": "@jupyter-widgets/controls",
43
- "model_name": "HTMLModel",
44
- "model_module_version": "1.5.0",
45
- "state": {
46
- "_dom_classes": [],
47
- "_model_module": "@jupyter-widgets/controls",
48
- "_model_module_version": "1.5.0",
49
- "_model_name": "HTMLModel",
50
- "_view_count": null,
51
- "_view_module": "@jupyter-widgets/controls",
52
- "_view_module_version": "1.5.0",
53
- "_view_name": "HTMLView",
54
- "description": "",
55
- "description_tooltip": null,
56
- "layout": "IPY_MODEL_1764933ac18c4c228ad8a852719c4f1a",
57
- "placeholder": "​",
58
- "style": "IPY_MODEL_07db2fdc39c849bf9fb4be93aca7cf4e",
59
- "value": "model.safetensors: 100%"
60
- }
61
- },
62
- "815e13d941ee4a26929ca4ca01c8c615": {
63
- "model_module": "@jupyter-widgets/controls",
64
- "model_name": "FloatProgressModel",
65
- "model_module_version": "1.5.0",
66
- "state": {
67
- "_dom_classes": [],
68
- "_model_module": "@jupyter-widgets/controls",
69
- "_model_module_version": "1.5.0",
70
- "_model_name": "FloatProgressModel",
71
- "_view_count": null,
72
- "_view_module": "@jupyter-widgets/controls",
73
- "_view_module_version": "1.5.0",
74
- "_view_name": "ProgressView",
75
- "bar_style": "success",
76
- "description": "",
77
- "description_tooltip": null,
78
- "layout": "IPY_MODEL_6b82e7b5673a4ad4bb112905754d0cda",
79
- "max": 21355344,
80
- "min": 0,
81
- "orientation": "horizontal",
82
- "style": "IPY_MODEL_693ce4424a8240a38ccebaabf030c5ff",
83
- "value": 21355344
84
- }
85
- },
86
- "1fa8e074d5954a6daed1c025aa8191ab": {
87
- "model_module": "@jupyter-widgets/controls",
88
- "model_name": "HTMLModel",
89
- "model_module_version": "1.5.0",
90
- "state": {
91
- "_dom_classes": [],
92
- "_model_module": "@jupyter-widgets/controls",
93
- "_model_module_version": "1.5.0",
94
- "_model_name": "HTMLModel",
95
- "_view_count": null,
96
- "_view_module": "@jupyter-widgets/controls",
97
- "_view_module_version": "1.5.0",
98
- "_view_name": "HTMLView",
99
- "description": "",
100
- "description_tooltip": null,
101
- "layout": "IPY_MODEL_31420da8c458472c85381cf0e7a7a242",
102
- "placeholder": "​",
103
- "style": "IPY_MODEL_bd6b1fcf472a4860bbd8adb8a7f8c4ab",
104
- "value": " 21.4M/21.4M [00:00<00:00, 53.6MB/s]"
105
- }
106
- },
107
- "489118d586fb4c0bb2a148796d39d0e3": {
108
- "model_module": "@jupyter-widgets/base",
109
- "model_name": "LayoutModel",
110
- "model_module_version": "1.2.0",
111
- "state": {
112
- "_model_module": "@jupyter-widgets/base",
113
- "_model_module_version": "1.2.0",
114
- "_model_name": "LayoutModel",
115
- "_view_count": null,
116
- "_view_module": "@jupyter-widgets/base",
117
- "_view_module_version": "1.2.0",
118
- "_view_name": "LayoutView",
119
- "align_content": null,
120
- "align_items": null,
121
- "align_self": null,
122
- "border": null,
123
- "bottom": null,
124
- "display": null,
125
- "flex": null,
126
- "flex_flow": null,
127
- "grid_area": null,
128
- "grid_auto_columns": null,
129
- "grid_auto_flow": null,
130
- "grid_auto_rows": null,
131
- "grid_column": null,
132
- "grid_gap": null,
133
- "grid_row": null,
134
- "grid_template_areas": null,
135
- "grid_template_columns": null,
136
- "grid_template_rows": null,
137
- "height": null,
138
- "justify_content": null,
139
- "justify_items": null,
140
- "left": null,
141
- "margin": null,
142
- "max_height": null,
143
- "max_width": null,
144
- "min_height": null,
145
- "min_width": null,
146
- "object_fit": null,
147
- "object_position": null,
148
- "order": null,
149
- "overflow": null,
150
- "overflow_x": null,
151
- "overflow_y": null,
152
- "padding": null,
153
- "right": null,
154
- "top": null,
155
- "visibility": null,
156
- "width": null
157
- }
158
- },
159
- "1764933ac18c4c228ad8a852719c4f1a": {
160
- "model_module": "@jupyter-widgets/base",
161
- "model_name": "LayoutModel",
162
- "model_module_version": "1.2.0",
163
- "state": {
164
- "_model_module": "@jupyter-widgets/base",
165
- "_model_module_version": "1.2.0",
166
- "_model_name": "LayoutModel",
167
- "_view_count": null,
168
- "_view_module": "@jupyter-widgets/base",
169
- "_view_module_version": "1.2.0",
170
- "_view_name": "LayoutView",
171
- "align_content": null,
172
- "align_items": null,
173
- "align_self": null,
174
- "border": null,
175
- "bottom": null,
176
- "display": null,
177
- "flex": null,
178
- "flex_flow": null,
179
- "grid_area": null,
180
- "grid_auto_columns": null,
181
- "grid_auto_flow": null,
182
- "grid_auto_rows": null,
183
- "grid_column": null,
184
- "grid_gap": null,
185
- "grid_row": null,
186
- "grid_template_areas": null,
187
- "grid_template_columns": null,
188
- "grid_template_rows": null,
189
- "height": null,
190
- "justify_content": null,
191
- "justify_items": null,
192
- "left": null,
193
- "margin": null,
194
- "max_height": null,
195
- "max_width": null,
196
- "min_height": null,
197
- "min_width": null,
198
- "object_fit": null,
199
- "object_position": null,
200
- "order": null,
201
- "overflow": null,
202
- "overflow_x": null,
203
- "overflow_y": null,
204
- "padding": null,
205
- "right": null,
206
- "top": null,
207
- "visibility": null,
208
- "width": null
209
- }
210
- },
211
- "07db2fdc39c849bf9fb4be93aca7cf4e": {
212
- "model_module": "@jupyter-widgets/controls",
213
- "model_name": "DescriptionStyleModel",
214
- "model_module_version": "1.5.0",
215
- "state": {
216
- "_model_module": "@jupyter-widgets/controls",
217
- "_model_module_version": "1.5.0",
218
- "_model_name": "DescriptionStyleModel",
219
- "_view_count": null,
220
- "_view_module": "@jupyter-widgets/base",
221
- "_view_module_version": "1.2.0",
222
- "_view_name": "StyleView",
223
- "description_width": ""
224
- }
225
- },
226
- "6b82e7b5673a4ad4bb112905754d0cda": {
227
- "model_module": "@jupyter-widgets/base",
228
- "model_name": "LayoutModel",
229
- "model_module_version": "1.2.0",
230
- "state": {
231
- "_model_module": "@jupyter-widgets/base",
232
- "_model_module_version": "1.2.0",
233
- "_model_name": "LayoutModel",
234
- "_view_count": null,
235
- "_view_module": "@jupyter-widgets/base",
236
- "_view_module_version": "1.2.0",
237
- "_view_name": "LayoutView",
238
- "align_content": null,
239
- "align_items": null,
240
- "align_self": null,
241
- "border": null,
242
- "bottom": null,
243
- "display": null,
244
- "flex": null,
245
- "flex_flow": null,
246
- "grid_area": null,
247
- "grid_auto_columns": null,
248
- "grid_auto_flow": null,
249
- "grid_auto_rows": null,
250
- "grid_column": null,
251
- "grid_gap": null,
252
- "grid_row": null,
253
- "grid_template_areas": null,
254
- "grid_template_columns": null,
255
- "grid_template_rows": null,
256
- "height": null,
257
- "justify_content": null,
258
- "justify_items": null,
259
- "left": null,
260
- "margin": null,
261
- "max_height": null,
262
- "max_width": null,
263
- "min_height": null,
264
- "min_width": null,
265
- "object_fit": null,
266
- "object_position": null,
267
- "order": null,
268
- "overflow": null,
269
- "overflow_x": null,
270
- "overflow_y": null,
271
- "padding": null,
272
- "right": null,
273
- "top": null,
274
- "visibility": null,
275
- "width": null
276
- }
277
- },
278
- "693ce4424a8240a38ccebaabf030c5ff": {
279
- "model_module": "@jupyter-widgets/controls",
280
- "model_name": "ProgressStyleModel",
281
- "model_module_version": "1.5.0",
282
- "state": {
283
- "_model_module": "@jupyter-widgets/controls",
284
- "_model_module_version": "1.5.0",
285
- "_model_name": "ProgressStyleModel",
286
- "_view_count": null,
287
- "_view_module": "@jupyter-widgets/base",
288
- "_view_module_version": "1.2.0",
289
- "_view_name": "StyleView",
290
- "bar_color": null,
291
- "description_width": ""
292
- }
293
- },
294
- "31420da8c458472c85381cf0e7a7a242": {
295
- "model_module": "@jupyter-widgets/base",
296
- "model_name": "LayoutModel",
297
- "model_module_version": "1.2.0",
298
- "state": {
299
- "_model_module": "@jupyter-widgets/base",
300
- "_model_module_version": "1.2.0",
301
- "_model_name": "LayoutModel",
302
- "_view_count": null,
303
- "_view_module": "@jupyter-widgets/base",
304
- "_view_module_version": "1.2.0",
305
- "_view_name": "LayoutView",
306
- "align_content": null,
307
- "align_items": null,
308
- "align_self": null,
309
- "border": null,
310
- "bottom": null,
311
- "display": null,
312
- "flex": null,
313
- "flex_flow": null,
314
- "grid_area": null,
315
- "grid_auto_columns": null,
316
- "grid_auto_flow": null,
317
- "grid_auto_rows": null,
318
- "grid_column": null,
319
- "grid_gap": null,
320
- "grid_row": null,
321
- "grid_template_areas": null,
322
- "grid_template_columns": null,
323
- "grid_template_rows": null,
324
- "height": null,
325
- "justify_content": null,
326
- "justify_items": null,
327
- "left": null,
328
- "margin": null,
329
- "max_height": null,
330
- "max_width": null,
331
- "min_height": null,
332
- "min_width": null,
333
- "object_fit": null,
334
- "object_position": null,
335
- "order": null,
336
- "overflow": null,
337
- "overflow_x": null,
338
- "overflow_y": null,
339
- "padding": null,
340
- "right": null,
341
- "top": null,
342
- "visibility": null,
343
- "width": null
344
- }
345
- },
346
- "bd6b1fcf472a4860bbd8adb8a7f8c4ab": {
347
- "model_module": "@jupyter-widgets/controls",
348
- "model_name": "DescriptionStyleModel",
349
- "model_module_version": "1.5.0",
350
- "state": {
351
- "_model_module": "@jupyter-widgets/controls",
352
- "_model_module_version": "1.5.0",
353
- "_model_name": "DescriptionStyleModel",
354
- "_view_count": null,
355
- "_view_module": "@jupyter-widgets/base",
356
- "_view_module_version": "1.2.0",
357
- "_view_name": "StyleView",
358
- "description_width": ""
359
- }
360
- }
361
- }
362
- }
363
- },
364
  "cells": [
365
  {
366
  "cell_type": "code",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  "source": [
368
  "# Colab: install required packages\n",
369
  "!pip install -q kagglehub timm torch torchvision scikit-learn pandas pillow matplotlib gradio openai pytorch-lightning"
370
- ],
 
 
 
 
371
  "metadata": {
372
  "colab": {
373
  "base_uri": "https://localhost:8080/"
374
  },
375
- "id": "yViCfhqXTD-l",
376
- "outputId": "901b3218-85e1-4cc5-d380-839d8a3fca86"
377
  },
378
- "execution_count": 10,
379
  "outputs": [
380
  {
381
  "output_type": "stream",
382
  "name": "stdout",
383
  "text": [
384
- "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/832.4 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m832.4/832.4 kB\u001b[0m \u001b[31m25.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
385
- "\u001b[?25h\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/983.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m983.2/983.2 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
386
- "\u001b[?25h"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  ]
388
- }
389
- ]
390
- },
391
- {
392
- "cell_type": "code",
393
- "execution_count": 11,
394
- "metadata": {
395
- "colab": {
396
- "base_uri": "https://localhost:8080/"
397
  },
398
- "id": "B7iabh3oOyNe",
399
- "outputId": "b487c9fb-aa5c-4457-e23b-067c427c15ab"
400
- },
401
- "outputs": [
402
  {
403
  "output_type": "stream",
404
  "name": "stdout",
405
  "text": [
406
- "Using Colab cache for faster access to the 'brain-tumor-classification-mri' dataset.\n",
407
- "Dataset path: /kaggle/input/brain-tumor-classification-mri\n",
408
- "Training\n",
409
- "Testing\n"
410
  ]
411
  }
412
  ],
413
  "source": [
414
  "import kagglehub\n",
415
- "dataset_ref = \"sartajbhuvaji/brain-tumor-classification-mri\"\n",
416
  "path = kagglehub.dataset_download(dataset_ref)\n",
417
  "print(\"Dataset path:\", path)\n",
418
  "\n",
@@ -424,18 +88,14 @@
424
  },
425
  {
426
  "cell_type": "code",
427
- "source": [
428
- "import torch\n",
429
- "print(torch.cuda.is_available())"
430
- ],
431
  "metadata": {
432
  "colab": {
433
  "base_uri": "https://localhost:8080/"
434
  },
435
  "id": "gq9zpPkmZVK7",
436
- "outputId": "1afaff9b-7d25-4697-eb7a-8e7c77255fb7"
437
  },
438
- "execution_count": 20,
439
  "outputs": [
440
  {
441
  "output_type": "stream",
@@ -444,431 +104,274 @@
444
  "True\n"
445
  ]
446
  }
 
 
 
 
447
  ]
448
  },
449
  {
450
  "cell_type": "code",
 
 
 
 
 
451
  "source": [
452
  "# Create CSV (multi-class) from folders (in-memory)\n",
453
  "\n",
454
  "# Change DATA_ROOT to the folder that directly contains glioma_tumor, meningioma_tumor, pituitary_tumor, no_tumor."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  ],
456
  "metadata": {
457
- "id": "YvSgspBGU3iG"
 
 
 
 
458
  },
459
- "execution_count": 12,
460
- "outputs": []
 
 
 
 
 
 
 
 
461
  },
462
  {
463
  "cell_type": "code",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  "source": [
 
 
 
465
  "import os, pandas as pd\n",
 
466
  "\n",
467
- "# <-- EDIT this path to match your dataset root that has the 4 class folders\n",
468
- "DATA_ROOT = \"/kaggle/input/brain-tumor-classification-mri/Training\"\n",
469
  "\n",
470
- "LABEL_MAP = {\n",
471
- " \"glioma_tumor\": 0,\n",
472
- " \"meningioma_tumor\": 1,\n",
473
- " \"pituitary_tumor\": 2,\n",
474
- " \"no_tumor\": 3\n",
475
- "}\n",
476
  "\n",
477
  "rows = []\n",
478
- "for folder in sorted(os.listdir(DATA_ROOT)):\n",
479
- " folder_path = os.path.join(DATA_ROOT, folder)\n",
480
- " if not os.path.isdir(folder_path):\n",
481
- " continue\n",
482
- " if folder not in LABEL_MAP:\n",
483
- " print(\"Skipping unknown folder:\", folder)\n",
484
  " continue\n",
485
- " label = LABEL_MAP[folder]\n",
486
- " for fname in os.listdir(folder_path):\n",
487
- " if fname.lower().endswith((\".png\",\".jpg\",\".jpeg\")):\n",
488
- " rows.append({\"image\": os.path.join(folder_path, fname), \"label\": label})\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  "\n",
490
  "df = pd.DataFrame(rows).sample(frac=1, random_state=42).reset_index(drop=True)\n",
491
  "print(\"Total images:\", len(df))\n",
492
  "print(df.label.value_counts())\n",
493
- "df.head()\n"
494
- ],
 
 
 
 
 
 
495
  "metadata": {
496
  "colab": {
497
- "base_uri": "https://localhost:8080/",
498
- "height": 367
499
  },
500
- "id": "1grjHP2ZU5st",
501
- "outputId": "a9c67cab-ac1d-46a1-ff2a-d99ed94d6c7a"
502
  },
503
- "execution_count": 13,
504
  "outputs": [
505
  {
506
  "output_type": "stream",
507
  "name": "stdout",
508
  "text": [
509
- "Total images: 2870\n",
510
- "label\n",
511
- "2 827\n",
512
- "0 826\n",
513
- "1 822\n",
514
- "3 395\n",
515
- "Name: count, dtype: int64\n"
516
  ]
517
- },
518
- {
519
- "output_type": "execute_result",
520
- "data": {
521
- "text/plain": [
522
- " image label\n",
523
- "0 /kaggle/input/brain-tumor-classification-mri/T... 0\n",
524
- "1 /kaggle/input/brain-tumor-classification-mri/T... 2\n",
525
- "2 /kaggle/input/brain-tumor-classification-mri/T... 0\n",
526
- "3 /kaggle/input/brain-tumor-classification-mri/T... 0\n",
527
- "4 /kaggle/input/brain-tumor-classification-mri/T... 2"
528
- ],
529
- "text/html": [
530
- "\n",
531
- " <div id=\"df-870f654a-e6d9-4f6c-938c-91d0b4e98446\" class=\"colab-df-container\">\n",
532
- " <div>\n",
533
- "<style scoped>\n",
534
- " .dataframe tbody tr th:only-of-type {\n",
535
- " vertical-align: middle;\n",
536
- " }\n",
537
- "\n",
538
- " .dataframe tbody tr th {\n",
539
- " vertical-align: top;\n",
540
- " }\n",
541
- "\n",
542
- " .dataframe thead th {\n",
543
- " text-align: right;\n",
544
- " }\n",
545
- "</style>\n",
546
- "<table border=\"1\" class=\"dataframe\">\n",
547
- " <thead>\n",
548
- " <tr style=\"text-align: right;\">\n",
549
- " <th></th>\n",
550
- " <th>image</th>\n",
551
- " <th>label</th>\n",
552
- " </tr>\n",
553
- " </thead>\n",
554
- " <tbody>\n",
555
- " <tr>\n",
556
- " <th>0</th>\n",
557
- " <td>/kaggle/input/brain-tumor-classification-mri/T...</td>\n",
558
- " <td>0</td>\n",
559
- " </tr>\n",
560
- " <tr>\n",
561
- " <th>1</th>\n",
562
- " <td>/kaggle/input/brain-tumor-classification-mri/T...</td>\n",
563
- " <td>2</td>\n",
564
- " </tr>\n",
565
- " <tr>\n",
566
- " <th>2</th>\n",
567
- " <td>/kaggle/input/brain-tumor-classification-mri/T...</td>\n",
568
- " <td>0</td>\n",
569
- " </tr>\n",
570
- " <tr>\n",
571
- " <th>3</th>\n",
572
- " <td>/kaggle/input/brain-tumor-classification-mri/T...</td>\n",
573
- " <td>0</td>\n",
574
- " </tr>\n",
575
- " <tr>\n",
576
- " <th>4</th>\n",
577
- " <td>/kaggle/input/brain-tumor-classification-mri/T...</td>\n",
578
- " <td>2</td>\n",
579
- " </tr>\n",
580
- " </tbody>\n",
581
- "</table>\n",
582
- "</div>\n",
583
- " <div class=\"colab-df-buttons\">\n",
584
- "\n",
585
- " <div class=\"colab-df-container\">\n",
586
- " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-870f654a-e6d9-4f6c-938c-91d0b4e98446')\"\n",
587
- " title=\"Convert this dataframe to an interactive table.\"\n",
588
- " style=\"display:none;\">\n",
589
- "\n",
590
- " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
591
- " <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
592
- " </svg>\n",
593
- " </button>\n",
594
- "\n",
595
- " <style>\n",
596
- " .colab-df-container {\n",
597
- " display:flex;\n",
598
- " gap: 12px;\n",
599
- " }\n",
600
- "\n",
601
- " .colab-df-convert {\n",
602
- " background-color: #E8F0FE;\n",
603
- " border: none;\n",
604
- " border-radius: 50%;\n",
605
- " cursor: pointer;\n",
606
- " display: none;\n",
607
- " fill: #1967D2;\n",
608
- " height: 32px;\n",
609
- " padding: 0 0 0 0;\n",
610
- " width: 32px;\n",
611
- " }\n",
612
- "\n",
613
- " .colab-df-convert:hover {\n",
614
- " background-color: #E2EBFA;\n",
615
- " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
616
- " fill: #174EA6;\n",
617
- " }\n",
618
- "\n",
619
- " .colab-df-buttons div {\n",
620
- " margin-bottom: 4px;\n",
621
- " }\n",
622
- "\n",
623
- " [theme=dark] .colab-df-convert {\n",
624
- " background-color: #3B4455;\n",
625
- " fill: #D2E3FC;\n",
626
- " }\n",
627
- "\n",
628
- " [theme=dark] .colab-df-convert:hover {\n",
629
- " background-color: #434B5C;\n",
630
- " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
631
- " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
632
- " fill: #FFFFFF;\n",
633
- " }\n",
634
- " </style>\n",
635
- "\n",
636
- " <script>\n",
637
- " const buttonEl =\n",
638
- " document.querySelector('#df-870f654a-e6d9-4f6c-938c-91d0b4e98446 button.colab-df-convert');\n",
639
- " buttonEl.style.display =\n",
640
- " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
641
- "\n",
642
- " async function convertToInteractive(key) {\n",
643
- " const element = document.querySelector('#df-870f654a-e6d9-4f6c-938c-91d0b4e98446');\n",
644
- " const dataTable =\n",
645
- " await google.colab.kernel.invokeFunction('convertToInteractive',\n",
646
- " [key], {});\n",
647
- " if (!dataTable) return;\n",
648
- "\n",
649
- " const docLinkHtml = 'Like what you see? Visit the ' +\n",
650
- " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
651
- " + ' to learn more about interactive tables.';\n",
652
- " element.innerHTML = '';\n",
653
- " dataTable['output_type'] = 'display_data';\n",
654
- " await google.colab.output.renderOutput(dataTable, element);\n",
655
- " const docLink = document.createElement('div');\n",
656
- " docLink.innerHTML = docLinkHtml;\n",
657
- " element.appendChild(docLink);\n",
658
- " }\n",
659
- " </script>\n",
660
- " </div>\n",
661
- "\n",
662
- "\n",
663
- " <div id=\"df-39e3087a-f1c1-4313-b033-254945753780\">\n",
664
- " <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-39e3087a-f1c1-4313-b033-254945753780')\"\n",
665
- " title=\"Suggest charts\"\n",
666
- " style=\"display:none;\">\n",
667
- "\n",
668
- "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
669
- " width=\"24px\">\n",
670
- " <g>\n",
671
- " <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
672
- " </g>\n",
673
- "</svg>\n",
674
- " </button>\n",
675
- "\n",
676
- "<style>\n",
677
- " .colab-df-quickchart {\n",
678
- " --bg-color: #E8F0FE;\n",
679
- " --fill-color: #1967D2;\n",
680
- " --hover-bg-color: #E2EBFA;\n",
681
- " --hover-fill-color: #174EA6;\n",
682
- " --disabled-fill-color: #AAA;\n",
683
- " --disabled-bg-color: #DDD;\n",
684
- " }\n",
685
- "\n",
686
- " [theme=dark] .colab-df-quickchart {\n",
687
- " --bg-color: #3B4455;\n",
688
- " --fill-color: #D2E3FC;\n",
689
- " --hover-bg-color: #434B5C;\n",
690
- " --hover-fill-color: #FFFFFF;\n",
691
- " --disabled-bg-color: #3B4455;\n",
692
- " --disabled-fill-color: #666;\n",
693
- " }\n",
694
- "\n",
695
- " .colab-df-quickchart {\n",
696
- " background-color: var(--bg-color);\n",
697
- " border: none;\n",
698
- " border-radius: 50%;\n",
699
- " cursor: pointer;\n",
700
- " display: none;\n",
701
- " fill: var(--fill-color);\n",
702
- " height: 32px;\n",
703
- " padding: 0;\n",
704
- " width: 32px;\n",
705
- " }\n",
706
- "\n",
707
- " .colab-df-quickchart:hover {\n",
708
- " background-color: var(--hover-bg-color);\n",
709
- " box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
710
- " fill: var(--button-hover-fill-color);\n",
711
- " }\n",
712
- "\n",
713
- " .colab-df-quickchart-complete:disabled,\n",
714
- " .colab-df-quickchart-complete:disabled:hover {\n",
715
- " background-color: var(--disabled-bg-color);\n",
716
- " fill: var(--disabled-fill-color);\n",
717
- " box-shadow: none;\n",
718
- " }\n",
719
- "\n",
720
- " .colab-df-spinner {\n",
721
- " border: 2px solid var(--fill-color);\n",
722
- " border-color: transparent;\n",
723
- " border-bottom-color: var(--fill-color);\n",
724
- " animation:\n",
725
- " spin 1s steps(1) infinite;\n",
726
- " }\n",
727
- "\n",
728
- " @keyframes spin {\n",
729
- " 0% {\n",
730
- " border-color: transparent;\n",
731
- " border-bottom-color: var(--fill-color);\n",
732
- " border-left-color: var(--fill-color);\n",
733
- " }\n",
734
- " 20% {\n",
735
- " border-color: transparent;\n",
736
- " border-left-color: var(--fill-color);\n",
737
- " border-top-color: var(--fill-color);\n",
738
- " }\n",
739
- " 30% {\n",
740
- " border-color: transparent;\n",
741
- " border-left-color: var(--fill-color);\n",
742
- " border-top-color: var(--fill-color);\n",
743
- " border-right-color: var(--fill-color);\n",
744
- " }\n",
745
- " 40% {\n",
746
- " border-color: transparent;\n",
747
- " border-right-color: var(--fill-color);\n",
748
- " border-top-color: var(--fill-color);\n",
749
- " }\n",
750
- " 60% {\n",
751
- " border-color: transparent;\n",
752
- " border-right-color: var(--fill-color);\n",
753
- " }\n",
754
- " 80% {\n",
755
- " border-color: transparent;\n",
756
- " border-right-color: var(--fill-color);\n",
757
- " border-bottom-color: var(--fill-color);\n",
758
- " }\n",
759
- " 90% {\n",
760
- " border-color: transparent;\n",
761
- " border-bottom-color: var(--fill-color);\n",
762
- " }\n",
763
- " }\n",
764
- "</style>\n",
765
- "\n",
766
- " <script>\n",
767
- " async function quickchart(key) {\n",
768
- " const quickchartButtonEl =\n",
769
- " document.querySelector('#' + key + ' button');\n",
770
- " quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n",
771
- " quickchartButtonEl.classList.add('colab-df-spinner');\n",
772
- " try {\n",
773
- " const charts = await google.colab.kernel.invokeFunction(\n",
774
- " 'suggestCharts', [key], {});\n",
775
- " } catch (error) {\n",
776
- " console.error('Error during call to suggestCharts:', error);\n",
777
- " }\n",
778
- " quickchartButtonEl.classList.remove('colab-df-spinner');\n",
779
- " quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
780
- " }\n",
781
- " (() => {\n",
782
- " let quickchartButtonEl =\n",
783
- " document.querySelector('#df-39e3087a-f1c1-4313-b033-254945753780 button');\n",
784
- " quickchartButtonEl.style.display =\n",
785
- " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
786
- " })();\n",
787
- " </script>\n",
788
- " </div>\n",
789
- "\n",
790
- " </div>\n",
791
- " </div>\n"
792
- ],
793
- "application/vnd.google.colaboratory.intrinsic+json": {
794
- "type": "dataframe",
795
- "variable_name": "df",
796
- "summary": "{\n \"name\": \"df\",\n \"rows\": 2870,\n \"fields\": [\n {\n \"column\": \"image\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 2870,\n \"samples\": [\n \"/kaggle/input/brain-tumor-classification-mri/Training/glioma_tumor/gg (87).jpg\",\n \"/kaggle/input/brain-tumor-classification-mri/Training/meningioma_tumor/m2 (51).jpg\",\n \"/kaggle/input/brain-tumor-classification-mri/Training/meningioma_tumor/m1(109).jpg\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"label\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 3,\n \"num_unique_values\": 4,\n \"samples\": [\n 2,\n 3,\n 0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
797
- }
798
- },
799
- "metadata": {},
800
- "execution_count": 13
801
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
802
  ]
803
  },
 
 
 
 
 
 
 
 
 
804
  {
805
  "cell_type": "code",
806
  "source": [
807
- "# Create train / val split (80/20) and display counts\n",
808
- "from sklearn.model_selection import train_test_split\n",
809
- "\n",
810
- "train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['label'], random_state=42)\n",
811
- "train_df = train_df.reset_index(drop=True)\n",
812
- "val_df = val_df.reset_index(drop=True)\n",
813
- "\n",
814
- "print(\"Train:\", len(train_df), \"Val:\", len(val_df))\n",
815
- "print(train_df.label.value_counts(), val_df.label.value_counts())\n"
816
  ],
817
  "metadata": {
818
  "colab": {
819
  "base_uri": "https://localhost:8080/"
820
  },
821
- "id": "OvqCFuB1U5wR",
822
- "outputId": "2964fa89-7d65-4998-dd8f-d21b5d3456f4"
823
  },
824
- "execution_count": 14,
825
  "outputs": [
826
  {
827
  "output_type": "stream",
828
  "name": "stdout",
829
  "text": [
830
- "Train: 2296 Val: 574\n",
831
- "label\n",
832
- "2 662\n",
833
- "0 661\n",
834
- "1 657\n",
835
- "3 316\n",
836
- "Name: count, dtype: int64 label\n",
837
- "1 165\n",
838
- "2 165\n",
839
- "0 165\n",
840
- "3 79\n",
841
- "Name: count, dtype: int64\n"
842
  ]
843
  }
844
  ]
845
  },
846
  {
847
  "cell_type": "code",
848
- "source": [
849
- "# Define Dataset and Transforms (inline)"
850
- ],
851
  "metadata": {
852
  "id": "MYxdFZEiY4_9"
853
  },
854
- "execution_count": 15,
855
- "outputs": []
856
  },
857
  {
858
  "cell_type": "code",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
859
  "source": [
 
860
  "import torch\n",
861
  "from torch.utils.data import Dataset, DataLoader\n",
862
  "from PIL import Image\n",
863
  "import torchvision.transforms as T\n",
864
  "\n",
865
- "IMG_SIZE = 224\n",
866
- "\n",
867
  "train_tf = T.Compose([\n",
868
  " T.Resize((IMG_SIZE, IMG_SIZE)),\n",
869
  " T.RandomHorizontalFlip(),\n",
870
  " T.RandomRotation(10),\n",
871
- " T.ColorJitter(brightness=0.1, contrast=0.1),\n",
872
  " T.ToTensor(),\n",
873
  " T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])\n",
874
  "])\n",
@@ -878,117 +381,85 @@
878
  " T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])\n",
879
  "])\n",
880
  "\n",
881
- "class MedicalMultiClassDataset(Dataset):\n",
882
- " def __init__(self, df, transform=None):\n",
883
  " self.df = df.reset_index(drop=True)\n",
884
  " self.transform = transform\n",
885
- "\n",
886
  " def __len__(self):\n",
887
  " return len(self.df)\n",
888
- "\n",
889
  " def __getitem__(self, idx):\n",
890
  " row = self.df.iloc[idx]\n",
891
- " img = Image.open(row['image']).convert('RGB')\n",
 
892
  " if self.transform:\n",
893
  " img = self.transform(img)\n",
894
- " label = int(row['label'])\n",
895
  " return img, label\n",
896
  "\n",
897
- "# quick smoke test\n",
898
- "ds = MedicalMultiClassDataset(train_df.iloc[:4], transform=train_tf)\n",
899
- "for x,y in ds:\n",
900
- " print(x.shape, y)\n",
901
- " break\n"
902
- ],
903
- "metadata": {
904
- "colab": {
905
- "base_uri": "https://localhost:8080/"
906
- },
907
- "id": "Nd0TbtAlY7to",
908
- "outputId": "78520285-f5f9-4987-ef9d-65be96356ca1"
909
- },
910
- "execution_count": 16,
911
- "outputs": [
912
- {
913
- "output_type": "stream",
914
- "name": "stdout",
915
- "text": [
916
- "torch.Size([3, 224, 224]) 0\n"
917
- ]
918
- }
919
  ]
920
  },
921
  {
922
  "cell_type": "code",
923
- "source": [],
924
  "metadata": {
925
  "id": "SrJLOm33ZMpC"
926
  },
927
- "execution_count": null,
928
- "outputs": []
929
  },
930
  {
931
  "cell_type": "code",
932
- "source": [
933
- "# Build dataloaders"
934
- ],
935
  "metadata": {
936
  "id": "bG2JclRDY-Di"
937
  },
938
- "execution_count": 17,
939
- "outputs": []
940
- },
941
- {
942
- "cell_type": "code",
943
- "source": [
944
- "BATCH_SIZE = 8 # adjust to your GPU memory\n",
945
- "\n",
946
- "train_loader = DataLoader(MedicalMultiClassDataset(train_df, transform=train_tf), batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)\n",
947
- "val_loader = DataLoader(MedicalMultiClassDataset(val_df, transform=val_tf), batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)\n"
948
- ],
949
- "metadata": {
950
- "id": "Jc4pdqreZGDS"
951
- },
952
- "execution_count": 18,
953
- "outputs": []
954
  },
955
  {
956
  "cell_type": "code",
957
- "source": [
958
- "# Define model, optimizer, loss, device\n",
959
- "import timm, torch.nn as nn, torch\n",
960
- "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
961
- "NUM_CLASSES = 4\n",
962
- "\n",
963
- "model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=NUM_CLASSES)\n",
964
- "model = model.to(device)\n",
965
- "\n",
966
- "loss_fn = nn.CrossEntropyLoss()\n",
967
- "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)\n"
968
- ],
969
  "metadata": {
 
970
  "colab": {
971
  "base_uri": "https://localhost:8080/",
972
- "height": 153,
973
  "referenced_widgets": [
974
- "1b0dc97d07a94fca9a24670e3813ff29",
975
- "d19c38e4492c4b9bb4c03fb285df62a2",
976
- "815e13d941ee4a26929ca4ca01c8c615",
977
- "1fa8e074d5954a6daed1c025aa8191ab",
978
- "489118d586fb4c0bb2a148796d39d0e3",
979
- "1764933ac18c4c228ad8a852719c4f1a",
980
- "07db2fdc39c849bf9fb4be93aca7cf4e",
981
- "6b82e7b5673a4ad4bb112905754d0cda",
982
- "693ce4424a8240a38ccebaabf030c5ff",
983
- "31420da8c458472c85381cf0e7a7a242",
984
- "bd6b1fcf472a4860bbd8adb8a7f8c4ab"
985
  ]
986
  },
987
- "id": "TRBF6u9FZLnT",
988
- "outputId": "2e9229ac-0260-41ce-ea49-9c3982ca7ab6"
989
  },
990
- "execution_count": 19,
991
  "outputs": [
 
 
 
 
 
 
 
992
  {
993
  "output_type": "stream",
994
  "name": "stderr",
@@ -1010,31 +481,90 @@
1010
  "application/vnd.jupyter.widget-view+json": {
1011
  "version_major": 2,
1012
  "version_minor": 0,
1013
- "model_id": "1b0dc97d07a94fca9a24670e3813ff29"
1014
  }
1015
  },
1016
  "metadata": {}
 
 
 
 
 
 
 
1017
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1018
  ]
1019
  },
1020
  {
1021
  "cell_type": "code",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1022
  "source": [
 
1023
  "from sklearn.metrics import accuracy_score\n",
1024
  "import time, os\n",
 
1025
  "\n",
1026
- "EPOCHS = 10\n",
1027
- "CKPT_DIR = \"/content/checkpoints\"\n",
1028
- "os.makedirs(CKPT_DIR, exist_ok=True)\n",
1029
  "best_val_acc = 0.0\n",
 
 
 
1030
  "\n",
1031
  "for epoch in range(1, EPOCHS+1):\n",
 
 
1032
  " model.train()\n",
1033
- " running_loss = 0.0\n",
1034
  " all_preds, all_labels = [], []\n",
1035
- " t0 = time.time()\n",
1036
- "\n",
1037
- " # ---- Training ----\n",
1038
  " for imgs, labels in train_loader:\n",
1039
  " imgs = imgs.to(device)\n",
1040
  " labels = labels.to(device)\n",
@@ -1043,75 +573,125 @@
1043
  " loss = loss_fn(logits, labels)\n",
1044
  " loss.backward()\n",
1045
  " optimizer.step()\n",
1046
- " running_loss += loss.item() * imgs.size(0)\n",
1047
  " preds = logits.argmax(dim=1).cpu().numpy()\n",
1048
  " all_preds.extend(preds.tolist())\n",
1049
  " all_labels.extend(labels.cpu().numpy().tolist())\n",
1050
- "\n",
1051
- " train_loss = running_loss / len(train_loader.dataset)\n",
1052
  " train_acc = accuracy_score(all_labels, all_preds)\n",
1053
  "\n",
1054
- " # ---- Validation ----\n",
1055
  " model.eval()\n",
1056
- " val_running_loss = 0.0\n",
1057
  " v_preds, v_labels = [], []\n",
1058
  " with torch.no_grad():\n",
1059
  " for imgs, labels in val_loader:\n",
1060
  " imgs = imgs.to(device)\n",
1061
  " labels = labels.to(device)\n",
1062
  " logits = model(imgs)\n",
1063
- " loss = loss_fn(logits, labels) # <-- compute val loss\n",
1064
- " val_running_loss += loss.item() * imgs.size(0)\n",
1065
  " preds = logits.argmax(dim=1).cpu().numpy()\n",
1066
  " v_preds.extend(preds.tolist())\n",
1067
  " v_labels.extend(labels.cpu().numpy().tolist())\n",
1068
- "\n",
1069
- " val_loss = val_running_loss / len(val_loader.dataset)\n",
1070
  " val_acc = accuracy_score(v_labels, v_preds)\n",
1071
  "\n",
1072
- " elapsed = time.time() - t0\n",
1073
- " print(f\"Epoch {epoch}/{EPOCHS} \"\n",
1074
- " f\"— train_loss: {train_loss:.4f} train_acc: {train_acc:.4f} \"\n",
1075
- " f\"val_loss: {val_loss:.4f} val_acc: {val_acc:.4f} \"\n",
1076
- " f\"time: {elapsed/60:.2f}m\")\n",
 
1077
  "\n",
1078
- " # ---- Save checkpoint ----\n",
1079
- " ckpt_path = os.path.join(CKPT_DIR, f\"epoch{epoch}.pth\")\n",
1080
- " torch.save(model.state_dict(), ckpt_path)\n",
1081
  " if val_acc > best_val_acc:\n",
1082
  " best_val_acc = val_acc\n",
1083
- " torch.save(model.state_dict(), os.path.join(CKPT_DIR, \"best_model_multiclass.pth\"))\n",
1084
- " print(\"Saved best model:\", best_val_acc)\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1085
  ],
1086
  "metadata": {
1087
  "colab": {
1088
  "base_uri": "https://localhost:8080/"
1089
  },
1090
- "id": "nUDYRwowZRGb",
1091
- "outputId": "65d7584c-b90e-4a70-bf44-f6747fbe271e"
1092
  },
1093
- "execution_count": 22,
1094
  "outputs": [
1095
  {
1096
  "output_type": "stream",
1097
  "name": "stdout",
1098
  "text": [
1099
- "Epoch 1/10 — train_loss: 0.0583 train_acc: 0.9800 val_loss: 0.1981 val_acc: 0.9512 time: 0.38m\n",
1100
- "Saved best model: 0.9512195121951219\n",
1101
- "Epoch 2/10 — train_loss: 0.0519 train_acc: 0.9817 val_loss: 0.2222 val_acc: 0.9425 time: 0.45m\n",
1102
- "Epoch 3/10 — train_loss: 0.0665 train_acc: 0.9756 val_loss: 0.1690 val_acc: 0.9582 time: 0.39m\n",
1103
- "Saved best model: 0.9581881533101045\n",
1104
- "Epoch 4/10 — train_loss: 0.0410 train_acc: 0.9861 val_loss: 0.1936 val_acc: 0.9564 time: 0.39m\n",
1105
- "Epoch 5/10 — train_loss: 0.0629 train_acc: 0.9821 val_loss: 0.1469 val_acc: 0.9634 time: 0.39m\n",
1106
- "Saved best model: 0.9634146341463414\n",
1107
- "Epoch 6/10 — train_loss: 0.0530 train_acc: 0.9843 val_loss: 0.1630 val_acc: 0.9652 time: 0.41m\n",
1108
- "Saved best model: 0.9651567944250871\n",
1109
- "Epoch 7/10 — train_loss: 0.0346 train_acc: 0.9869 val_loss: 0.1726 val_acc: 0.9634 time: 0.41m\n",
1110
- "Epoch 8/10 — train_loss: 0.0268 train_acc: 0.9917 val_loss: 0.1771 val_acc: 0.9582 time: 0.39m\n",
1111
- "Epoch 9/10 — train_loss: 0.0349 train_acc: 0.9874 val_loss: 0.1282 val_acc: 0.9756 time: 0.40m\n",
1112
- "Saved best model: 0.975609756097561\n",
1113
- "Epoch 10/10 — train_loss: 0.0338 train_acc: 0.9891 val_loss: 0.1267 val_acc: 0.9791 time: 0.39m\n",
1114
- "Saved best model: 0.9790940766550522\n"
 
 
 
 
 
1115
  ]
1116
  }
1117
  ]
@@ -1120,117 +700,175 @@
1120
  "cell_type": "code",
1121
  "source": [],
1122
  "metadata": {
1123
- "id": "I2WUIxNnZtVI"
 
 
 
 
 
 
 
 
 
1124
  },
1125
  "execution_count": null,
1126
  "outputs": []
1127
  },
1128
  {
1129
  "cell_type": "code",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1130
  "source": [
1131
  "# Load best model & test a single image (inference helper)\n",
1132
  "import torch.nn.functional as F\n",
1133
  "from PIL import Image\n",
1134
  "from torchvision import transforms\n",
1135
  "\n",
1136
- "# load\n",
1137
- "best_ckpt = \"/content/checkpoints/best_model_multiclass.pth\"\n",
1138
- "model.load_state_dict(torch.load(best_ckpt, map_location=device))\n",
1139
- "model.eval()\n",
1140
- "\n",
1141
- "IDX2LABEL = {0:\"glioma_tumor\", 1:\"meningioma_tumor\", 2:\"pituitary_tumor\", 3:\"no_tumor\"}\n",
1142
  "\n",
1143
- "def preprocess(pil_img, img_size=IMG_SIZE):\n",
1144
- " tf = transforms.Compose([\n",
1145
- " transforms.Resize((img_size, img_size)),\n",
1146
- " transforms.ToTensor(),\n",
1147
- " transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])\n",
1148
- " ])\n",
1149
- " return tf(pil_img).unsqueeze(0).to(device)\n",
1150
- "\n",
1151
- "def predict(pil_img, topk=3):\n",
1152
- " x = preprocess(pil_img)\n",
1153
  " with torch.no_grad():\n",
1154
  " logits = model(x)\n",
1155
  " probs = F.softmax(logits, dim=1).cpu().numpy().ravel()\n",
1156
  " idxs = probs.argsort()[::-1][:topk]\n",
1157
  " return [{\"label\": IDX2LABEL[int(i)], \"score\": float(probs[int(i)])} for i in idxs]\n",
1158
  "\n",
1159
- "# quick test with a sample image from val set\n",
1160
- "sample_path = val_df.iloc[0].image\n",
1161
- "print(\"Sample:\", sample_path)\n",
1162
  "img = Image.open(sample_path).convert(\"RGB\")\n",
1163
- "preds = predict(img, topk=3)\n",
1164
- "preds\n"
1165
- ],
1166
- "metadata": {
1167
- "colab": {
1168
- "base_uri": "https://localhost:8080/"
1169
- },
1170
- "id": "H2BqijX2Zupx",
1171
- "outputId": "8ad2aef5-e047-4054-acca-f50cbc3429be"
1172
- },
1173
- "execution_count": 29,
1174
- "outputs": [
1175
- {
1176
- "output_type": "stream",
1177
- "name": "stdout",
1178
- "text": [
1179
- "Sample: /kaggle/input/brain-tumor-classification-mri/Training/meningioma_tumor/m3 (155).jpg\n"
1180
- ]
1181
- },
1182
- {
1183
- "output_type": "execute_result",
1184
- "data": {
1185
- "text/plain": [
1186
- "[{'label': 'meningioma_tumor', 'score': 0.9999998807907104},\n",
1187
- " {'label': 'pituitary_tumor', 'score': 5.87046180555717e-08},\n",
1188
- " {'label': 'glioma_tumor', 'score': 4.05661282343317e-08}]"
1189
- ]
1190
- },
1191
- "metadata": {},
1192
- "execution_count": 29
1193
- }
1194
  ]
1195
  },
1196
  {
1197
  "cell_type": "code",
1198
- "source": [],
1199
  "metadata": {
1200
  "id": "RAZRNlq2ZutV"
1201
  },
1202
- "execution_count": null,
1203
- "outputs": []
1204
  },
1205
  {
1206
  "cell_type": "code",
1207
- "source": [],
1208
  "metadata": {
1209
  "id": "chDN6CJmc8h0"
1210
  },
 
 
 
 
 
 
 
 
 
 
 
1211
  "execution_count": null,
1212
  "outputs": []
1213
  },
1214
  {
1215
  "cell_type": "code",
 
 
 
 
 
1216
  "source": [
1217
  "sample_paths = [\n",
1218
- " \"/kaggle/input/brain-tumor-classification-mri/Testing/glioma_tumor/image(10).jpg\",\n",
1219
- " \"/kaggle/input/brain-tumor-classification-mri/Testing/meningioma_tumor/image(10).jpg\",\n",
1220
- " \"/kaggle/input/brain-tumor-classification-mri/Testing/no_tumor/image(10).jpg\",\n",
1221
- " \"/kaggle/input/brain-tumor-classification-mri/Testing/pituitary_tumor/image(10).jpg\",\n",
1222
  "]\n",
1223
  "\n",
1224
  "OPENAI_API_KEY = \"\""
1225
- ],
1226
- "metadata": {
1227
- "id": "WSygonZYdvpg"
1228
- },
1229
- "execution_count": 36,
1230
- "outputs": []
1231
  },
1232
  {
1233
  "cell_type": "code",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1234
  "source": [
1235
  "import os\n",
1236
  "from pathlib import Path\n",
@@ -1243,23 +881,33 @@
1243
  "import openai\n",
1244
  "from openai import OpenAI\n",
1245
  "\n",
1246
- "\n",
1247
- "# ------------------ USER CONFIG ------------------\n",
1248
- "CKPT_PATH = \"/content/checkpoints/best_model_multiclass.pth\" # path to trained checkpoint\n",
1249
  "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
1250
  "BACKBONE = \"efficientnet_b0\"\n",
1251
- "NUM_CLASSES = 4\n",
1252
  "IMG_SIZE = 224\n",
1253
  "\n",
1254
- "# Paste your 4 sample image absolute paths here\n",
1255
- "# e.g. sample_paths = [\"/content/a.png\", \"/content/b.png\", ...]\n",
1256
- "sample_paths = sample_paths\n",
 
1257
  "\n",
1258
- "OPENAI_API_KEY = OPENAI_API_KEY\n",
1259
  "if OPENAI_API_KEY:\n",
1260
  " openai.api_key = OPENAI_API_KEY\n",
1261
  "\n",
1262
- "IDX2LABEL = {0: \"glioma_tumor\", 1: \"meningioma_tumor\", 2: \"pituitary_tumor\", 3: \"no_tumor\"}\n",
 
 
 
 
 
 
 
 
 
 
 
1263
  "# -------------------------------------------------\n",
1264
  "\n",
1265
  "# Validate sample paths\n",
@@ -1295,7 +943,7 @@
1295
  " idxs = probs.argsort()[::-1][:topk]\n",
1296
  " return [{\"label\": IDX2LABEL[int(i)], \"score\": float(probs[int(i)])} for i in idxs]\n",
1297
  "\n",
1298
- "# OpenAI system prompt\n",
1299
  "SYSTEM_TEMPLATE = \"\"\"You are an educational radiology assistant. The model predicted:\n",
1300
  "{preds}\n",
1301
  "\n",
@@ -1325,13 +973,12 @@
1325
  " )\n",
1326
  " return resp.choices[0].message.content.strip()\n",
1327
  "\n",
1328
- "\n",
1329
  "# Prepare thumbnails (just images)\n",
1330
  "thumbs = [Image.open(p).convert(\"RGB\") for p in sample_paths]\n",
1331
  "\n",
1332
- "# Gradio UI\n",
1333
  "with gr.Blocks() as demo:\n",
1334
- " gr.Markdown(\"## Brain MRI multi-class demo — Educational only\")\n",
1335
  " with gr.Row():\n",
1336
  " with gr.Column():\n",
1337
  " # Use a simple list of PIL thumbnails for the gallery value\n",
@@ -1349,16 +996,16 @@
1349
  " selected_path_state = gr.State(value=None)\n",
1350
  " preds_state = gr.State(value=None)\n",
1351
  "\n",
1352
- " # select image from gallery\n",
1353
  " def on_select(evt: gr.SelectData):\n",
1354
- " idx = evt.index\n",
1355
- " path = sample_paths[idx] # <-- return only the path (string)\n",
1356
  " return path\n",
1357
  "\n",
1358
  " # ensure the gallery select writes only the path to the state\n",
1359
  " gallery.select(fn=on_select, inputs=None, outputs=[selected_path_state])\n",
1360
  "\n",
1361
- " # analyze\n",
1362
  " def analyze(path):\n",
1363
  " try:\n",
1364
  " if path is None:\n",
@@ -1375,7 +1022,7 @@
1375
  "\n",
1376
  " analyze_btn.click(fn=analyze, inputs=selected_path_state, outputs=[preds_output, chatbot, preds_state])\n",
1377
  "\n",
1378
- " # chat\n",
1379
  " def chat(chat_history, msg, preds_text):\n",
1380
  " if preds_text is None:\n",
1381
  " return chat_history+[(\"AI\",\"Please analyze an image first.\")], \"\"\n",
@@ -1406,74 +1053,391 @@
1406
  "\n",
1407
  " gr.Markdown(\"⚠️ Educational use only — not a medical diagnosis.\")\n",
1408
  "\n",
1409
- "demo.launch(share=True)\n"
1410
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
1411
  "metadata": {
1412
  "colab": {
1413
- "base_uri": "https://localhost:8080/",
1414
- "height": 680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1415
  },
1416
- "id": "S_rWA1Xcdvx8",
1417
- "outputId": "5735af82-a1ea-4963-d8da-6cc480836bb6"
1418
- },
1419
- "execution_count": 50,
1420
- "outputs": [
1421
- {
1422
- "output_type": "stream",
1423
- "name": "stdout",
1424
- "text": [
1425
- "Loading model...\n",
1426
- "Model loaded on cuda\n"
1427
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1428
  },
1429
- {
1430
- "output_type": "stream",
1431
- "name": "stderr",
1432
- "text": [
1433
- "/tmp/ipython-input-4085922474.py:110: UserWarning: You have not specified a value for the `type` parameter. Defaulting to the 'tuples' format for chatbot messages, but this is deprecated and will be removed in a future version of Gradio. Please set type='messages' instead, which uses openai-style dictionaries with 'role' and 'content' keys.\n",
1434
- " chatbot = gr.Chatbot(label=\"Assistant\")\n"
1435
- ]
 
 
 
 
 
 
 
1436
  },
1437
- {
1438
- "output_type": "stream",
1439
- "name": "stdout",
1440
- "text": [
1441
- "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n",
1442
- "* Running on public URL: https://53b5220cc19d12e1dc.gradio.live\n",
1443
- "\n",
1444
- "This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
1445
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1446
  },
1447
- {
1448
- "output_type": "display_data",
1449
- "data": {
1450
- "text/plain": [
1451
- "<IPython.core.display.HTML object>"
1452
- ],
1453
- "text/html": [
1454
- "<div><iframe src=\"https://53b5220cc19d12e1dc.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
1455
- ]
1456
- },
1457
- "metadata": {}
 
 
 
 
1458
  },
1459
- {
1460
- "output_type": "execute_result",
1461
- "data": {
1462
- "text/plain": []
1463
- },
1464
- "metadata": {},
1465
- "execution_count": 50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1466
  }
1467
- ]
1468
- },
1469
- {
1470
- "cell_type": "code",
1471
- "source": [],
1472
- "metadata": {
1473
- "id": "TrS_HlY7d8s3"
1474
- },
1475
- "execution_count": null,
1476
- "outputs": []
1477
  }
1478
- ]
 
 
1479
  }
 
1
  {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {
7
+ "colab": {
8
+ "base_uri": "https://localhost:8080/"
9
+ },
10
+ "id": "yViCfhqXTD-l",
11
+ "outputId": "e599cb0e-fb7f-44d0-c84d-82ce2d17c0fe"
12
+ },
13
+ "outputs": [
14
+ {
15
+ "output_type": "stream",
16
+ "name": "stdout",
17
+ "text": [
18
+ "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/832.4 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m829.4/832.4 kB\u001b[0m \u001b[31m36.7 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m832.4/832.4 kB\u001b[0m \u001b[31m22.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
19
+ "\u001b[?25h\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/983.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m983.2/983.2 kB\u001b[0m \u001b[31m52.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
20
+ "\u001b[?25h"
21
+ ]
22
+ }
23
+ ],
24
  "source": [
25
  "# Colab: install required packages\n",
26
  "!pip install -q kagglehub timm torch torchvision scikit-learn pandas pillow matplotlib gradio openai pytorch-lightning"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": 2,
32
  "metadata": {
33
  "colab": {
34
  "base_uri": "https://localhost:8080/"
35
  },
36
+ "id": "B7iabh3oOyNe",
37
+ "outputId": "fb2920c5-a19b-49c8-814e-640ce263c27b"
38
  },
 
39
  "outputs": [
40
  {
41
  "output_type": "stream",
42
  "name": "stdout",
43
  "text": [
44
+ "Downloading from https://www.kaggle.com/api/v1/datasets/download/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri?dataset_version_number=1...\n"
45
+ ]
46
+ },
47
+ {
48
+ "output_type": "stream",
49
+ "name": "stderr",
50
+ "text": [
51
+ "100%|██████████| 361M/361M [00:02<00:00, 183MB/s]"
52
+ ]
53
+ },
54
+ {
55
+ "output_type": "stream",
56
+ "name": "stdout",
57
+ "text": [
58
+ "Extracting files...\n"
59
+ ]
60
+ },
61
+ {
62
+ "output_type": "stream",
63
+ "name": "stderr",
64
+ "text": [
65
+ "\n"
66
  ]
 
 
 
 
 
 
 
 
 
67
  },
 
 
 
 
68
  {
69
  "output_type": "stream",
70
  "name": "stdout",
71
  "text": [
72
+ "Dataset path: /root/.cache/kagglehub/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/versions/1\n",
73
+ "Dataset\n"
 
 
74
  ]
75
  }
76
  ],
77
  "source": [
78
  "import kagglehub\n",
79
+ "dataset_ref = \"murtozalikhon/brain-tumor-multimodal-image-ct-and-mri\"\n",
80
  "path = kagglehub.dataset_download(dataset_ref)\n",
81
  "print(\"Dataset path:\", path)\n",
82
  "\n",
 
88
  },
89
  {
90
  "cell_type": "code",
91
+ "execution_count": 3,
 
 
 
92
  "metadata": {
93
  "colab": {
94
  "base_uri": "https://localhost:8080/"
95
  },
96
  "id": "gq9zpPkmZVK7",
97
+ "outputId": "174e6148-05a4-4b4b-8ebe-7b62ea806b1b"
98
  },
 
99
  "outputs": [
100
  {
101
  "output_type": "stream",
 
104
  "True\n"
105
  ]
106
  }
107
+ ],
108
+ "source": [
109
+ "import torch\n",
110
+ "print(torch.cuda.is_available())"
111
  ]
112
  },
113
  {
114
  "cell_type": "code",
115
+ "execution_count": null,
116
+ "metadata": {
117
+ "id": "YvSgspBGU3iG"
118
+ },
119
+ "outputs": [],
120
  "source": [
121
  "# Create CSV (multi-class) from folders (in-memory)\n",
122
  "\n",
123
  "# Change DATA_ROOT to the folder that directly contains glioma_tumor, meningioma_tumor, pituitary_tumor, no_tumor."
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "source": [
129
+ "# CONFIG - edit these paths before running the pipeline\n",
130
+ "DATA_ROOT = \"/root/.cache/kagglehub/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/versions/1/Dataset\"\n",
131
+ "# Where to write outputs / checkpoints\n",
132
+ "OUT_DIR = \"/content/output\"\n",
133
+ "os.makedirs(OUT_DIR, exist_ok=True)\n",
134
+ "\n",
135
+ "# Training hyperparams (change if needed)\n",
136
+ "IMG_SIZE = 224\n",
137
+ "BATCH_SIZE = 8\n",
138
+ "EPOCHS = 10\n",
139
+ "LR = 1e-4\n",
140
+ "NUM_WORKERS = 2 # adjust for Colab\n",
141
+ "TOP_K = 3\n",
142
+ "\n",
143
+ "print(\"DATA_ROOT:\", DATA_ROOT)\n"
144
  ],
145
  "metadata": {
146
+ "colab": {
147
+ "base_uri": "https://localhost:8080/"
148
+ },
149
+ "id": "H5gEKtfMJxVQ",
150
+ "outputId": "fd555a63-b825-4860-86df-e4bc3f0caa76"
151
  },
152
+ "execution_count": 23,
153
+ "outputs": [
154
+ {
155
+ "output_type": "stream",
156
+ "name": "stdout",
157
+ "text": [
158
+ "DATA_ROOT: /root/.cache/kagglehub/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/versions/1/Dataset\n"
159
+ ]
160
+ }
161
+ ]
162
  },
163
  {
164
  "cell_type": "code",
165
+ "execution_count": 19,
166
+ "metadata": {
167
+ "colab": {
168
+ "base_uri": "https://localhost:8080/"
169
+ },
170
+ "id": "1grjHP2ZU5st",
171
+ "outputId": "050866ba-120a-40d9-a7b3-d33cdcd21394"
172
+ },
173
+ "outputs": [
174
+ {
175
+ "output_type": "stream",
176
+ "name": "stdout",
177
+ "text": [
178
+ "Total images: 9618\n",
179
+ "label\n",
180
+ "no_tumor 4300\n",
181
+ "ct_tumor 2318\n",
182
+ "meningioma 1112\n",
183
+ "glioma 672\n",
184
+ "pituitary 629\n",
185
+ "mri_tumor 587\n",
186
+ "Name: count, dtype: int64\n",
187
+ "Wrote CSV: /content/output/combined_image_labels.csv\n"
188
+ ]
189
+ }
190
+ ],
191
  "source": [
192
+ "# Build combined CSV for your described dataset layout\n",
193
+ "# Paste & run in Colab; set PARENT_DIR to your dataset root\n",
194
+ "\n",
195
  "import os, pandas as pd\n",
196
+ "from pathlib import Path\n",
197
  "\n",
198
+ "# EDIT: point this to the parent folder that contains the two modality folders\n",
199
+ "PARENT_DIR = DATA_ROOT # ← change this\n",
200
  "\n",
201
+ "# expected modality folder names (we'll match case-insensitively)\n",
202
+ "# adjust if your folder names differ\n",
203
+ "CT_FOLDER_KEY = \"ct\"\n",
204
+ "MRI_FOLDER_KEY = \"mri\"\n",
 
 
205
  "\n",
206
  "rows = []\n",
207
+ "for mod_name in sorted(os.listdir(PARENT_DIR)):\n",
208
+ " mod_path = os.path.join(PARENT_DIR, mod_name)\n",
209
+ " if not os.path.isdir(mod_path):\n",
 
 
 
210
  " continue\n",
211
+ " mod_lower = mod_name.lower()\n",
212
+ " modality = \"CT\" if CT_FOLDER_KEY in mod_lower else (\"MRI\" if MRI_FOLDER_KEY in mod_lower else mod_name)\n",
213
+ " # inside each modality we expect subfolders: 'healthy' and 'tumor'\n",
214
+ " for sub in sorted(os.listdir(mod_path)):\n",
215
+ " sub_path = os.path.join(mod_path, sub)\n",
216
+ " if not os.path.isdir(sub_path):\n",
217
+ " continue\n",
218
+ " sub_lower = sub.lower()\n",
219
+ " # CT modality handling\n",
220
+ " if modality == \"CT\":\n",
221
+ " if \"healthy\" in sub_lower:\n",
222
+ " lab = \"no_tumor\"\n",
223
+ " else:\n",
224
+ " lab = \"ct_tumor\"\n",
225
+ " for fn in os.listdir(sub_path):\n",
226
+ " if fn.lower().endswith((\".png\",\".jpg\",\".jpeg\")):\n",
227
+ " rows.append({\"image\": os.path.join(sub_path, fn), \"label\": lab, \"modality\": \"CT\"})\n",
228
+ " # MRI modality handling\n",
229
+ " elif modality == \"MRI\":\n",
230
+ " if \"healthy\" in sub_lower:\n",
231
+ " for fn in os.listdir(sub_path):\n",
232
+ " if fn.lower().endswith((\".png\",\".jpg\",\".jpeg\")):\n",
233
+ " rows.append({\"image\": os.path.join(sub_path, fn), \"label\": \"no_tumor\", \"modality\": \"MRI\"})\n",
234
+ " else:\n",
235
+ " # tumor folder: try to parse subtype from filename, else fallback\n",
236
+ " for fn in os.listdir(sub_path):\n",
237
+ " if not fn.lower().endswith((\".png\",\".jpg\",\".jpeg\")):\n",
238
+ " continue\n",
239
+ " low = fn.lower()\n",
240
+ " if \"glioma\" in low:\n",
241
+ " lab = \"glioma\"\n",
242
+ " elif \"meningioma\" in low:\n",
243
+ " lab = \"meningioma\"\n",
244
+ " elif \"pituitary\" in low:\n",
245
+ " lab = \"pituitary\"\n",
246
+ " else:\n",
247
+ " lab = \"mri_tumor\"\n",
248
+ " rows.append({\"image\": os.path.join(sub_path, fn), \"label\": lab, \"modality\": \"MRI\"})\n",
249
+ " else:\n",
250
+ " # unexpected modality name: treat similarly to MRI fallback\n",
251
+ " for fn in os.listdir(sub_path):\n",
252
+ " if fn.lower().endswith((\".png\",\".jpg\",\".jpeg\")):\n",
253
+ " rows.append({\"image\": os.path.join(sub_path, fn), \"label\": \"unknown_modality\", \"modality\": mod_name})\n",
254
  "\n",
255
  "df = pd.DataFrame(rows).sample(frac=1, random_state=42).reset_index(drop=True)\n",
256
  "print(\"Total images:\", len(df))\n",
257
  "print(df.label.value_counts())\n",
258
+ "out_csv = os.path.join(OUT_DIR, \"combined_image_labels.csv\")\n",
259
+ "df.to_csv(out_csv, index=False)\n",
260
+ "print(\"Wrote CSV:\", out_csv)\n"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": 20,
266
  "metadata": {
267
  "colab": {
268
+ "base_uri": "https://localhost:8080/"
 
269
  },
270
+ "id": "OvqCFuB1U5wR",
271
+ "outputId": "be745f82-a3b2-4b69-ea06-f0615f8cc321"
272
  },
 
273
  "outputs": [
274
  {
275
  "output_type": "stream",
276
  "name": "stdout",
277
  "text": [
278
+ "Train: 7213 Val: 1202 Test: 1203\n"
 
 
 
 
 
 
279
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  }
281
+ ],
282
+ "source": [
283
+ "# Stratified split train/val/test\n",
284
+ "from sklearn.model_selection import train_test_split\n",
285
+ "csv_path = os.path.join(OUT_DIR, \"combined_image_labels.csv\")\n",
286
+ "df = pd.read_csv(csv_path)\n",
287
+ "\n",
288
+ "train_df, temp_df = train_test_split(df, test_size=0.25, stratify=df['label'], random_state=42)\n",
289
+ "val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['label'], random_state=42)\n",
290
+ "\n",
291
+ "print(\"Train:\", len(train_df), \"Val:\", len(val_df), \"Test:\", len(test_df))\n",
292
+ "train_df.to_csv(os.path.join(OUT_DIR, \"train.csv\"), index=False)\n",
293
+ "val_df.to_csv(os.path.join(OUT_DIR, \"val.csv\"), index=False)\n",
294
+ "test_df.to_csv(os.path.join(OUT_DIR, \"test.csv\"), index=False)\n"
295
  ]
296
  },
297
+ {
298
+ "cell_type": "code",
299
+ "source": [],
300
+ "metadata": {
301
+ "id": "MTmULrSFKdxe"
302
+ },
303
+ "execution_count": null,
304
+ "outputs": []
305
+ },
306
  {
307
  "cell_type": "code",
308
  "source": [
309
+ "# Build label -> index mapping from training labels\n",
310
+ "labels = sorted(train_df['label'].unique().tolist())\n",
311
+ "LABEL2IDX = {lab: idx for idx, lab in enumerate(labels)}\n",
312
+ "IDX2LABEL = {v:k for k,v in LABEL2IDX.items()}\n",
313
+ "print(\"Labels:\", LABEL2IDX)\n",
314
+ "NUM_CLASSES = len(labels)\n"
 
 
 
315
  ],
316
  "metadata": {
317
  "colab": {
318
  "base_uri": "https://localhost:8080/"
319
  },
320
+ "id": "LKnSLphiKd6i",
321
+ "outputId": "48c16684-e96c-4146-9ba6-1fe0a28e138d"
322
  },
323
+ "execution_count": 21,
324
  "outputs": [
325
  {
326
  "output_type": "stream",
327
  "name": "stdout",
328
  "text": [
329
+ "Labels: {'ct_tumor': 0, 'glioma': 1, 'meningioma': 2, 'mri_tumor': 3, 'no_tumor': 4, 'pituitary': 5}\n"
 
 
 
 
 
 
 
 
 
 
 
330
  ]
331
  }
332
  ]
333
  },
334
  {
335
  "cell_type": "code",
336
+ "execution_count": null,
 
 
337
  "metadata": {
338
  "id": "MYxdFZEiY4_9"
339
  },
340
+ "outputs": [],
341
+ "source": []
342
  },
343
  {
344
  "cell_type": "code",
345
+ "execution_count": 24,
346
+ "metadata": {
347
+ "colab": {
348
+ "base_uri": "https://localhost:8080/"
349
+ },
350
+ "id": "Nd0TbtAlY7to",
351
+ "outputId": "63032d78-ebe7-4370-c780-746cee650c8f"
352
+ },
353
+ "outputs": [
354
+ {
355
+ "output_type": "stream",
356
+ "name": "stdout",
357
+ "text": [
358
+ "batch x shape: torch.Size([8, 3, 224, 224]) y shape: torch.Size([8])\n"
359
+ ]
360
+ }
361
+ ],
362
  "source": [
363
+ "# Define Dataset and Transforms (inline)\n",
364
  "import torch\n",
365
  "from torch.utils.data import Dataset, DataLoader\n",
366
  "from PIL import Image\n",
367
  "import torchvision.transforms as T\n",
368
  "\n",
369
+ "# transforms (same for MRI/CT for MVP)\n",
 
370
  "train_tf = T.Compose([\n",
371
  " T.Resize((IMG_SIZE, IMG_SIZE)),\n",
372
  " T.RandomHorizontalFlip(),\n",
373
  " T.RandomRotation(10),\n",
374
+ " T.ColorJitter(0.08, 0.08, 0.08, 0.02),\n",
375
  " T.ToTensor(),\n",
376
  " T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])\n",
377
  "])\n",
 
381
  " T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])\n",
382
  "])\n",
383
  "\n",
384
+ "class MultiModalityDataset(Dataset):\n",
385
+ " def __init__(self, df, transform=None, label_map=None):\n",
386
  " self.df = df.reset_index(drop=True)\n",
387
  " self.transform = transform\n",
388
+ " self.label_map = label_map\n",
389
  " def __len__(self):\n",
390
  " return len(self.df)\n",
 
391
  " def __getitem__(self, idx):\n",
392
  " row = self.df.iloc[idx]\n",
393
+ " img_path = row['image']\n",
394
+ " img = Image.open(img_path).convert('RGB') # ensures 3 channels\n",
395
  " if self.transform:\n",
396
  " img = self.transform(img)\n",
397
+ " label = self.label_map[row['label']]\n",
398
  " return img, label\n",
399
  "\n",
400
+ "train_ds = MultiModalityDataset(train_df, transform=train_tf, label_map=LABEL2IDX)\n",
401
+ "val_ds = MultiModalityDataset(val_df, transform=val_tf, label_map=LABEL2IDX)\n",
402
+ "test_ds = MultiModalityDataset(test_df, transform=val_tf, label_map=LABEL2IDX)\n",
403
+ "\n",
404
+ "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)\n",
405
+ "val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)\n",
406
+ "test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)\n",
407
+ "\n",
408
+ "# sanity check\n",
409
+ "x,y = next(iter(train_loader))\n",
410
+ "print(\"batch x shape:\", x.shape, \"y shape:\", y.shape)"
 
 
 
 
 
 
 
 
 
 
 
411
  ]
412
  },
413
  {
414
  "cell_type": "code",
415
+ "execution_count": null,
416
  "metadata": {
417
  "id": "SrJLOm33ZMpC"
418
  },
419
+ "outputs": [],
420
+ "source": []
421
  },
422
  {
423
  "cell_type": "code",
424
+ "execution_count": null,
 
 
425
  "metadata": {
426
  "id": "bG2JclRDY-Di"
427
  },
428
+ "outputs": [],
429
+ "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  },
431
  {
432
  "cell_type": "code",
433
+ "execution_count": 25,
 
 
 
 
 
 
 
 
 
 
 
434
  "metadata": {
435
+ "id": "Jc4pdqreZGDS",
436
  "colab": {
437
  "base_uri": "https://localhost:8080/",
438
+ "height": 188,
439
  "referenced_widgets": [
440
+ "fd1f533d7d814e10b252bf66f8c0697e",
441
+ "243ec1d23b8e469e833db01f89189639",
442
+ "9bcf41c7fe1e425aae4e37dcfca6a67c",
443
+ "f09e93d2fdc64071b3b9860fd7782c58",
444
+ "0a03432f7ce14c8c98090f846a512beb",
445
+ "1da74b90beb44d85a22ca57ee52b1e9d",
446
+ "a7c35ed82c20419b9b559b217d924a6a",
447
+ "0a302207f9884533b49fb150411169d2",
448
+ "bc007e650c0f4f619946055a6c72d132",
449
+ "20b5ba3ee19446dc89ddb5222be6f5f2",
450
+ "1e35273de451460aa8c82df17b3ca23c"
451
  ]
452
  },
453
+ "outputId": "8f6811b3-e720-4ffd-b49b-45c177d2972e"
 
454
  },
 
455
  "outputs": [
456
+ {
457
+ "output_type": "stream",
458
+ "name": "stdout",
459
+ "text": [
460
+ "Using device: cuda\n"
461
+ ]
462
+ },
463
  {
464
  "output_type": "stream",
465
  "name": "stderr",
 
481
  "application/vnd.jupyter.widget-view+json": {
482
  "version_major": 2,
483
  "version_minor": 0,
484
+ "model_id": "fd1f533d7d814e10b252bf66f8c0697e"
485
  }
486
  },
487
  "metadata": {}
488
+ },
489
+ {
490
+ "output_type": "stream",
491
+ "name": "stdout",
492
+ "text": [
493
+ "Class weights: tensor([0.6917, 2.3853, 1.4414, 2.7322, 0.3728, 2.5470], device='cuda:0')\n"
494
+ ]
495
  }
496
+ ],
497
+ "source": [
498
+ "# Model, loss, optimizer, optional class weights\n",
499
+ "import timm\n",
500
+ "import torch.nn as nn\n",
501
+ "\n",
502
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
503
+ "print(\"Using device:\", device)\n",
504
+ "\n",
505
+ "model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=NUM_CLASSES)\n",
506
+ "model = model.to(device)\n",
507
+ "\n",
508
+ "# Optionally compute class weights if highly imbalanced\n",
509
+ "from sklearn.utils.class_weight import compute_class_weight\n",
510
+ "import numpy as np\n",
511
+ "cls_w = compute_class_weight(class_weight='balanced', classes=np.arange(NUM_CLASSES), y=train_df['label'].map(LABEL2IDX))\n",
512
+ "cls_w = torch.tensor(cls_w, dtype=torch.float).to(device)\n",
513
+ "print(\"Class weights:\", cls_w)\n",
514
+ "\n",
515
+ "loss_fn = nn.CrossEntropyLoss(weight=cls_w)\n",
516
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)\n",
517
+ "scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=2)\n"
518
  ]
519
  },
520
  {
521
  "cell_type": "code",
522
+ "execution_count": 26,
523
+ "metadata": {
524
+ "colab": {
525
+ "base_uri": "https://localhost:8080/"
526
+ },
527
+ "id": "nUDYRwowZRGb",
528
+ "outputId": "bdac381c-00bd-4330-f8d9-f2ba47208e67"
529
+ },
530
+ "outputs": [
531
+ {
532
+ "output_type": "stream",
533
+ "name": "stdout",
534
+ "text": [
535
+ "Epoch 1/12 | train_loss 0.7560 train_acc 0.7829 | val_loss 0.4183 val_acc 0.8968 | 1.48 min\n",
536
+ "Saved best model: 0.8968386023294509\n",
537
+ "Epoch 2/12 | train_loss 0.3758 train_acc 0.8959 | val_loss 0.3714 val_acc 0.9160 | 1.32 min\n",
538
+ "Saved best model: 0.9159733777038269\n",
539
+ "Epoch 3/12 | train_loss 0.2867 train_acc 0.9186 | val_loss 0.3187 val_acc 0.9268 | 1.27 min\n",
540
+ "Saved best model: 0.9267886855241264\n",
541
+ "Epoch 4/12 | train_loss 0.2459 train_acc 0.9326 | val_loss 0.2907 val_acc 0.9368 | 1.28 min\n",
542
+ "Saved best model: 0.9367720465890182\n",
543
+ "Epoch 5/12 | train_loss 0.2312 train_acc 0.9368 | val_loss 0.3151 val_acc 0.9285 | 1.26 min\n",
544
+ "Epoch 6/12 | train_loss 0.2106 train_acc 0.9427 | val_loss 0.3328 val_acc 0.9276 | 1.26 min\n",
545
+ "Epoch 7/12 | train_loss 0.1896 train_acc 0.9473 | val_loss 0.3394 val_acc 0.9301 | 1.28 min\n",
546
+ "Early stopping triggered. Stopping training.\n"
547
+ ]
548
+ }
549
+ ],
550
  "source": [
551
+ "# Training loop with val_loss, val_acc, checkpointing & early stopping\n",
552
  "from sklearn.metrics import accuracy_score\n",
553
  "import time, os\n",
554
+ "import pandas as pd\n",
555
  "\n",
556
+ "os.makedirs(OUT_DIR, exist_ok=True)\n",
 
 
557
  "best_val_acc = 0.0\n",
558
+ "patience = 3\n",
559
+ "wait = 0\n",
560
+ "logs = []\n",
561
  "\n",
562
  "for epoch in range(1, EPOCHS+1):\n",
563
+ " t0 = time.time()\n",
564
+ " # train\n",
565
  " model.train()\n",
566
+ " train_loss = 0.0\n",
567
  " all_preds, all_labels = [], []\n",
 
 
 
568
  " for imgs, labels in train_loader:\n",
569
  " imgs = imgs.to(device)\n",
570
  " labels = labels.to(device)\n",
 
573
  " loss = loss_fn(logits, labels)\n",
574
  " loss.backward()\n",
575
  " optimizer.step()\n",
576
+ " train_loss += loss.item() * imgs.size(0)\n",
577
  " preds = logits.argmax(dim=1).cpu().numpy()\n",
578
  " all_preds.extend(preds.tolist())\n",
579
  " all_labels.extend(labels.cpu().numpy().tolist())\n",
580
+ " train_loss = train_loss / len(train_loader.dataset)\n",
 
581
  " train_acc = accuracy_score(all_labels, all_preds)\n",
582
  "\n",
583
+ " # val\n",
584
  " model.eval()\n",
585
+ " val_loss = 0.0\n",
586
  " v_preds, v_labels = [], []\n",
587
  " with torch.no_grad():\n",
588
  " for imgs, labels in val_loader:\n",
589
  " imgs = imgs.to(device)\n",
590
  " labels = labels.to(device)\n",
591
  " logits = model(imgs)\n",
592
+ " loss = loss_fn(logits, labels)\n",
593
+ " val_loss += loss.item() * imgs.size(0)\n",
594
  " preds = logits.argmax(dim=1).cpu().numpy()\n",
595
  " v_preds.extend(preds.tolist())\n",
596
  " v_labels.extend(labels.cpu().numpy().tolist())\n",
597
+ " val_loss = val_loss / len(val_loader.dataset)\n",
 
598
  " val_acc = accuracy_score(v_labels, v_preds)\n",
599
  "\n",
600
+ " # scheduler step\n",
601
+ " scheduler.step(val_loss)\n",
602
+ "\n",
603
+ " elapsed = (time.time() - t0) / 60\n",
604
+ " print(f\"Epoch {epoch}/{EPOCHS} | train_loss {train_loss:.4f} train_acc {train_acc:.4f} | val_loss {val_loss:.4f} val_acc {val_acc:.4f} | {elapsed:.2f} min\")\n",
605
+ " logs.append({\"epoch\":epoch,\"train_loss\":train_loss,\"train_acc\":train_acc,\"val_loss\":val_loss,\"val_acc\":val_acc})\n",
606
  "\n",
607
+ " # checkpoint\n",
608
+ " ckpt = os.path.join(OUT_DIR, f\"epoch{epoch}.pth\")\n",
609
+ " torch.save(model.state_dict(), ckpt)\n",
610
  " if val_acc > best_val_acc:\n",
611
  " best_val_acc = val_acc\n",
612
+ " torch.save(model.state_dict(), os.path.join(OUT_DIR, \"best_model.pth\"))\n",
613
+ " print(\"Saved best model:\", best_val_acc)\n",
614
+ " wait = 0\n",
615
+ " else:\n",
616
+ " wait += 1\n",
617
+ " if wait >= patience:\n",
618
+ " print(\"Early stopping triggered. Stopping training.\")\n",
619
+ " break\n",
620
+ "\n",
621
+ "# save logs to csv\n",
622
+ "pd.DataFrame(logs).to_csv(os.path.join(OUT_DIR, \"training_log.csv\"), index=False)"
623
+ ]
624
+ },
625
+ {
626
+ "cell_type": "code",
627
+ "execution_count": null,
628
+ "metadata": {
629
+ "id": "I2WUIxNnZtVI"
630
+ },
631
+ "outputs": [],
632
+ "source": []
633
+ },
634
+ {
635
+ "cell_type": "code",
636
+ "source": [
637
+ "# Evaluate on test set (classification report + confusion matrix)\n",
638
+ "from sklearn.metrics import classification_report, confusion_matrix\n",
639
+ "import numpy as np\n",
640
+ "\n",
641
+ "# load best model\n",
642
+ "best_path = os.path.join(OUT_DIR, \"best_model.pth\")\n",
643
+ "model.load_state_dict(torch.load(best_path, map_location=device))\n",
644
+ "model.eval()\n",
645
+ "\n",
646
+ "y_true, y_pred = [], []\n",
647
+ "with torch.no_grad():\n",
648
+ " for imgs, labels in test_loader:\n",
649
+ " imgs = imgs.to(device)\n",
650
+ " labels = labels.to(device)\n",
651
+ " logits = model(imgs)\n",
652
+ " preds = logits.argmax(dim=1).cpu().numpy()\n",
653
+ " y_pred.extend(preds.tolist())\n",
654
+ " y_true.extend(labels.cpu().numpy().tolist())\n",
655
+ "\n",
656
+ "print(\"Classification Report:\")\n",
657
+ "print(classification_report(y_true, y_pred, target_names=[IDX2LABEL[i] for i in range(NUM_CLASSES)]))\n",
658
+ "print(\"Confusion matrix:\")\n",
659
+ "print(confusion_matrix(y_true, y_pred))\n"
660
  ],
661
  "metadata": {
662
  "colab": {
663
  "base_uri": "https://localhost:8080/"
664
  },
665
+ "id": "nbkHkAPELGYq",
666
+ "outputId": "19358314-2ce1-4bf4-8844-45a0d723100c"
667
  },
668
+ "execution_count": 28,
669
  "outputs": [
670
  {
671
  "output_type": "stream",
672
  "name": "stdout",
673
  "text": [
674
+ "Classification Report:\n",
675
+ " precision recall f1-score support\n",
676
+ "\n",
677
+ " ct_tumor 0.98 0.99 0.99 290\n",
678
+ " glioma 0.95 0.93 0.94 84\n",
679
+ " meningioma 0.77 0.95 0.85 139\n",
680
+ " mri_tumor 0.80 0.53 0.64 73\n",
681
+ " no_tumor 1.00 0.99 0.99 538\n",
682
+ " pituitary 0.99 0.94 0.96 79\n",
683
+ "\n",
684
+ " accuracy 0.95 1203\n",
685
+ " macro avg 0.91 0.89 0.90 1203\n",
686
+ "weighted avg 0.95 0.95 0.95 1203\n",
687
+ "\n",
688
+ "Confusion matrix:\n",
689
+ "[[288 0 0 0 2 0]\n",
690
+ " [ 0 78 5 1 0 0]\n",
691
+ " [ 0 0 132 7 0 0]\n",
692
+ " [ 0 4 29 39 0 1]\n",
693
+ " [ 5 0 2 0 531 0]\n",
694
+ " [ 0 0 3 2 0 74]]\n"
695
  ]
696
  }
697
  ]
 
700
  "cell_type": "code",
701
  "source": [],
702
  "metadata": {
703
+ "id": "QZCaUaWzLGbS"
704
+ },
705
+ "execution_count": null,
706
+ "outputs": []
707
+ },
708
+ {
709
+ "cell_type": "code",
710
+ "source": [],
711
+ "metadata": {
712
+ "id": "QMaVFK9xLGgu"
713
  },
714
  "execution_count": null,
715
  "outputs": []
716
  },
717
  {
718
  "cell_type": "code",
719
+ "execution_count": 33,
720
+ "metadata": {
721
+ "colab": {
722
+ "base_uri": "https://localhost:8080/"
723
+ },
724
+ "id": "H2BqijX2Zupx",
725
+ "outputId": "3ed96f69-87a9-4b2b-839e-2a8680c672e9"
726
+ },
727
+ "outputs": [
728
+ {
729
+ "output_type": "stream",
730
+ "name": "stdout",
731
+ "text": [
732
+ "Sample: /root/.cache/kagglehub/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/versions/1/Dataset/Brain Tumor CT scan Images/Tumor/ct_tumor (137).jpg\n",
733
+ "[{'label': 'ct_tumor', 'score': 0.9998412132263184}, {'label': 'meningioma', 'score': 0.00015001322026364505}, {'label': 'mri_tumor', 'score': 5.044169483880978e-06}]\n"
734
+ ]
735
+ }
736
+ ],
737
  "source": [
738
  "# Load best model & test a single image (inference helper)\n",
739
  "import torch.nn.functional as F\n",
740
  "from PIL import Image\n",
741
  "from torchvision import transforms\n",
742
  "\n",
743
+ "def load_model_for_infer(ckpt_path=os.path.join(OUT_DIR,\"best_model.pth\")):\n",
744
+ " m = timm.create_model('efficientnet_b0', pretrained=False, num_classes=NUM_CLASSES)\n",
745
+ " m.load_state_dict(torch.load(ckpt_path, map_location=device))\n",
746
+ " m.to(device).eval()\n",
747
+ " return m\n",
 
748
  "\n",
749
+ "# reuse tf from earlier\n",
750
+ "def predict_topk_pil(model, pil_img, topk=TOP_K):\n",
751
+ " x = val_tf(pil_img).unsqueeze(0).to(device)\n",
 
 
 
 
 
 
 
752
  " with torch.no_grad():\n",
753
  " logits = model(x)\n",
754
  " probs = F.softmax(logits, dim=1).cpu().numpy().ravel()\n",
755
  " idxs = probs.argsort()[::-1][:topk]\n",
756
  " return [{\"label\": IDX2LABEL[int(i)], \"score\": float(probs[int(i)])} for i in idxs]\n",
757
  "\n",
758
+ "# quick example\n",
759
+ "infer_model = load_model_for_infer()\n",
760
+ "sample_path = val_df.iloc[110].image\n",
761
  "img = Image.open(sample_path).convert(\"RGB\")\n",
762
+ "print(\"Sample:\", sample_path)\n",
763
+ "print(predict_topk_pil(infer_model, img, topk=3))"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
764
  ]
765
  },
766
  {
767
  "cell_type": "code",
768
+ "execution_count": null,
769
  "metadata": {
770
  "id": "RAZRNlq2ZutV"
771
  },
772
+ "outputs": [],
773
+ "source": []
774
  },
775
  {
776
  "cell_type": "code",
777
+ "execution_count": null,
778
  "metadata": {
779
  "id": "chDN6CJmc8h0"
780
  },
781
+ "outputs": [],
782
+ "source": []
783
+ },
784
+ {
785
+ "cell_type": "code",
786
+ "source": [
787
+ "# gradio App plus Openai"
788
+ ],
789
+ "metadata": {
790
+ "id": "LsQ5FIkOOHch"
791
+ },
792
  "execution_count": null,
793
  "outputs": []
794
  },
795
  {
796
  "cell_type": "code",
797
+ "execution_count": 39,
798
+ "metadata": {
799
+ "id": "WSygonZYdvpg"
800
+ },
801
+ "outputs": [],
802
  "source": [
803
  "sample_paths = [\n",
804
+ " \"/root/.cache/kagglehub/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/versions/1/Dataset/Brain Tumor CT scan Images/Tumor/ct_tumor (10).jpg\",\n",
805
+ " \"/root/.cache/kagglehub/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/versions/1/Dataset/Brain Tumor CT scan Images/Healthy/ct_healthy (1).jpg\",\n",
806
+ " \"/root/.cache/kagglehub/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/versions/1/Dataset/Brain Tumor MRI images/Tumor/glioma (10).jpg\",\n",
807
+ " \"/root/.cache/kagglehub/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/versions/1/Dataset/Brain Tumor MRI images/Tumor/meningioma (1056).jpg\",\n",
808
  "]\n",
809
  "\n",
810
  "OPENAI_API_KEY = \"\""
811
+ ]
 
 
 
 
 
812
  },
813
  {
814
  "cell_type": "code",
815
+ "execution_count": 43,
816
+ "metadata": {
817
+ "colab": {
818
+ "base_uri": "https://localhost:8080/",
819
+ "height": 680
820
+ },
821
+ "id": "S_rWA1Xcdvx8",
822
+ "outputId": "55c18697-f04d-45bd-bddc-f59e4a729c3d"
823
+ },
824
+ "outputs": [
825
+ {
826
+ "output_type": "stream",
827
+ "name": "stdout",
828
+ "text": [
829
+ "Loading model...\n",
830
+ "Model loaded on cuda\n"
831
+ ]
832
+ },
833
+ {
834
+ "output_type": "stream",
835
+ "name": "stderr",
836
+ "text": [
837
+ "/tmp/ipython-input-2087682989.py:119: UserWarning: You have not specified a value for the `type` parameter. Defaulting to the 'tuples' format for chatbot messages, but this is deprecated and will be removed in a future version of Gradio. Please set type='messages' instead, which uses openai-style dictionaries with 'role' and 'content' keys.\n",
838
+ " chatbot = gr.Chatbot(label=\"Assistant\")\n"
839
+ ]
840
+ },
841
+ {
842
+ "output_type": "stream",
843
+ "name": "stdout",
844
+ "text": [
845
+ "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n",
846
+ "* Running on public URL: https://a1ab9ef721b239799a.gradio.live\n",
847
+ "\n",
848
+ "This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
849
+ ]
850
+ },
851
+ {
852
+ "output_type": "display_data",
853
+ "data": {
854
+ "text/plain": [
855
+ "<IPython.core.display.HTML object>"
856
+ ],
857
+ "text/html": [
858
+ "<div><iframe src=\"https://a1ab9ef721b239799a.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
859
+ ]
860
+ },
861
+ "metadata": {}
862
+ },
863
+ {
864
+ "output_type": "execute_result",
865
+ "data": {
866
+ "text/plain": []
867
+ },
868
+ "metadata": {},
869
+ "execution_count": 43
870
+ }
871
+ ],
872
  "source": [
873
  "import os\n",
874
  "from pathlib import Path\n",
 
881
  "import openai\n",
882
  "from openai import OpenAI\n",
883
  "\n",
884
+ "# ------------------ USER CONFIG (edit as needed) ------------------\n",
885
+ "CKPT_PATH = \"/content/output/best_model.pth\" # path to trained checkpoint\n",
 
886
  "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
887
  "BACKBONE = \"efficientnet_b0\"\n",
 
888
  "IMG_SIZE = 224\n",
889
  "\n",
890
+ "# Paste your 4 sample image absolute paths here (or ensure sample_paths variable already exists)\n",
891
+ "# Example:\n",
892
+ "# sample_paths = [\"/content/samples/ct1.jpg\", \"/content/samples/mri_glioma1.jpg\", ...]\n",
893
+ "sample_paths = sample_paths # keep your existing variable/assignment\n",
894
  "\n",
895
+ "OPENAI_API_KEY = OPENAI_API_KEY # keep as before (e.g., from env or variable)\n",
896
  "if OPENAI_API_KEY:\n",
897
  " openai.api_key = OPENAI_API_KEY\n",
898
  "\n",
899
+ "# ------------------ UPDATED LABEL MAP for combined CT+MRI ------------------\n",
900
+ "# Minimal change: include ct_tumor and MRI subtypes + no_tumor\n",
901
+ "IDX2LABEL = {\n",
902
+ " 0: \"ct_tumor\",\n",
903
+ " 1: \"glioma\",\n",
904
+ " 2: \"meningioma\",\n",
905
+ " 3: \"mri_tumor\",\n",
906
+ " 4: \"no_tumor\",\n",
907
+ " 5: \"pituitary\"\n",
908
+ "}\n",
909
+ "\n",
910
+ "NUM_CLASSES = len(IDX2LABEL)\n",
911
  "# -------------------------------------------------\n",
912
  "\n",
913
  "# Validate sample paths\n",
 
943
  " idxs = probs.argsort()[::-1][:topk]\n",
944
  " return [{\"label\": IDX2LABEL[int(i)], \"score\": float(probs[int(i)])} for i in idxs]\n",
945
  "\n",
946
+ "# OpenAI system prompt (unchanged)\n",
947
  "SYSTEM_TEMPLATE = \"\"\"You are an educational radiology assistant. The model predicted:\n",
948
  "{preds}\n",
949
  "\n",
 
973
  " )\n",
974
  " return resp.choices[0].message.content.strip()\n",
975
  "\n",
 
976
  "# Prepare thumbnails (just images)\n",
977
  "thumbs = [Image.open(p).convert(\"RGB\") for p in sample_paths]\n",
978
  "\n",
979
+ "# Gradio UI (unchanged structure)\n",
980
  "with gr.Blocks() as demo:\n",
981
+ " gr.Markdown(\"## Brain MRI+CT multi-class demo — Educational only\")\n",
982
  " with gr.Row():\n",
983
  " with gr.Column():\n",
984
  " # Use a simple list of PIL thumbnails for the gallery value\n",
 
996
  " selected_path_state = gr.State(value=None)\n",
997
  " preds_state = gr.State(value=None)\n",
998
  "\n",
999
+ " # select image from gallery -> store selected path string\n",
1000
  " def on_select(evt: gr.SelectData):\n",
1001
+ " idx = int(evt.index)\n",
1002
+ " path = sample_paths[idx]\n",
1003
  " return path\n",
1004
  "\n",
1005
  " # ensure the gallery select writes only the path to the state\n",
1006
  " gallery.select(fn=on_select, inputs=None, outputs=[selected_path_state])\n",
1007
  "\n",
1008
+ " # analyze (unchanged except label names)\n",
1009
  " def analyze(path):\n",
1010
  " try:\n",
1011
  " if path is None:\n",
 
1022
  "\n",
1023
  " analyze_btn.click(fn=analyze, inputs=selected_path_state, outputs=[preds_output, chatbot, preds_state])\n",
1024
  "\n",
1025
+ " # chat (keeps same logic)\n",
1026
  " def chat(chat_history, msg, preds_text):\n",
1027
  " if preds_text is None:\n",
1028
  " return chat_history+[(\"AI\",\"Please analyze an image first.\")], \"\"\n",
 
1053
  "\n",
1054
  " gr.Markdown(\"⚠️ Educational use only — not a medical diagnosis.\")\n",
1055
  "\n",
1056
+ "demo.launch(share=True)"
1057
+ ]
1058
+ },
1059
+ {
1060
+ "cell_type": "code",
1061
+ "execution_count": null,
1062
+ "metadata": {
1063
+ "id": "TrS_HlY7d8s3"
1064
+ },
1065
+ "outputs": [],
1066
+ "source": []
1067
+ },
1068
+ {
1069
+ "cell_type": "code",
1070
+ "execution_count": null,
1071
  "metadata": {
1072
  "colab": {
1073
+ "background_save": true
1074
+ },
1075
+ "id": "6Yj2jOOY0swp"
1076
+ },
1077
+ "outputs": [],
1078
+ "source": []
1079
+ }
1080
+ ],
1081
+ "metadata": {
1082
+ "accelerator": "GPU",
1083
+ "colab": {
1084
+ "gpuType": "T4",
1085
+ "provenance": []
1086
+ },
1087
+ "kernelspec": {
1088
+ "display_name": "Python 3",
1089
+ "name": "python3"
1090
+ },
1091
+ "language_info": {
1092
+ "name": "python"
1093
+ },
1094
+ "widgets": {
1095
+ "application/vnd.jupyter.widget-state+json": {
1096
+ "fd1f533d7d814e10b252bf66f8c0697e": {
1097
+ "model_module": "@jupyter-widgets/controls",
1098
+ "model_name": "HBoxModel",
1099
+ "model_module_version": "1.5.0",
1100
+ "state": {
1101
+ "_dom_classes": [],
1102
+ "_model_module": "@jupyter-widgets/controls",
1103
+ "_model_module_version": "1.5.0",
1104
+ "_model_name": "HBoxModel",
1105
+ "_view_count": null,
1106
+ "_view_module": "@jupyter-widgets/controls",
1107
+ "_view_module_version": "1.5.0",
1108
+ "_view_name": "HBoxView",
1109
+ "box_style": "",
1110
+ "children": [
1111
+ "IPY_MODEL_243ec1d23b8e469e833db01f89189639",
1112
+ "IPY_MODEL_9bcf41c7fe1e425aae4e37dcfca6a67c",
1113
+ "IPY_MODEL_f09e93d2fdc64071b3b9860fd7782c58"
1114
+ ],
1115
+ "layout": "IPY_MODEL_0a03432f7ce14c8c98090f846a512beb"
1116
+ }
1117
+ },
1118
+ "243ec1d23b8e469e833db01f89189639": {
1119
+ "model_module": "@jupyter-widgets/controls",
1120
+ "model_name": "HTMLModel",
1121
+ "model_module_version": "1.5.0",
1122
+ "state": {
1123
+ "_dom_classes": [],
1124
+ "_model_module": "@jupyter-widgets/controls",
1125
+ "_model_module_version": "1.5.0",
1126
+ "_model_name": "HTMLModel",
1127
+ "_view_count": null,
1128
+ "_view_module": "@jupyter-widgets/controls",
1129
+ "_view_module_version": "1.5.0",
1130
+ "_view_name": "HTMLView",
1131
+ "description": "",
1132
+ "description_tooltip": null,
1133
+ "layout": "IPY_MODEL_1da74b90beb44d85a22ca57ee52b1e9d",
1134
+ "placeholder": "​",
1135
+ "style": "IPY_MODEL_a7c35ed82c20419b9b559b217d924a6a",
1136
+ "value": "model.safetensors: 100%"
1137
+ }
1138
+ },
1139
+ "9bcf41c7fe1e425aae4e37dcfca6a67c": {
1140
+ "model_module": "@jupyter-widgets/controls",
1141
+ "model_name": "FloatProgressModel",
1142
+ "model_module_version": "1.5.0",
1143
+ "state": {
1144
+ "_dom_classes": [],
1145
+ "_model_module": "@jupyter-widgets/controls",
1146
+ "_model_module_version": "1.5.0",
1147
+ "_model_name": "FloatProgressModel",
1148
+ "_view_count": null,
1149
+ "_view_module": "@jupyter-widgets/controls",
1150
+ "_view_module_version": "1.5.0",
1151
+ "_view_name": "ProgressView",
1152
+ "bar_style": "success",
1153
+ "description": "",
1154
+ "description_tooltip": null,
1155
+ "layout": "IPY_MODEL_0a302207f9884533b49fb150411169d2",
1156
+ "max": 21355344,
1157
+ "min": 0,
1158
+ "orientation": "horizontal",
1159
+ "style": "IPY_MODEL_bc007e650c0f4f619946055a6c72d132",
1160
+ "value": 21355344
1161
+ }
1162
+ },
1163
+ "f09e93d2fdc64071b3b9860fd7782c58": {
1164
+ "model_module": "@jupyter-widgets/controls",
1165
+ "model_name": "HTMLModel",
1166
+ "model_module_version": "1.5.0",
1167
+ "state": {
1168
+ "_dom_classes": [],
1169
+ "_model_module": "@jupyter-widgets/controls",
1170
+ "_model_module_version": "1.5.0",
1171
+ "_model_name": "HTMLModel",
1172
+ "_view_count": null,
1173
+ "_view_module": "@jupyter-widgets/controls",
1174
+ "_view_module_version": "1.5.0",
1175
+ "_view_name": "HTMLView",
1176
+ "description": "",
1177
+ "description_tooltip": null,
1178
+ "layout": "IPY_MODEL_20b5ba3ee19446dc89ddb5222be6f5f2",
1179
+ "placeholder": "​",
1180
+ "style": "IPY_MODEL_1e35273de451460aa8c82df17b3ca23c",
1181
+ "value": " 21.4M/21.4M [00:00&lt;00:00, 43.5MB/s]"
1182
+ }
1183
+ },
1184
+ "0a03432f7ce14c8c98090f846a512beb": {
1185
+ "model_module": "@jupyter-widgets/base",
1186
+ "model_name": "LayoutModel",
1187
+ "model_module_version": "1.2.0",
1188
+ "state": {
1189
+ "_model_module": "@jupyter-widgets/base",
1190
+ "_model_module_version": "1.2.0",
1191
+ "_model_name": "LayoutModel",
1192
+ "_view_count": null,
1193
+ "_view_module": "@jupyter-widgets/base",
1194
+ "_view_module_version": "1.2.0",
1195
+ "_view_name": "LayoutView",
1196
+ "align_content": null,
1197
+ "align_items": null,
1198
+ "align_self": null,
1199
+ "border": null,
1200
+ "bottom": null,
1201
+ "display": null,
1202
+ "flex": null,
1203
+ "flex_flow": null,
1204
+ "grid_area": null,
1205
+ "grid_auto_columns": null,
1206
+ "grid_auto_flow": null,
1207
+ "grid_auto_rows": null,
1208
+ "grid_column": null,
1209
+ "grid_gap": null,
1210
+ "grid_row": null,
1211
+ "grid_template_areas": null,
1212
+ "grid_template_columns": null,
1213
+ "grid_template_rows": null,
1214
+ "height": null,
1215
+ "justify_content": null,
1216
+ "justify_items": null,
1217
+ "left": null,
1218
+ "margin": null,
1219
+ "max_height": null,
1220
+ "max_width": null,
1221
+ "min_height": null,
1222
+ "min_width": null,
1223
+ "object_fit": null,
1224
+ "object_position": null,
1225
+ "order": null,
1226
+ "overflow": null,
1227
+ "overflow_x": null,
1228
+ "overflow_y": null,
1229
+ "padding": null,
1230
+ "right": null,
1231
+ "top": null,
1232
+ "visibility": null,
1233
+ "width": null
1234
+ }
1235
  },
1236
+ "1da74b90beb44d85a22ca57ee52b1e9d": {
1237
+ "model_module": "@jupyter-widgets/base",
1238
+ "model_name": "LayoutModel",
1239
+ "model_module_version": "1.2.0",
1240
+ "state": {
1241
+ "_model_module": "@jupyter-widgets/base",
1242
+ "_model_module_version": "1.2.0",
1243
+ "_model_name": "LayoutModel",
1244
+ "_view_count": null,
1245
+ "_view_module": "@jupyter-widgets/base",
1246
+ "_view_module_version": "1.2.0",
1247
+ "_view_name": "LayoutView",
1248
+ "align_content": null,
1249
+ "align_items": null,
1250
+ "align_self": null,
1251
+ "border": null,
1252
+ "bottom": null,
1253
+ "display": null,
1254
+ "flex": null,
1255
+ "flex_flow": null,
1256
+ "grid_area": null,
1257
+ "grid_auto_columns": null,
1258
+ "grid_auto_flow": null,
1259
+ "grid_auto_rows": null,
1260
+ "grid_column": null,
1261
+ "grid_gap": null,
1262
+ "grid_row": null,
1263
+ "grid_template_areas": null,
1264
+ "grid_template_columns": null,
1265
+ "grid_template_rows": null,
1266
+ "height": null,
1267
+ "justify_content": null,
1268
+ "justify_items": null,
1269
+ "left": null,
1270
+ "margin": null,
1271
+ "max_height": null,
1272
+ "max_width": null,
1273
+ "min_height": null,
1274
+ "min_width": null,
1275
+ "object_fit": null,
1276
+ "object_position": null,
1277
+ "order": null,
1278
+ "overflow": null,
1279
+ "overflow_x": null,
1280
+ "overflow_y": null,
1281
+ "padding": null,
1282
+ "right": null,
1283
+ "top": null,
1284
+ "visibility": null,
1285
+ "width": null
1286
+ }
1287
  },
1288
+ "a7c35ed82c20419b9b559b217d924a6a": {
1289
+ "model_module": "@jupyter-widgets/controls",
1290
+ "model_name": "DescriptionStyleModel",
1291
+ "model_module_version": "1.5.0",
1292
+ "state": {
1293
+ "_model_module": "@jupyter-widgets/controls",
1294
+ "_model_module_version": "1.5.0",
1295
+ "_model_name": "DescriptionStyleModel",
1296
+ "_view_count": null,
1297
+ "_view_module": "@jupyter-widgets/base",
1298
+ "_view_module_version": "1.2.0",
1299
+ "_view_name": "StyleView",
1300
+ "description_width": ""
1301
+ }
1302
  },
1303
+ "0a302207f9884533b49fb150411169d2": {
1304
+ "model_module": "@jupyter-widgets/base",
1305
+ "model_name": "LayoutModel",
1306
+ "model_module_version": "1.2.0",
1307
+ "state": {
1308
+ "_model_module": "@jupyter-widgets/base",
1309
+ "_model_module_version": "1.2.0",
1310
+ "_model_name": "LayoutModel",
1311
+ "_view_count": null,
1312
+ "_view_module": "@jupyter-widgets/base",
1313
+ "_view_module_version": "1.2.0",
1314
+ "_view_name": "LayoutView",
1315
+ "align_content": null,
1316
+ "align_items": null,
1317
+ "align_self": null,
1318
+ "border": null,
1319
+ "bottom": null,
1320
+ "display": null,
1321
+ "flex": null,
1322
+ "flex_flow": null,
1323
+ "grid_area": null,
1324
+ "grid_auto_columns": null,
1325
+ "grid_auto_flow": null,
1326
+ "grid_auto_rows": null,
1327
+ "grid_column": null,
1328
+ "grid_gap": null,
1329
+ "grid_row": null,
1330
+ "grid_template_areas": null,
1331
+ "grid_template_columns": null,
1332
+ "grid_template_rows": null,
1333
+ "height": null,
1334
+ "justify_content": null,
1335
+ "justify_items": null,
1336
+ "left": null,
1337
+ "margin": null,
1338
+ "max_height": null,
1339
+ "max_width": null,
1340
+ "min_height": null,
1341
+ "min_width": null,
1342
+ "object_fit": null,
1343
+ "object_position": null,
1344
+ "order": null,
1345
+ "overflow": null,
1346
+ "overflow_x": null,
1347
+ "overflow_y": null,
1348
+ "padding": null,
1349
+ "right": null,
1350
+ "top": null,
1351
+ "visibility": null,
1352
+ "width": null
1353
+ }
1354
  },
1355
+ "bc007e650c0f4f619946055a6c72d132": {
1356
+ "model_module": "@jupyter-widgets/controls",
1357
+ "model_name": "ProgressStyleModel",
1358
+ "model_module_version": "1.5.0",
1359
+ "state": {
1360
+ "_model_module": "@jupyter-widgets/controls",
1361
+ "_model_module_version": "1.5.0",
1362
+ "_model_name": "ProgressStyleModel",
1363
+ "_view_count": null,
1364
+ "_view_module": "@jupyter-widgets/base",
1365
+ "_view_module_version": "1.2.0",
1366
+ "_view_name": "StyleView",
1367
+ "bar_color": null,
1368
+ "description_width": ""
1369
+ }
1370
  },
1371
+ "20b5ba3ee19446dc89ddb5222be6f5f2": {
1372
+ "model_module": "@jupyter-widgets/base",
1373
+ "model_name": "LayoutModel",
1374
+ "model_module_version": "1.2.0",
1375
+ "state": {
1376
+ "_model_module": "@jupyter-widgets/base",
1377
+ "_model_module_version": "1.2.0",
1378
+ "_model_name": "LayoutModel",
1379
+ "_view_count": null,
1380
+ "_view_module": "@jupyter-widgets/base",
1381
+ "_view_module_version": "1.2.0",
1382
+ "_view_name": "LayoutView",
1383
+ "align_content": null,
1384
+ "align_items": null,
1385
+ "align_self": null,
1386
+ "border": null,
1387
+ "bottom": null,
1388
+ "display": null,
1389
+ "flex": null,
1390
+ "flex_flow": null,
1391
+ "grid_area": null,
1392
+ "grid_auto_columns": null,
1393
+ "grid_auto_flow": null,
1394
+ "grid_auto_rows": null,
1395
+ "grid_column": null,
1396
+ "grid_gap": null,
1397
+ "grid_row": null,
1398
+ "grid_template_areas": null,
1399
+ "grid_template_columns": null,
1400
+ "grid_template_rows": null,
1401
+ "height": null,
1402
+ "justify_content": null,
1403
+ "justify_items": null,
1404
+ "left": null,
1405
+ "margin": null,
1406
+ "max_height": null,
1407
+ "max_width": null,
1408
+ "min_height": null,
1409
+ "min_width": null,
1410
+ "object_fit": null,
1411
+ "object_position": null,
1412
+ "order": null,
1413
+ "overflow": null,
1414
+ "overflow_x": null,
1415
+ "overflow_y": null,
1416
+ "padding": null,
1417
+ "right": null,
1418
+ "top": null,
1419
+ "visibility": null,
1420
+ "width": null
1421
+ }
1422
+ },
1423
+ "1e35273de451460aa8c82df17b3ca23c": {
1424
+ "model_module": "@jupyter-widgets/controls",
1425
+ "model_name": "DescriptionStyleModel",
1426
+ "model_module_version": "1.5.0",
1427
+ "state": {
1428
+ "_model_module": "@jupyter-widgets/controls",
1429
+ "_model_module_version": "1.5.0",
1430
+ "_model_name": "DescriptionStyleModel",
1431
+ "_view_count": null,
1432
+ "_view_module": "@jupyter-widgets/base",
1433
+ "_view_module_version": "1.2.0",
1434
+ "_view_name": "StyleView",
1435
+ "description_width": ""
1436
+ }
1437
  }
1438
+ }
 
 
 
 
 
 
 
 
 
1439
  }
1440
+ },
1441
+ "nbformat": 4,
1442
+ "nbformat_minor": 0
1443
  }
app.py CHANGED
@@ -9,28 +9,45 @@ import gradio as gr
9
  import openai
10
  from openai import OpenAI
11
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # ------------------ USER CONFIG ------------------
14
- CKPT_PATH = "model/best_model_multiclass.pth" # path to trained checkpoint
 
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
  BACKBONE = "efficientnet_b0"
17
- NUM_CLASSES = 4
18
  IMG_SIZE = 224
19
 
20
- # Paste your 4 sample image absolute paths here
21
- # e.g. sample_paths = ["/content/a.png", "/content/b.png", ...]
22
- sample_paths = [
23
- "sample_images/image.jpg",
24
- "sample_images/image(2).jpg",
25
- "sample_images/image(3).jpg",
26
- "sample_images/image(23).jpg",
27
- ]
28
 
29
- OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
30
  if OPENAI_API_KEY:
31
  openai.api_key = OPENAI_API_KEY
32
 
33
- IDX2LABEL = {0: "glioma_tumor", 1: "meningioma_tumor", 2: "pituitary_tumor", 3: "no_tumor"}
 
 
 
 
 
 
 
 
 
 
 
34
  # -------------------------------------------------
35
 
36
  # Validate sample paths
@@ -66,7 +83,7 @@ def predict_topk_from_pil(pil_img, topk=3):
66
  idxs = probs.argsort()[::-1][:topk]
67
  return [{"label": IDX2LABEL[int(i)], "score": float(probs[int(i)])} for i in idxs]
68
 
69
- # OpenAI system prompt
70
  SYSTEM_TEMPLATE = """You are an educational radiology assistant. The model predicted:
71
  {preds}
72
 
@@ -96,13 +113,12 @@ def call_openai_seed(preds_text):
96
  )
97
  return resp.choices[0].message.content.strip()
98
 
99
-
100
  # Prepare thumbnails (just images)
101
  thumbs = [Image.open(p).convert("RGB") for p in sample_paths]
102
 
103
- # Gradio UI
104
  with gr.Blocks() as demo:
105
- gr.Markdown("## Brain MRI multi-class demo — Educational only")
106
  with gr.Row():
107
  with gr.Column():
108
  # Use a simple list of PIL thumbnails for the gallery value
@@ -120,16 +136,16 @@ with gr.Blocks() as demo:
120
  selected_path_state = gr.State(value=None)
121
  preds_state = gr.State(value=None)
122
 
123
- # select image from gallery
124
  def on_select(evt: gr.SelectData):
125
- idx = evt.index
126
- path = sample_paths[idx] # <-- return only the path (string)
127
  return path
128
 
129
  # ensure the gallery select writes only the path to the state
130
  gallery.select(fn=on_select, inputs=None, outputs=[selected_path_state])
131
 
132
- # analyze
133
  def analyze(path):
134
  try:
135
  if path is None:
@@ -146,7 +162,7 @@ with gr.Blocks() as demo:
146
 
147
  analyze_btn.click(fn=analyze, inputs=selected_path_state, outputs=[preds_output, chatbot, preds_state])
148
 
149
- # chat
150
  def chat(chat_history, msg, preds_text):
151
  if preds_text is None:
152
  return chat_history+[("AI","Please analyze an image first.")], ""
@@ -177,4 +193,4 @@ with gr.Blocks() as demo:
177
 
178
  gr.Markdown("⚠️ Educational use only — not a medical diagnosis.")
179
 
180
- demo.launch(share=True)
 
9
  import openai
10
  from openai import OpenAI
11
 
12
+ sample_paths = [
13
+ "sample_images/ct_healthy (1).jpg",
14
+ "sample_images/ct_tumor (1).png",
15
+ "sample_images/glioma (1).jpg",
16
+ "sample_images/pituitary (244).jpg",
17
+ ]
18
+
19
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
20
+ if OPENAI_API_KEY:
21
+ openai.api_key = OPENAI_API_KEY
22
 
23
+
24
+ # ------------------ USER CONFIG (edit as needed) ------------------
25
+ CKPT_PATH = "model/best_model.pth" # path to trained checkpoint
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
  BACKBONE = "efficientnet_b0"
 
28
  IMG_SIZE = 224
29
 
30
+ # Paste your 4 sample image absolute paths here (or ensure sample_paths variable already exists)
31
+ # Example:
32
+ # sample_paths = ["/content/samples/ct1.jpg", "/content/samples/mri_glioma1.jpg", ...]
33
+ sample_paths = sample_paths # keep your existing variable/assignment
 
 
 
 
34
 
35
+ OPENAI_API_KEY = OPENAI_API_KEY # keep as before (e.g., from env or variable)
36
  if OPENAI_API_KEY:
37
  openai.api_key = OPENAI_API_KEY
38
 
39
+ # ------------------ UPDATED LABEL MAP for combined CT+MRI ------------------
40
+ # Minimal change: include ct_tumor and MRI subtypes + no_tumor
41
+ IDX2LABEL = {
42
+ 0: "ct_tumor",
43
+ 1: "glioma",
44
+ 2: "meningioma",
45
+ 3: "mri_tumor",
46
+ 4: "no_tumor",
47
+ 5: "pituitary"
48
+ }
49
+
50
+ NUM_CLASSES = len(IDX2LABEL)
51
  # -------------------------------------------------
52
 
53
  # Validate sample paths
 
83
  idxs = probs.argsort()[::-1][:topk]
84
  return [{"label": IDX2LABEL[int(i)], "score": float(probs[int(i)])} for i in idxs]
85
 
86
+ # OpenAI system prompt (unchanged)
87
  SYSTEM_TEMPLATE = """You are an educational radiology assistant. The model predicted:
88
  {preds}
89
 
 
113
  )
114
  return resp.choices[0].message.content.strip()
115
 
 
116
  # Prepare thumbnails (just images)
117
  thumbs = [Image.open(p).convert("RGB") for p in sample_paths]
118
 
119
+ # Gradio UI (unchanged structure)
120
  with gr.Blocks() as demo:
121
+ gr.Markdown("## Brain MRI+CT multi-class demo — Educational only")
122
  with gr.Row():
123
  with gr.Column():
124
  # Use a simple list of PIL thumbnails for the gallery value
 
136
  selected_path_state = gr.State(value=None)
137
  preds_state = gr.State(value=None)
138
 
139
+ # select image from gallery -> store selected path string
140
  def on_select(evt: gr.SelectData):
141
+ idx = int(evt.index)
142
+ path = sample_paths[idx]
143
  return path
144
 
145
  # ensure the gallery select writes only the path to the state
146
  gallery.select(fn=on_select, inputs=None, outputs=[selected_path_state])
147
 
148
+ # analyze (unchanged except label names)
149
  def analyze(path):
150
  try:
151
  if path is None:
 
162
 
163
  analyze_btn.click(fn=analyze, inputs=selected_path_state, outputs=[preds_output, chatbot, preds_state])
164
 
165
+ # chat (keeps same logic)
166
  def chat(chat_history, msg, preds_text):
167
  if preds_text is None:
168
  return chat_history+[("AI","Please analyze an image first.")], ""
 
193
 
194
  gr.Markdown("⚠️ Educational use only — not a medical diagnosis.")
195
 
196
+ demo.launch(share=True)
model/{best_model_multiclass.pth → best_model.pth} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:90f3a7461ec275012fc223ca25422cfdcfab75379e2b406841a32f34b3f6619c
3
- size 16350387
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66488a18d397630b3f6eda091504415329f432fbf90b0a14d1fae5be90bc7864
3
+ size 16356601
sample_images/ct_healthy (1).jpg ADDED

Git LFS Details

  • SHA256: 39d6ec15338bea283052bf3db28b6b5889b68cbc7a3ff34dc0178ad0a25917fb
  • Pointer size: 129 Bytes
  • Size of remote file: 7.56 kB
sample_images/ct_tumor (1).png ADDED

Git LFS Details

  • SHA256: 70630f19d7605a73d0d65deada213a2b402cd52871af1d78f26fb006404d1a87
  • Pointer size: 131 Bytes
  • Size of remote file: 371 kB
sample_images/glioma (1).jpg ADDED

Git LFS Details

  • SHA256: 13509e45c804999bb70316d0d4ad406f19bd36ebbc97dd8ecc6279699f7f9978
  • Pointer size: 130 Bytes
  • Size of remote file: 46.3 kB
sample_images/image(2).jpg DELETED
Binary file (33.8 kB)
 
sample_images/image(23).jpg DELETED
Binary file (34.2 kB)
 
sample_images/image(3).jpg DELETED
Binary file (16.9 kB)
 
sample_images/image.jpg DELETED
Binary file (49.5 kB)
 
sample_images/pituitary (244).jpg ADDED

Git LFS Details

  • SHA256: 7884435698fdf05b17408e4f304b4876f0c288f0066fee5006ba69e62772afe4
  • Pointer size: 130 Bytes
  • Size of remote file: 53.1 kB