akushwaha.ext commited on
Commit ·
7db32ef
1
Parent(s): 8bf68a3
ct images also added
Browse files- .gitattributes +1 -0
- Brain_Tumor_MRI_Diagnose_plus_openai.ipynb → Brain_Tumor_MRI_CT_Images_Diagnose_plus_openai.ipynb +975 -1011
- app.py +39 -23
- model/{best_model_multiclass.pth → best_model.pth} +2 -2
- sample_images/ct_healthy (1).jpg +3 -0
- sample_images/ct_tumor (1).png +3 -0
- sample_images/glioma (1).jpg +3 -0
- sample_images/image(2).jpg +0 -0
- sample_images/image(23).jpg +0 -0
- sample_images/image(3).jpg +0 -0
- sample_images/image.jpg +0 -0
- sample_images/pituitary (244).jpg +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
sample_images/* filter=lfs diff=lfs merge=lfs -text
|
Brain_Tumor_MRI_Diagnose_plus_openai.ipynb → Brain_Tumor_MRI_CT_Images_Diagnose_plus_openai.ipynb
RENAMED
|
@@ -1,418 +1,82 @@
|
|
| 1 |
{
|
| 2 |
-
"nbformat": 4,
|
| 3 |
-
"nbformat_minor": 0,
|
| 4 |
-
"metadata": {
|
| 5 |
-
"colab": {
|
| 6 |
-
"provenance": [],
|
| 7 |
-
"gpuType": "T4"
|
| 8 |
-
},
|
| 9 |
-
"kernelspec": {
|
| 10 |
-
"name": "python3",
|
| 11 |
-
"display_name": "Python 3"
|
| 12 |
-
},
|
| 13 |
-
"language_info": {
|
| 14 |
-
"name": "python"
|
| 15 |
-
},
|
| 16 |
-
"accelerator": "GPU",
|
| 17 |
-
"widgets": {
|
| 18 |
-
"application/vnd.jupyter.widget-state+json": {
|
| 19 |
-
"1b0dc97d07a94fca9a24670e3813ff29": {
|
| 20 |
-
"model_module": "@jupyter-widgets/controls",
|
| 21 |
-
"model_name": "HBoxModel",
|
| 22 |
-
"model_module_version": "1.5.0",
|
| 23 |
-
"state": {
|
| 24 |
-
"_dom_classes": [],
|
| 25 |
-
"_model_module": "@jupyter-widgets/controls",
|
| 26 |
-
"_model_module_version": "1.5.0",
|
| 27 |
-
"_model_name": "HBoxModel",
|
| 28 |
-
"_view_count": null,
|
| 29 |
-
"_view_module": "@jupyter-widgets/controls",
|
| 30 |
-
"_view_module_version": "1.5.0",
|
| 31 |
-
"_view_name": "HBoxView",
|
| 32 |
-
"box_style": "",
|
| 33 |
-
"children": [
|
| 34 |
-
"IPY_MODEL_d19c38e4492c4b9bb4c03fb285df62a2",
|
| 35 |
-
"IPY_MODEL_815e13d941ee4a26929ca4ca01c8c615",
|
| 36 |
-
"IPY_MODEL_1fa8e074d5954a6daed1c025aa8191ab"
|
| 37 |
-
],
|
| 38 |
-
"layout": "IPY_MODEL_489118d586fb4c0bb2a148796d39d0e3"
|
| 39 |
-
}
|
| 40 |
-
},
|
| 41 |
-
"d19c38e4492c4b9bb4c03fb285df62a2": {
|
| 42 |
-
"model_module": "@jupyter-widgets/controls",
|
| 43 |
-
"model_name": "HTMLModel",
|
| 44 |
-
"model_module_version": "1.5.0",
|
| 45 |
-
"state": {
|
| 46 |
-
"_dom_classes": [],
|
| 47 |
-
"_model_module": "@jupyter-widgets/controls",
|
| 48 |
-
"_model_module_version": "1.5.0",
|
| 49 |
-
"_model_name": "HTMLModel",
|
| 50 |
-
"_view_count": null,
|
| 51 |
-
"_view_module": "@jupyter-widgets/controls",
|
| 52 |
-
"_view_module_version": "1.5.0",
|
| 53 |
-
"_view_name": "HTMLView",
|
| 54 |
-
"description": "",
|
| 55 |
-
"description_tooltip": null,
|
| 56 |
-
"layout": "IPY_MODEL_1764933ac18c4c228ad8a852719c4f1a",
|
| 57 |
-
"placeholder": "",
|
| 58 |
-
"style": "IPY_MODEL_07db2fdc39c849bf9fb4be93aca7cf4e",
|
| 59 |
-
"value": "model.safetensors: 100%"
|
| 60 |
-
}
|
| 61 |
-
},
|
| 62 |
-
"815e13d941ee4a26929ca4ca01c8c615": {
|
| 63 |
-
"model_module": "@jupyter-widgets/controls",
|
| 64 |
-
"model_name": "FloatProgressModel",
|
| 65 |
-
"model_module_version": "1.5.0",
|
| 66 |
-
"state": {
|
| 67 |
-
"_dom_classes": [],
|
| 68 |
-
"_model_module": "@jupyter-widgets/controls",
|
| 69 |
-
"_model_module_version": "1.5.0",
|
| 70 |
-
"_model_name": "FloatProgressModel",
|
| 71 |
-
"_view_count": null,
|
| 72 |
-
"_view_module": "@jupyter-widgets/controls",
|
| 73 |
-
"_view_module_version": "1.5.0",
|
| 74 |
-
"_view_name": "ProgressView",
|
| 75 |
-
"bar_style": "success",
|
| 76 |
-
"description": "",
|
| 77 |
-
"description_tooltip": null,
|
| 78 |
-
"layout": "IPY_MODEL_6b82e7b5673a4ad4bb112905754d0cda",
|
| 79 |
-
"max": 21355344,
|
| 80 |
-
"min": 0,
|
| 81 |
-
"orientation": "horizontal",
|
| 82 |
-
"style": "IPY_MODEL_693ce4424a8240a38ccebaabf030c5ff",
|
| 83 |
-
"value": 21355344
|
| 84 |
-
}
|
| 85 |
-
},
|
| 86 |
-
"1fa8e074d5954a6daed1c025aa8191ab": {
|
| 87 |
-
"model_module": "@jupyter-widgets/controls",
|
| 88 |
-
"model_name": "HTMLModel",
|
| 89 |
-
"model_module_version": "1.5.0",
|
| 90 |
-
"state": {
|
| 91 |
-
"_dom_classes": [],
|
| 92 |
-
"_model_module": "@jupyter-widgets/controls",
|
| 93 |
-
"_model_module_version": "1.5.0",
|
| 94 |
-
"_model_name": "HTMLModel",
|
| 95 |
-
"_view_count": null,
|
| 96 |
-
"_view_module": "@jupyter-widgets/controls",
|
| 97 |
-
"_view_module_version": "1.5.0",
|
| 98 |
-
"_view_name": "HTMLView",
|
| 99 |
-
"description": "",
|
| 100 |
-
"description_tooltip": null,
|
| 101 |
-
"layout": "IPY_MODEL_31420da8c458472c85381cf0e7a7a242",
|
| 102 |
-
"placeholder": "",
|
| 103 |
-
"style": "IPY_MODEL_bd6b1fcf472a4860bbd8adb8a7f8c4ab",
|
| 104 |
-
"value": " 21.4M/21.4M [00:00<00:00, 53.6MB/s]"
|
| 105 |
-
}
|
| 106 |
-
},
|
| 107 |
-
"489118d586fb4c0bb2a148796d39d0e3": {
|
| 108 |
-
"model_module": "@jupyter-widgets/base",
|
| 109 |
-
"model_name": "LayoutModel",
|
| 110 |
-
"model_module_version": "1.2.0",
|
| 111 |
-
"state": {
|
| 112 |
-
"_model_module": "@jupyter-widgets/base",
|
| 113 |
-
"_model_module_version": "1.2.0",
|
| 114 |
-
"_model_name": "LayoutModel",
|
| 115 |
-
"_view_count": null,
|
| 116 |
-
"_view_module": "@jupyter-widgets/base",
|
| 117 |
-
"_view_module_version": "1.2.0",
|
| 118 |
-
"_view_name": "LayoutView",
|
| 119 |
-
"align_content": null,
|
| 120 |
-
"align_items": null,
|
| 121 |
-
"align_self": null,
|
| 122 |
-
"border": null,
|
| 123 |
-
"bottom": null,
|
| 124 |
-
"display": null,
|
| 125 |
-
"flex": null,
|
| 126 |
-
"flex_flow": null,
|
| 127 |
-
"grid_area": null,
|
| 128 |
-
"grid_auto_columns": null,
|
| 129 |
-
"grid_auto_flow": null,
|
| 130 |
-
"grid_auto_rows": null,
|
| 131 |
-
"grid_column": null,
|
| 132 |
-
"grid_gap": null,
|
| 133 |
-
"grid_row": null,
|
| 134 |
-
"grid_template_areas": null,
|
| 135 |
-
"grid_template_columns": null,
|
| 136 |
-
"grid_template_rows": null,
|
| 137 |
-
"height": null,
|
| 138 |
-
"justify_content": null,
|
| 139 |
-
"justify_items": null,
|
| 140 |
-
"left": null,
|
| 141 |
-
"margin": null,
|
| 142 |
-
"max_height": null,
|
| 143 |
-
"max_width": null,
|
| 144 |
-
"min_height": null,
|
| 145 |
-
"min_width": null,
|
| 146 |
-
"object_fit": null,
|
| 147 |
-
"object_position": null,
|
| 148 |
-
"order": null,
|
| 149 |
-
"overflow": null,
|
| 150 |
-
"overflow_x": null,
|
| 151 |
-
"overflow_y": null,
|
| 152 |
-
"padding": null,
|
| 153 |
-
"right": null,
|
| 154 |
-
"top": null,
|
| 155 |
-
"visibility": null,
|
| 156 |
-
"width": null
|
| 157 |
-
}
|
| 158 |
-
},
|
| 159 |
-
"1764933ac18c4c228ad8a852719c4f1a": {
|
| 160 |
-
"model_module": "@jupyter-widgets/base",
|
| 161 |
-
"model_name": "LayoutModel",
|
| 162 |
-
"model_module_version": "1.2.0",
|
| 163 |
-
"state": {
|
| 164 |
-
"_model_module": "@jupyter-widgets/base",
|
| 165 |
-
"_model_module_version": "1.2.0",
|
| 166 |
-
"_model_name": "LayoutModel",
|
| 167 |
-
"_view_count": null,
|
| 168 |
-
"_view_module": "@jupyter-widgets/base",
|
| 169 |
-
"_view_module_version": "1.2.0",
|
| 170 |
-
"_view_name": "LayoutView",
|
| 171 |
-
"align_content": null,
|
| 172 |
-
"align_items": null,
|
| 173 |
-
"align_self": null,
|
| 174 |
-
"border": null,
|
| 175 |
-
"bottom": null,
|
| 176 |
-
"display": null,
|
| 177 |
-
"flex": null,
|
| 178 |
-
"flex_flow": null,
|
| 179 |
-
"grid_area": null,
|
| 180 |
-
"grid_auto_columns": null,
|
| 181 |
-
"grid_auto_flow": null,
|
| 182 |
-
"grid_auto_rows": null,
|
| 183 |
-
"grid_column": null,
|
| 184 |
-
"grid_gap": null,
|
| 185 |
-
"grid_row": null,
|
| 186 |
-
"grid_template_areas": null,
|
| 187 |
-
"grid_template_columns": null,
|
| 188 |
-
"grid_template_rows": null,
|
| 189 |
-
"height": null,
|
| 190 |
-
"justify_content": null,
|
| 191 |
-
"justify_items": null,
|
| 192 |
-
"left": null,
|
| 193 |
-
"margin": null,
|
| 194 |
-
"max_height": null,
|
| 195 |
-
"max_width": null,
|
| 196 |
-
"min_height": null,
|
| 197 |
-
"min_width": null,
|
| 198 |
-
"object_fit": null,
|
| 199 |
-
"object_position": null,
|
| 200 |
-
"order": null,
|
| 201 |
-
"overflow": null,
|
| 202 |
-
"overflow_x": null,
|
| 203 |
-
"overflow_y": null,
|
| 204 |
-
"padding": null,
|
| 205 |
-
"right": null,
|
| 206 |
-
"top": null,
|
| 207 |
-
"visibility": null,
|
| 208 |
-
"width": null
|
| 209 |
-
}
|
| 210 |
-
},
|
| 211 |
-
"07db2fdc39c849bf9fb4be93aca7cf4e": {
|
| 212 |
-
"model_module": "@jupyter-widgets/controls",
|
| 213 |
-
"model_name": "DescriptionStyleModel",
|
| 214 |
-
"model_module_version": "1.5.0",
|
| 215 |
-
"state": {
|
| 216 |
-
"_model_module": "@jupyter-widgets/controls",
|
| 217 |
-
"_model_module_version": "1.5.0",
|
| 218 |
-
"_model_name": "DescriptionStyleModel",
|
| 219 |
-
"_view_count": null,
|
| 220 |
-
"_view_module": "@jupyter-widgets/base",
|
| 221 |
-
"_view_module_version": "1.2.0",
|
| 222 |
-
"_view_name": "StyleView",
|
| 223 |
-
"description_width": ""
|
| 224 |
-
}
|
| 225 |
-
},
|
| 226 |
-
"6b82e7b5673a4ad4bb112905754d0cda": {
|
| 227 |
-
"model_module": "@jupyter-widgets/base",
|
| 228 |
-
"model_name": "LayoutModel",
|
| 229 |
-
"model_module_version": "1.2.0",
|
| 230 |
-
"state": {
|
| 231 |
-
"_model_module": "@jupyter-widgets/base",
|
| 232 |
-
"_model_module_version": "1.2.0",
|
| 233 |
-
"_model_name": "LayoutModel",
|
| 234 |
-
"_view_count": null,
|
| 235 |
-
"_view_module": "@jupyter-widgets/base",
|
| 236 |
-
"_view_module_version": "1.2.0",
|
| 237 |
-
"_view_name": "LayoutView",
|
| 238 |
-
"align_content": null,
|
| 239 |
-
"align_items": null,
|
| 240 |
-
"align_self": null,
|
| 241 |
-
"border": null,
|
| 242 |
-
"bottom": null,
|
| 243 |
-
"display": null,
|
| 244 |
-
"flex": null,
|
| 245 |
-
"flex_flow": null,
|
| 246 |
-
"grid_area": null,
|
| 247 |
-
"grid_auto_columns": null,
|
| 248 |
-
"grid_auto_flow": null,
|
| 249 |
-
"grid_auto_rows": null,
|
| 250 |
-
"grid_column": null,
|
| 251 |
-
"grid_gap": null,
|
| 252 |
-
"grid_row": null,
|
| 253 |
-
"grid_template_areas": null,
|
| 254 |
-
"grid_template_columns": null,
|
| 255 |
-
"grid_template_rows": null,
|
| 256 |
-
"height": null,
|
| 257 |
-
"justify_content": null,
|
| 258 |
-
"justify_items": null,
|
| 259 |
-
"left": null,
|
| 260 |
-
"margin": null,
|
| 261 |
-
"max_height": null,
|
| 262 |
-
"max_width": null,
|
| 263 |
-
"min_height": null,
|
| 264 |
-
"min_width": null,
|
| 265 |
-
"object_fit": null,
|
| 266 |
-
"object_position": null,
|
| 267 |
-
"order": null,
|
| 268 |
-
"overflow": null,
|
| 269 |
-
"overflow_x": null,
|
| 270 |
-
"overflow_y": null,
|
| 271 |
-
"padding": null,
|
| 272 |
-
"right": null,
|
| 273 |
-
"top": null,
|
| 274 |
-
"visibility": null,
|
| 275 |
-
"width": null
|
| 276 |
-
}
|
| 277 |
-
},
|
| 278 |
-
"693ce4424a8240a38ccebaabf030c5ff": {
|
| 279 |
-
"model_module": "@jupyter-widgets/controls",
|
| 280 |
-
"model_name": "ProgressStyleModel",
|
| 281 |
-
"model_module_version": "1.5.0",
|
| 282 |
-
"state": {
|
| 283 |
-
"_model_module": "@jupyter-widgets/controls",
|
| 284 |
-
"_model_module_version": "1.5.0",
|
| 285 |
-
"_model_name": "ProgressStyleModel",
|
| 286 |
-
"_view_count": null,
|
| 287 |
-
"_view_module": "@jupyter-widgets/base",
|
| 288 |
-
"_view_module_version": "1.2.0",
|
| 289 |
-
"_view_name": "StyleView",
|
| 290 |
-
"bar_color": null,
|
| 291 |
-
"description_width": ""
|
| 292 |
-
}
|
| 293 |
-
},
|
| 294 |
-
"31420da8c458472c85381cf0e7a7a242": {
|
| 295 |
-
"model_module": "@jupyter-widgets/base",
|
| 296 |
-
"model_name": "LayoutModel",
|
| 297 |
-
"model_module_version": "1.2.0",
|
| 298 |
-
"state": {
|
| 299 |
-
"_model_module": "@jupyter-widgets/base",
|
| 300 |
-
"_model_module_version": "1.2.0",
|
| 301 |
-
"_model_name": "LayoutModel",
|
| 302 |
-
"_view_count": null,
|
| 303 |
-
"_view_module": "@jupyter-widgets/base",
|
| 304 |
-
"_view_module_version": "1.2.0",
|
| 305 |
-
"_view_name": "LayoutView",
|
| 306 |
-
"align_content": null,
|
| 307 |
-
"align_items": null,
|
| 308 |
-
"align_self": null,
|
| 309 |
-
"border": null,
|
| 310 |
-
"bottom": null,
|
| 311 |
-
"display": null,
|
| 312 |
-
"flex": null,
|
| 313 |
-
"flex_flow": null,
|
| 314 |
-
"grid_area": null,
|
| 315 |
-
"grid_auto_columns": null,
|
| 316 |
-
"grid_auto_flow": null,
|
| 317 |
-
"grid_auto_rows": null,
|
| 318 |
-
"grid_column": null,
|
| 319 |
-
"grid_gap": null,
|
| 320 |
-
"grid_row": null,
|
| 321 |
-
"grid_template_areas": null,
|
| 322 |
-
"grid_template_columns": null,
|
| 323 |
-
"grid_template_rows": null,
|
| 324 |
-
"height": null,
|
| 325 |
-
"justify_content": null,
|
| 326 |
-
"justify_items": null,
|
| 327 |
-
"left": null,
|
| 328 |
-
"margin": null,
|
| 329 |
-
"max_height": null,
|
| 330 |
-
"max_width": null,
|
| 331 |
-
"min_height": null,
|
| 332 |
-
"min_width": null,
|
| 333 |
-
"object_fit": null,
|
| 334 |
-
"object_position": null,
|
| 335 |
-
"order": null,
|
| 336 |
-
"overflow": null,
|
| 337 |
-
"overflow_x": null,
|
| 338 |
-
"overflow_y": null,
|
| 339 |
-
"padding": null,
|
| 340 |
-
"right": null,
|
| 341 |
-
"top": null,
|
| 342 |
-
"visibility": null,
|
| 343 |
-
"width": null
|
| 344 |
-
}
|
| 345 |
-
},
|
| 346 |
-
"bd6b1fcf472a4860bbd8adb8a7f8c4ab": {
|
| 347 |
-
"model_module": "@jupyter-widgets/controls",
|
| 348 |
-
"model_name": "DescriptionStyleModel",
|
| 349 |
-
"model_module_version": "1.5.0",
|
| 350 |
-
"state": {
|
| 351 |
-
"_model_module": "@jupyter-widgets/controls",
|
| 352 |
-
"_model_module_version": "1.5.0",
|
| 353 |
-
"_model_name": "DescriptionStyleModel",
|
| 354 |
-
"_view_count": null,
|
| 355 |
-
"_view_module": "@jupyter-widgets/base",
|
| 356 |
-
"_view_module_version": "1.2.0",
|
| 357 |
-
"_view_name": "StyleView",
|
| 358 |
-
"description_width": ""
|
| 359 |
-
}
|
| 360 |
-
}
|
| 361 |
-
}
|
| 362 |
-
}
|
| 363 |
-
},
|
| 364 |
"cells": [
|
| 365 |
{
|
| 366 |
"cell_type": "code",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
"source": [
|
| 368 |
"# Colab: install required packages\n",
|
| 369 |
"!pip install -q kagglehub timm torch torchvision scikit-learn pandas pillow matplotlib gradio openai pytorch-lightning"
|
| 370 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
"metadata": {
|
| 372 |
"colab": {
|
| 373 |
"base_uri": "https://localhost:8080/"
|
| 374 |
},
|
| 375 |
-
"id": "
|
| 376 |
-
"outputId": "
|
| 377 |
},
|
| 378 |
-
"execution_count": 10,
|
| 379 |
"outputs": [
|
| 380 |
{
|
| 381 |
"output_type": "stream",
|
| 382 |
"name": "stdout",
|
| 383 |
"text": [
|
| 384 |
-
"
|
| 385 |
-
|
| 386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
]
|
| 388 |
-
}
|
| 389 |
-
]
|
| 390 |
-
},
|
| 391 |
-
{
|
| 392 |
-
"cell_type": "code",
|
| 393 |
-
"execution_count": 11,
|
| 394 |
-
"metadata": {
|
| 395 |
-
"colab": {
|
| 396 |
-
"base_uri": "https://localhost:8080/"
|
| 397 |
},
|
| 398 |
-
"id": "B7iabh3oOyNe",
|
| 399 |
-
"outputId": "b487c9fb-aa5c-4457-e23b-067c427c15ab"
|
| 400 |
-
},
|
| 401 |
-
"outputs": [
|
| 402 |
{
|
| 403 |
"output_type": "stream",
|
| 404 |
"name": "stdout",
|
| 405 |
"text": [
|
| 406 |
-
"
|
| 407 |
-
"Dataset
|
| 408 |
-
"Training\n",
|
| 409 |
-
"Testing\n"
|
| 410 |
]
|
| 411 |
}
|
| 412 |
],
|
| 413 |
"source": [
|
| 414 |
"import kagglehub\n",
|
| 415 |
-
"dataset_ref = \"
|
| 416 |
"path = kagglehub.dataset_download(dataset_ref)\n",
|
| 417 |
"print(\"Dataset path:\", path)\n",
|
| 418 |
"\n",
|
|
@@ -424,18 +88,14 @@
|
|
| 424 |
},
|
| 425 |
{
|
| 426 |
"cell_type": "code",
|
| 427 |
-
"
|
| 428 |
-
"import torch\n",
|
| 429 |
-
"print(torch.cuda.is_available())"
|
| 430 |
-
],
|
| 431 |
"metadata": {
|
| 432 |
"colab": {
|
| 433 |
"base_uri": "https://localhost:8080/"
|
| 434 |
},
|
| 435 |
"id": "gq9zpPkmZVK7",
|
| 436 |
-
"outputId": "
|
| 437 |
},
|
| 438 |
-
"execution_count": 20,
|
| 439 |
"outputs": [
|
| 440 |
{
|
| 441 |
"output_type": "stream",
|
|
@@ -444,431 +104,274 @@
|
|
| 444 |
"True\n"
|
| 445 |
]
|
| 446 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
]
|
| 448 |
},
|
| 449 |
{
|
| 450 |
"cell_type": "code",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
"source": [
|
| 452 |
"# Create CSV (multi-class) from folders (in-memory)\n",
|
| 453 |
"\n",
|
| 454 |
"# Change DATA_ROOT to the folder that directly contains glioma_tumor, meningioma_tumor, pituitary_tumor, no_tumor."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
],
|
| 456 |
"metadata": {
|
| 457 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
},
|
| 459 |
-
"execution_count":
|
| 460 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
},
|
| 462 |
{
|
| 463 |
"cell_type": "code",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
"source": [
|
|
|
|
|
|
|
|
|
|
| 465 |
"import os, pandas as pd\n",
|
|
|
|
| 466 |
"\n",
|
| 467 |
-
"#
|
| 468 |
-
"
|
| 469 |
"\n",
|
| 470 |
-
"
|
| 471 |
-
"
|
| 472 |
-
"
|
| 473 |
-
"
|
| 474 |
-
" \"no_tumor\": 3\n",
|
| 475 |
-
"}\n",
|
| 476 |
"\n",
|
| 477 |
"rows = []\n",
|
| 478 |
-
"for
|
| 479 |
-
"
|
| 480 |
-
" if not os.path.isdir(
|
| 481 |
-
" continue\n",
|
| 482 |
-
" if folder not in LABEL_MAP:\n",
|
| 483 |
-
" print(\"Skipping unknown folder:\", folder)\n",
|
| 484 |
" continue\n",
|
| 485 |
-
"
|
| 486 |
-
"
|
| 487 |
-
"
|
| 488 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
"\n",
|
| 490 |
"df = pd.DataFrame(rows).sample(frac=1, random_state=42).reset_index(drop=True)\n",
|
| 491 |
"print(\"Total images:\", len(df))\n",
|
| 492 |
"print(df.label.value_counts())\n",
|
| 493 |
-
"
|
| 494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
"metadata": {
|
| 496 |
"colab": {
|
| 497 |
-
"base_uri": "https://localhost:8080/"
|
| 498 |
-
"height": 367
|
| 499 |
},
|
| 500 |
-
"id": "
|
| 501 |
-
"outputId": "
|
| 502 |
},
|
| 503 |
-
"execution_count": 13,
|
| 504 |
"outputs": [
|
| 505 |
{
|
| 506 |
"output_type": "stream",
|
| 507 |
"name": "stdout",
|
| 508 |
"text": [
|
| 509 |
-
"
|
| 510 |
-
"label\n",
|
| 511 |
-
"2 827\n",
|
| 512 |
-
"0 826\n",
|
| 513 |
-
"1 822\n",
|
| 514 |
-
"3 395\n",
|
| 515 |
-
"Name: count, dtype: int64\n"
|
| 516 |
]
|
| 517 |
-
},
|
| 518 |
-
{
|
| 519 |
-
"output_type": "execute_result",
|
| 520 |
-
"data": {
|
| 521 |
-
"text/plain": [
|
| 522 |
-
" image label\n",
|
| 523 |
-
"0 /kaggle/input/brain-tumor-classification-mri/T... 0\n",
|
| 524 |
-
"1 /kaggle/input/brain-tumor-classification-mri/T... 2\n",
|
| 525 |
-
"2 /kaggle/input/brain-tumor-classification-mri/T... 0\n",
|
| 526 |
-
"3 /kaggle/input/brain-tumor-classification-mri/T... 0\n",
|
| 527 |
-
"4 /kaggle/input/brain-tumor-classification-mri/T... 2"
|
| 528 |
-
],
|
| 529 |
-
"text/html": [
|
| 530 |
-
"\n",
|
| 531 |
-
" <div id=\"df-870f654a-e6d9-4f6c-938c-91d0b4e98446\" class=\"colab-df-container\">\n",
|
| 532 |
-
" <div>\n",
|
| 533 |
-
"<style scoped>\n",
|
| 534 |
-
" .dataframe tbody tr th:only-of-type {\n",
|
| 535 |
-
" vertical-align: middle;\n",
|
| 536 |
-
" }\n",
|
| 537 |
-
"\n",
|
| 538 |
-
" .dataframe tbody tr th {\n",
|
| 539 |
-
" vertical-align: top;\n",
|
| 540 |
-
" }\n",
|
| 541 |
-
"\n",
|
| 542 |
-
" .dataframe thead th {\n",
|
| 543 |
-
" text-align: right;\n",
|
| 544 |
-
" }\n",
|
| 545 |
-
"</style>\n",
|
| 546 |
-
"<table border=\"1\" class=\"dataframe\">\n",
|
| 547 |
-
" <thead>\n",
|
| 548 |
-
" <tr style=\"text-align: right;\">\n",
|
| 549 |
-
" <th></th>\n",
|
| 550 |
-
" <th>image</th>\n",
|
| 551 |
-
" <th>label</th>\n",
|
| 552 |
-
" </tr>\n",
|
| 553 |
-
" </thead>\n",
|
| 554 |
-
" <tbody>\n",
|
| 555 |
-
" <tr>\n",
|
| 556 |
-
" <th>0</th>\n",
|
| 557 |
-
" <td>/kaggle/input/brain-tumor-classification-mri/T...</td>\n",
|
| 558 |
-
" <td>0</td>\n",
|
| 559 |
-
" </tr>\n",
|
| 560 |
-
" <tr>\n",
|
| 561 |
-
" <th>1</th>\n",
|
| 562 |
-
" <td>/kaggle/input/brain-tumor-classification-mri/T...</td>\n",
|
| 563 |
-
" <td>2</td>\n",
|
| 564 |
-
" </tr>\n",
|
| 565 |
-
" <tr>\n",
|
| 566 |
-
" <th>2</th>\n",
|
| 567 |
-
" <td>/kaggle/input/brain-tumor-classification-mri/T...</td>\n",
|
| 568 |
-
" <td>0</td>\n",
|
| 569 |
-
" </tr>\n",
|
| 570 |
-
" <tr>\n",
|
| 571 |
-
" <th>3</th>\n",
|
| 572 |
-
" <td>/kaggle/input/brain-tumor-classification-mri/T...</td>\n",
|
| 573 |
-
" <td>0</td>\n",
|
| 574 |
-
" </tr>\n",
|
| 575 |
-
" <tr>\n",
|
| 576 |
-
" <th>4</th>\n",
|
| 577 |
-
" <td>/kaggle/input/brain-tumor-classification-mri/T...</td>\n",
|
| 578 |
-
" <td>2</td>\n",
|
| 579 |
-
" </tr>\n",
|
| 580 |
-
" </tbody>\n",
|
| 581 |
-
"</table>\n",
|
| 582 |
-
"</div>\n",
|
| 583 |
-
" <div class=\"colab-df-buttons\">\n",
|
| 584 |
-
"\n",
|
| 585 |
-
" <div class=\"colab-df-container\">\n",
|
| 586 |
-
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-870f654a-e6d9-4f6c-938c-91d0b4e98446')\"\n",
|
| 587 |
-
" title=\"Convert this dataframe to an interactive table.\"\n",
|
| 588 |
-
" style=\"display:none;\">\n",
|
| 589 |
-
"\n",
|
| 590 |
-
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
|
| 591 |
-
" <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
|
| 592 |
-
" </svg>\n",
|
| 593 |
-
" </button>\n",
|
| 594 |
-
"\n",
|
| 595 |
-
" <style>\n",
|
| 596 |
-
" .colab-df-container {\n",
|
| 597 |
-
" display:flex;\n",
|
| 598 |
-
" gap: 12px;\n",
|
| 599 |
-
" }\n",
|
| 600 |
-
"\n",
|
| 601 |
-
" .colab-df-convert {\n",
|
| 602 |
-
" background-color: #E8F0FE;\n",
|
| 603 |
-
" border: none;\n",
|
| 604 |
-
" border-radius: 50%;\n",
|
| 605 |
-
" cursor: pointer;\n",
|
| 606 |
-
" display: none;\n",
|
| 607 |
-
" fill: #1967D2;\n",
|
| 608 |
-
" height: 32px;\n",
|
| 609 |
-
" padding: 0 0 0 0;\n",
|
| 610 |
-
" width: 32px;\n",
|
| 611 |
-
" }\n",
|
| 612 |
-
"\n",
|
| 613 |
-
" .colab-df-convert:hover {\n",
|
| 614 |
-
" background-color: #E2EBFA;\n",
|
| 615 |
-
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
|
| 616 |
-
" fill: #174EA6;\n",
|
| 617 |
-
" }\n",
|
| 618 |
-
"\n",
|
| 619 |
-
" .colab-df-buttons div {\n",
|
| 620 |
-
" margin-bottom: 4px;\n",
|
| 621 |
-
" }\n",
|
| 622 |
-
"\n",
|
| 623 |
-
" [theme=dark] .colab-df-convert {\n",
|
| 624 |
-
" background-color: #3B4455;\n",
|
| 625 |
-
" fill: #D2E3FC;\n",
|
| 626 |
-
" }\n",
|
| 627 |
-
"\n",
|
| 628 |
-
" [theme=dark] .colab-df-convert:hover {\n",
|
| 629 |
-
" background-color: #434B5C;\n",
|
| 630 |
-
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
|
| 631 |
-
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
|
| 632 |
-
" fill: #FFFFFF;\n",
|
| 633 |
-
" }\n",
|
| 634 |
-
" </style>\n",
|
| 635 |
-
"\n",
|
| 636 |
-
" <script>\n",
|
| 637 |
-
" const buttonEl =\n",
|
| 638 |
-
" document.querySelector('#df-870f654a-e6d9-4f6c-938c-91d0b4e98446 button.colab-df-convert');\n",
|
| 639 |
-
" buttonEl.style.display =\n",
|
| 640 |
-
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
|
| 641 |
-
"\n",
|
| 642 |
-
" async function convertToInteractive(key) {\n",
|
| 643 |
-
" const element = document.querySelector('#df-870f654a-e6d9-4f6c-938c-91d0b4e98446');\n",
|
| 644 |
-
" const dataTable =\n",
|
| 645 |
-
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
|
| 646 |
-
" [key], {});\n",
|
| 647 |
-
" if (!dataTable) return;\n",
|
| 648 |
-
"\n",
|
| 649 |
-
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
|
| 650 |
-
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
|
| 651 |
-
" + ' to learn more about interactive tables.';\n",
|
| 652 |
-
" element.innerHTML = '';\n",
|
| 653 |
-
" dataTable['output_type'] = 'display_data';\n",
|
| 654 |
-
" await google.colab.output.renderOutput(dataTable, element);\n",
|
| 655 |
-
" const docLink = document.createElement('div');\n",
|
| 656 |
-
" docLink.innerHTML = docLinkHtml;\n",
|
| 657 |
-
" element.appendChild(docLink);\n",
|
| 658 |
-
" }\n",
|
| 659 |
-
" </script>\n",
|
| 660 |
-
" </div>\n",
|
| 661 |
-
"\n",
|
| 662 |
-
"\n",
|
| 663 |
-
" <div id=\"df-39e3087a-f1c1-4313-b033-254945753780\">\n",
|
| 664 |
-
" <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-39e3087a-f1c1-4313-b033-254945753780')\"\n",
|
| 665 |
-
" title=\"Suggest charts\"\n",
|
| 666 |
-
" style=\"display:none;\">\n",
|
| 667 |
-
"\n",
|
| 668 |
-
"<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
|
| 669 |
-
" width=\"24px\">\n",
|
| 670 |
-
" <g>\n",
|
| 671 |
-
" <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
|
| 672 |
-
" </g>\n",
|
| 673 |
-
"</svg>\n",
|
| 674 |
-
" </button>\n",
|
| 675 |
-
"\n",
|
| 676 |
-
"<style>\n",
|
| 677 |
-
" .colab-df-quickchart {\n",
|
| 678 |
-
" --bg-color: #E8F0FE;\n",
|
| 679 |
-
" --fill-color: #1967D2;\n",
|
| 680 |
-
" --hover-bg-color: #E2EBFA;\n",
|
| 681 |
-
" --hover-fill-color: #174EA6;\n",
|
| 682 |
-
" --disabled-fill-color: #AAA;\n",
|
| 683 |
-
" --disabled-bg-color: #DDD;\n",
|
| 684 |
-
" }\n",
|
| 685 |
-
"\n",
|
| 686 |
-
" [theme=dark] .colab-df-quickchart {\n",
|
| 687 |
-
" --bg-color: #3B4455;\n",
|
| 688 |
-
" --fill-color: #D2E3FC;\n",
|
| 689 |
-
" --hover-bg-color: #434B5C;\n",
|
| 690 |
-
" --hover-fill-color: #FFFFFF;\n",
|
| 691 |
-
" --disabled-bg-color: #3B4455;\n",
|
| 692 |
-
" --disabled-fill-color: #666;\n",
|
| 693 |
-
" }\n",
|
| 694 |
-
"\n",
|
| 695 |
-
" .colab-df-quickchart {\n",
|
| 696 |
-
" background-color: var(--bg-color);\n",
|
| 697 |
-
" border: none;\n",
|
| 698 |
-
" border-radius: 50%;\n",
|
| 699 |
-
" cursor: pointer;\n",
|
| 700 |
-
" display: none;\n",
|
| 701 |
-
" fill: var(--fill-color);\n",
|
| 702 |
-
" height: 32px;\n",
|
| 703 |
-
" padding: 0;\n",
|
| 704 |
-
" width: 32px;\n",
|
| 705 |
-
" }\n",
|
| 706 |
-
"\n",
|
| 707 |
-
" .colab-df-quickchart:hover {\n",
|
| 708 |
-
" background-color: var(--hover-bg-color);\n",
|
| 709 |
-
" box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
|
| 710 |
-
" fill: var(--button-hover-fill-color);\n",
|
| 711 |
-
" }\n",
|
| 712 |
-
"\n",
|
| 713 |
-
" .colab-df-quickchart-complete:disabled,\n",
|
| 714 |
-
" .colab-df-quickchart-complete:disabled:hover {\n",
|
| 715 |
-
" background-color: var(--disabled-bg-color);\n",
|
| 716 |
-
" fill: var(--disabled-fill-color);\n",
|
| 717 |
-
" box-shadow: none;\n",
|
| 718 |
-
" }\n",
|
| 719 |
-
"\n",
|
| 720 |
-
" .colab-df-spinner {\n",
|
| 721 |
-
" border: 2px solid var(--fill-color);\n",
|
| 722 |
-
" border-color: transparent;\n",
|
| 723 |
-
" border-bottom-color: var(--fill-color);\n",
|
| 724 |
-
" animation:\n",
|
| 725 |
-
" spin 1s steps(1) infinite;\n",
|
| 726 |
-
" }\n",
|
| 727 |
-
"\n",
|
| 728 |
-
" @keyframes spin {\n",
|
| 729 |
-
" 0% {\n",
|
| 730 |
-
" border-color: transparent;\n",
|
| 731 |
-
" border-bottom-color: var(--fill-color);\n",
|
| 732 |
-
" border-left-color: var(--fill-color);\n",
|
| 733 |
-
" }\n",
|
| 734 |
-
" 20% {\n",
|
| 735 |
-
" border-color: transparent;\n",
|
| 736 |
-
" border-left-color: var(--fill-color);\n",
|
| 737 |
-
" border-top-color: var(--fill-color);\n",
|
| 738 |
-
" }\n",
|
| 739 |
-
" 30% {\n",
|
| 740 |
-
" border-color: transparent;\n",
|
| 741 |
-
" border-left-color: var(--fill-color);\n",
|
| 742 |
-
" border-top-color: var(--fill-color);\n",
|
| 743 |
-
" border-right-color: var(--fill-color);\n",
|
| 744 |
-
" }\n",
|
| 745 |
-
" 40% {\n",
|
| 746 |
-
" border-color: transparent;\n",
|
| 747 |
-
" border-right-color: var(--fill-color);\n",
|
| 748 |
-
" border-top-color: var(--fill-color);\n",
|
| 749 |
-
" }\n",
|
| 750 |
-
" 60% {\n",
|
| 751 |
-
" border-color: transparent;\n",
|
| 752 |
-
" border-right-color: var(--fill-color);\n",
|
| 753 |
-
" }\n",
|
| 754 |
-
" 80% {\n",
|
| 755 |
-
" border-color: transparent;\n",
|
| 756 |
-
" border-right-color: var(--fill-color);\n",
|
| 757 |
-
" border-bottom-color: var(--fill-color);\n",
|
| 758 |
-
" }\n",
|
| 759 |
-
" 90% {\n",
|
| 760 |
-
" border-color: transparent;\n",
|
| 761 |
-
" border-bottom-color: var(--fill-color);\n",
|
| 762 |
-
" }\n",
|
| 763 |
-
" }\n",
|
| 764 |
-
"</style>\n",
|
| 765 |
-
"\n",
|
| 766 |
-
" <script>\n",
|
| 767 |
-
" async function quickchart(key) {\n",
|
| 768 |
-
" const quickchartButtonEl =\n",
|
| 769 |
-
" document.querySelector('#' + key + ' button');\n",
|
| 770 |
-
" quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n",
|
| 771 |
-
" quickchartButtonEl.classList.add('colab-df-spinner');\n",
|
| 772 |
-
" try {\n",
|
| 773 |
-
" const charts = await google.colab.kernel.invokeFunction(\n",
|
| 774 |
-
" 'suggestCharts', [key], {});\n",
|
| 775 |
-
" } catch (error) {\n",
|
| 776 |
-
" console.error('Error during call to suggestCharts:', error);\n",
|
| 777 |
-
" }\n",
|
| 778 |
-
" quickchartButtonEl.classList.remove('colab-df-spinner');\n",
|
| 779 |
-
" quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
|
| 780 |
-
" }\n",
|
| 781 |
-
" (() => {\n",
|
| 782 |
-
" let quickchartButtonEl =\n",
|
| 783 |
-
" document.querySelector('#df-39e3087a-f1c1-4313-b033-254945753780 button');\n",
|
| 784 |
-
" quickchartButtonEl.style.display =\n",
|
| 785 |
-
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
|
| 786 |
-
" })();\n",
|
| 787 |
-
" </script>\n",
|
| 788 |
-
" </div>\n",
|
| 789 |
-
"\n",
|
| 790 |
-
" </div>\n",
|
| 791 |
-
" </div>\n"
|
| 792 |
-
],
|
| 793 |
-
"application/vnd.google.colaboratory.intrinsic+json": {
|
| 794 |
-
"type": "dataframe",
|
| 795 |
-
"variable_name": "df",
|
| 796 |
-
"summary": "{\n \"name\": \"df\",\n \"rows\": 2870,\n \"fields\": [\n {\n \"column\": \"image\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 2870,\n \"samples\": [\n \"/kaggle/input/brain-tumor-classification-mri/Training/glioma_tumor/gg (87).jpg\",\n \"/kaggle/input/brain-tumor-classification-mri/Training/meningioma_tumor/m2 (51).jpg\",\n \"/kaggle/input/brain-tumor-classification-mri/Training/meningioma_tumor/m1(109).jpg\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"label\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 3,\n \"num_unique_values\": 4,\n \"samples\": [\n 2,\n 3,\n 0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
|
| 797 |
-
}
|
| 798 |
-
},
|
| 799 |
-
"metadata": {},
|
| 800 |
-
"execution_count": 13
|
| 801 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 802 |
]
|
| 803 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 804 |
{
|
| 805 |
"cell_type": "code",
|
| 806 |
"source": [
|
| 807 |
-
"#
|
| 808 |
-
"
|
| 809 |
-
"\n",
|
| 810 |
-
"
|
| 811 |
-
"
|
| 812 |
-
"
|
| 813 |
-
"\n",
|
| 814 |
-
"print(\"Train:\", len(train_df), \"Val:\", len(val_df))\n",
|
| 815 |
-
"print(train_df.label.value_counts(), val_df.label.value_counts())\n"
|
| 816 |
],
|
| 817 |
"metadata": {
|
| 818 |
"colab": {
|
| 819 |
"base_uri": "https://localhost:8080/"
|
| 820 |
},
|
| 821 |
-
"id": "
|
| 822 |
-
"outputId": "
|
| 823 |
},
|
| 824 |
-
"execution_count":
|
| 825 |
"outputs": [
|
| 826 |
{
|
| 827 |
"output_type": "stream",
|
| 828 |
"name": "stdout",
|
| 829 |
"text": [
|
| 830 |
-
"
|
| 831 |
-
"label\n",
|
| 832 |
-
"2 662\n",
|
| 833 |
-
"0 661\n",
|
| 834 |
-
"1 657\n",
|
| 835 |
-
"3 316\n",
|
| 836 |
-
"Name: count, dtype: int64 label\n",
|
| 837 |
-
"1 165\n",
|
| 838 |
-
"2 165\n",
|
| 839 |
-
"0 165\n",
|
| 840 |
-
"3 79\n",
|
| 841 |
-
"Name: count, dtype: int64\n"
|
| 842 |
]
|
| 843 |
}
|
| 844 |
]
|
| 845 |
},
|
| 846 |
{
|
| 847 |
"cell_type": "code",
|
| 848 |
-
"
|
| 849 |
-
"# Define Dataset and Transforms (inline)"
|
| 850 |
-
],
|
| 851 |
"metadata": {
|
| 852 |
"id": "MYxdFZEiY4_9"
|
| 853 |
},
|
| 854 |
-
"
|
| 855 |
-
"
|
| 856 |
},
|
| 857 |
{
|
| 858 |
"cell_type": "code",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 859 |
"source": [
|
|
|
|
| 860 |
"import torch\n",
|
| 861 |
"from torch.utils.data import Dataset, DataLoader\n",
|
| 862 |
"from PIL import Image\n",
|
| 863 |
"import torchvision.transforms as T\n",
|
| 864 |
"\n",
|
| 865 |
-
"
|
| 866 |
-
"\n",
|
| 867 |
"train_tf = T.Compose([\n",
|
| 868 |
" T.Resize((IMG_SIZE, IMG_SIZE)),\n",
|
| 869 |
" T.RandomHorizontalFlip(),\n",
|
| 870 |
" T.RandomRotation(10),\n",
|
| 871 |
-
" T.ColorJitter(
|
| 872 |
" T.ToTensor(),\n",
|
| 873 |
" T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])\n",
|
| 874 |
"])\n",
|
|
@@ -878,117 +381,85 @@
|
|
| 878 |
" T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])\n",
|
| 879 |
"])\n",
|
| 880 |
"\n",
|
| 881 |
-
"class
|
| 882 |
-
" def __init__(self, df, transform=None):\n",
|
| 883 |
" self.df = df.reset_index(drop=True)\n",
|
| 884 |
" self.transform = transform\n",
|
| 885 |
-
"\n",
|
| 886 |
" def __len__(self):\n",
|
| 887 |
" return len(self.df)\n",
|
| 888 |
-
"\n",
|
| 889 |
" def __getitem__(self, idx):\n",
|
| 890 |
" row = self.df.iloc[idx]\n",
|
| 891 |
-
"
|
|
|
|
| 892 |
" if self.transform:\n",
|
| 893 |
" img = self.transform(img)\n",
|
| 894 |
-
" label =
|
| 895 |
" return img, label\n",
|
| 896 |
"\n",
|
| 897 |
-
"
|
| 898 |
-
"
|
| 899 |
-
"
|
| 900 |
-
"
|
| 901 |
-
"
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
"
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
"
|
| 908 |
-
"outputId": "78520285-f5f9-4987-ef9d-65be96356ca1"
|
| 909 |
-
},
|
| 910 |
-
"execution_count": 16,
|
| 911 |
-
"outputs": [
|
| 912 |
-
{
|
| 913 |
-
"output_type": "stream",
|
| 914 |
-
"name": "stdout",
|
| 915 |
-
"text": [
|
| 916 |
-
"torch.Size([3, 224, 224]) 0\n"
|
| 917 |
-
]
|
| 918 |
-
}
|
| 919 |
]
|
| 920 |
},
|
| 921 |
{
|
| 922 |
"cell_type": "code",
|
| 923 |
-
"
|
| 924 |
"metadata": {
|
| 925 |
"id": "SrJLOm33ZMpC"
|
| 926 |
},
|
| 927 |
-
"
|
| 928 |
-
"
|
| 929 |
},
|
| 930 |
{
|
| 931 |
"cell_type": "code",
|
| 932 |
-
"
|
| 933 |
-
"# Build dataloaders"
|
| 934 |
-
],
|
| 935 |
"metadata": {
|
| 936 |
"id": "bG2JclRDY-Di"
|
| 937 |
},
|
| 938 |
-
"
|
| 939 |
-
"
|
| 940 |
-
},
|
| 941 |
-
{
|
| 942 |
-
"cell_type": "code",
|
| 943 |
-
"source": [
|
| 944 |
-
"BATCH_SIZE = 8 # adjust to your GPU memory\n",
|
| 945 |
-
"\n",
|
| 946 |
-
"train_loader = DataLoader(MedicalMultiClassDataset(train_df, transform=train_tf), batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)\n",
|
| 947 |
-
"val_loader = DataLoader(MedicalMultiClassDataset(val_df, transform=val_tf), batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)\n"
|
| 948 |
-
],
|
| 949 |
-
"metadata": {
|
| 950 |
-
"id": "Jc4pdqreZGDS"
|
| 951 |
-
},
|
| 952 |
-
"execution_count": 18,
|
| 953 |
-
"outputs": []
|
| 954 |
},
|
| 955 |
{
|
| 956 |
"cell_type": "code",
|
| 957 |
-
"
|
| 958 |
-
"# Define model, optimizer, loss, device\n",
|
| 959 |
-
"import timm, torch.nn as nn, torch\n",
|
| 960 |
-
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 961 |
-
"NUM_CLASSES = 4\n",
|
| 962 |
-
"\n",
|
| 963 |
-
"model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=NUM_CLASSES)\n",
|
| 964 |
-
"model = model.to(device)\n",
|
| 965 |
-
"\n",
|
| 966 |
-
"loss_fn = nn.CrossEntropyLoss()\n",
|
| 967 |
-
"optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)\n"
|
| 968 |
-
],
|
| 969 |
"metadata": {
|
|
|
|
| 970 |
"colab": {
|
| 971 |
"base_uri": "https://localhost:8080/",
|
| 972 |
-
"height":
|
| 973 |
"referenced_widgets": [
|
| 974 |
-
"
|
| 975 |
-
"
|
| 976 |
-
"
|
| 977 |
-
"
|
| 978 |
-
"
|
| 979 |
-
"
|
| 980 |
-
"
|
| 981 |
-
"
|
| 982 |
-
"
|
| 983 |
-
"
|
| 984 |
-
"
|
| 985 |
]
|
| 986 |
},
|
| 987 |
-
"
|
| 988 |
-
"outputId": "2e9229ac-0260-41ce-ea49-9c3982ca7ab6"
|
| 989 |
},
|
| 990 |
-
"execution_count": 19,
|
| 991 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 992 |
{
|
| 993 |
"output_type": "stream",
|
| 994 |
"name": "stderr",
|
|
@@ -1010,31 +481,90 @@
|
|
| 1010 |
"application/vnd.jupyter.widget-view+json": {
|
| 1011 |
"version_major": 2,
|
| 1012 |
"version_minor": 0,
|
| 1013 |
-
"model_id": "
|
| 1014 |
}
|
| 1015 |
},
|
| 1016 |
"metadata": {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1017 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1018 |
]
|
| 1019 |
},
|
| 1020 |
{
|
| 1021 |
"cell_type": "code",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1022 |
"source": [
|
|
|
|
| 1023 |
"from sklearn.metrics import accuracy_score\n",
|
| 1024 |
"import time, os\n",
|
|
|
|
| 1025 |
"\n",
|
| 1026 |
-
"
|
| 1027 |
-
"CKPT_DIR = \"/content/checkpoints\"\n",
|
| 1028 |
-
"os.makedirs(CKPT_DIR, exist_ok=True)\n",
|
| 1029 |
"best_val_acc = 0.0\n",
|
|
|
|
|
|
|
|
|
|
| 1030 |
"\n",
|
| 1031 |
"for epoch in range(1, EPOCHS+1):\n",
|
|
|
|
|
|
|
| 1032 |
" model.train()\n",
|
| 1033 |
-
"
|
| 1034 |
" all_preds, all_labels = [], []\n",
|
| 1035 |
-
" t0 = time.time()\n",
|
| 1036 |
-
"\n",
|
| 1037 |
-
" # ---- Training ----\n",
|
| 1038 |
" for imgs, labels in train_loader:\n",
|
| 1039 |
" imgs = imgs.to(device)\n",
|
| 1040 |
" labels = labels.to(device)\n",
|
|
@@ -1043,75 +573,125 @@
|
|
| 1043 |
" loss = loss_fn(logits, labels)\n",
|
| 1044 |
" loss.backward()\n",
|
| 1045 |
" optimizer.step()\n",
|
| 1046 |
-
"
|
| 1047 |
" preds = logits.argmax(dim=1).cpu().numpy()\n",
|
| 1048 |
" all_preds.extend(preds.tolist())\n",
|
| 1049 |
" all_labels.extend(labels.cpu().numpy().tolist())\n",
|
| 1050 |
-
"\n",
|
| 1051 |
-
" train_loss = running_loss / len(train_loader.dataset)\n",
|
| 1052 |
" train_acc = accuracy_score(all_labels, all_preds)\n",
|
| 1053 |
"\n",
|
| 1054 |
-
" #
|
| 1055 |
" model.eval()\n",
|
| 1056 |
-
"
|
| 1057 |
" v_preds, v_labels = [], []\n",
|
| 1058 |
" with torch.no_grad():\n",
|
| 1059 |
" for imgs, labels in val_loader:\n",
|
| 1060 |
" imgs = imgs.to(device)\n",
|
| 1061 |
" labels = labels.to(device)\n",
|
| 1062 |
" logits = model(imgs)\n",
|
| 1063 |
-
" loss = loss_fn(logits, labels)
|
| 1064 |
-
"
|
| 1065 |
" preds = logits.argmax(dim=1).cpu().numpy()\n",
|
| 1066 |
" v_preds.extend(preds.tolist())\n",
|
| 1067 |
" v_labels.extend(labels.cpu().numpy().tolist())\n",
|
| 1068 |
-
"\n",
|
| 1069 |
-
" val_loss = val_running_loss / len(val_loader.dataset)\n",
|
| 1070 |
" val_acc = accuracy_score(v_labels, v_preds)\n",
|
| 1071 |
"\n",
|
| 1072 |
-
"
|
| 1073 |
-
"
|
| 1074 |
-
"
|
| 1075 |
-
"
|
| 1076 |
-
"
|
|
|
|
| 1077 |
"\n",
|
| 1078 |
-
" #
|
| 1079 |
-
"
|
| 1080 |
-
" torch.save(model.state_dict(),
|
| 1081 |
" if val_acc > best_val_acc:\n",
|
| 1082 |
" best_val_acc = val_acc\n",
|
| 1083 |
-
" torch.save(model.state_dict(), os.path.join(
|
| 1084 |
-
" print(\"Saved best model:\", best_val_acc)\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1085 |
],
|
| 1086 |
"metadata": {
|
| 1087 |
"colab": {
|
| 1088 |
"base_uri": "https://localhost:8080/"
|
| 1089 |
},
|
| 1090 |
-
"id": "
|
| 1091 |
-
"outputId": "
|
| 1092 |
},
|
| 1093 |
-
"execution_count":
|
| 1094 |
"outputs": [
|
| 1095 |
{
|
| 1096 |
"output_type": "stream",
|
| 1097 |
"name": "stdout",
|
| 1098 |
"text": [
|
| 1099 |
-
"
|
| 1100 |
-
"
|
| 1101 |
-
"
|
| 1102 |
-
"
|
| 1103 |
-
"
|
| 1104 |
-
"
|
| 1105 |
-
"
|
| 1106 |
-
"
|
| 1107 |
-
"
|
| 1108 |
-
"
|
| 1109 |
-
"
|
| 1110 |
-
"
|
| 1111 |
-
"
|
| 1112 |
-
"
|
| 1113 |
-
"
|
| 1114 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1115 |
]
|
| 1116 |
}
|
| 1117 |
]
|
|
@@ -1120,117 +700,175 @@
|
|
| 1120 |
"cell_type": "code",
|
| 1121 |
"source": [],
|
| 1122 |
"metadata": {
|
| 1123 |
-
"id": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1124 |
},
|
| 1125 |
"execution_count": null,
|
| 1126 |
"outputs": []
|
| 1127 |
},
|
| 1128 |
{
|
| 1129 |
"cell_type": "code",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1130 |
"source": [
|
| 1131 |
"# Load best model & test a single image (inference helper)\n",
|
| 1132 |
"import torch.nn.functional as F\n",
|
| 1133 |
"from PIL import Image\n",
|
| 1134 |
"from torchvision import transforms\n",
|
| 1135 |
"\n",
|
| 1136 |
-
"
|
| 1137 |
-
"
|
| 1138 |
-
"
|
| 1139 |
-
"
|
| 1140 |
-
"\n",
|
| 1141 |
-
"IDX2LABEL = {0:\"glioma_tumor\", 1:\"meningioma_tumor\", 2:\"pituitary_tumor\", 3:\"no_tumor\"}\n",
|
| 1142 |
"\n",
|
| 1143 |
-
"
|
| 1144 |
-
"
|
| 1145 |
-
"
|
| 1146 |
-
" transforms.ToTensor(),\n",
|
| 1147 |
-
" transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])\n",
|
| 1148 |
-
" ])\n",
|
| 1149 |
-
" return tf(pil_img).unsqueeze(0).to(device)\n",
|
| 1150 |
-
"\n",
|
| 1151 |
-
"def predict(pil_img, topk=3):\n",
|
| 1152 |
-
" x = preprocess(pil_img)\n",
|
| 1153 |
" with torch.no_grad():\n",
|
| 1154 |
" logits = model(x)\n",
|
| 1155 |
" probs = F.softmax(logits, dim=1).cpu().numpy().ravel()\n",
|
| 1156 |
" idxs = probs.argsort()[::-1][:topk]\n",
|
| 1157 |
" return [{\"label\": IDX2LABEL[int(i)], \"score\": float(probs[int(i)])} for i in idxs]\n",
|
| 1158 |
"\n",
|
| 1159 |
-
"# quick
|
| 1160 |
-
"
|
| 1161 |
-
"
|
| 1162 |
"img = Image.open(sample_path).convert(\"RGB\")\n",
|
| 1163 |
-
"
|
| 1164 |
-
"
|
| 1165 |
-
],
|
| 1166 |
-
"metadata": {
|
| 1167 |
-
"colab": {
|
| 1168 |
-
"base_uri": "https://localhost:8080/"
|
| 1169 |
-
},
|
| 1170 |
-
"id": "H2BqijX2Zupx",
|
| 1171 |
-
"outputId": "8ad2aef5-e047-4054-acca-f50cbc3429be"
|
| 1172 |
-
},
|
| 1173 |
-
"execution_count": 29,
|
| 1174 |
-
"outputs": [
|
| 1175 |
-
{
|
| 1176 |
-
"output_type": "stream",
|
| 1177 |
-
"name": "stdout",
|
| 1178 |
-
"text": [
|
| 1179 |
-
"Sample: /kaggle/input/brain-tumor-classification-mri/Training/meningioma_tumor/m3 (155).jpg\n"
|
| 1180 |
-
]
|
| 1181 |
-
},
|
| 1182 |
-
{
|
| 1183 |
-
"output_type": "execute_result",
|
| 1184 |
-
"data": {
|
| 1185 |
-
"text/plain": [
|
| 1186 |
-
"[{'label': 'meningioma_tumor', 'score': 0.9999998807907104},\n",
|
| 1187 |
-
" {'label': 'pituitary_tumor', 'score': 5.87046180555717e-08},\n",
|
| 1188 |
-
" {'label': 'glioma_tumor', 'score': 4.05661282343317e-08}]"
|
| 1189 |
-
]
|
| 1190 |
-
},
|
| 1191 |
-
"metadata": {},
|
| 1192 |
-
"execution_count": 29
|
| 1193 |
-
}
|
| 1194 |
]
|
| 1195 |
},
|
| 1196 |
{
|
| 1197 |
"cell_type": "code",
|
| 1198 |
-
"
|
| 1199 |
"metadata": {
|
| 1200 |
"id": "RAZRNlq2ZutV"
|
| 1201 |
},
|
| 1202 |
-
"
|
| 1203 |
-
"
|
| 1204 |
},
|
| 1205 |
{
|
| 1206 |
"cell_type": "code",
|
| 1207 |
-
"
|
| 1208 |
"metadata": {
|
| 1209 |
"id": "chDN6CJmc8h0"
|
| 1210 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1211 |
"execution_count": null,
|
| 1212 |
"outputs": []
|
| 1213 |
},
|
| 1214 |
{
|
| 1215 |
"cell_type": "code",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1216 |
"source": [
|
| 1217 |
"sample_paths = [\n",
|
| 1218 |
-
" \"/
|
| 1219 |
-
" \"/
|
| 1220 |
-
" \"/
|
| 1221 |
-
" \"/
|
| 1222 |
"]\n",
|
| 1223 |
"\n",
|
| 1224 |
"OPENAI_API_KEY = \"\""
|
| 1225 |
-
]
|
| 1226 |
-
"metadata": {
|
| 1227 |
-
"id": "WSygonZYdvpg"
|
| 1228 |
-
},
|
| 1229 |
-
"execution_count": 36,
|
| 1230 |
-
"outputs": []
|
| 1231 |
},
|
| 1232 |
{
|
| 1233 |
"cell_type": "code",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1234 |
"source": [
|
| 1235 |
"import os\n",
|
| 1236 |
"from pathlib import Path\n",
|
|
@@ -1243,23 +881,33 @@
|
|
| 1243 |
"import openai\n",
|
| 1244 |
"from openai import OpenAI\n",
|
| 1245 |
"\n",
|
| 1246 |
-
"
|
| 1247 |
-
"#
|
| 1248 |
-
"CKPT_PATH = \"/content/checkpoints/best_model_multiclass.pth\" # path to trained checkpoint\n",
|
| 1249 |
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 1250 |
"BACKBONE = \"efficientnet_b0\"\n",
|
| 1251 |
-
"NUM_CLASSES = 4\n",
|
| 1252 |
"IMG_SIZE = 224\n",
|
| 1253 |
"\n",
|
| 1254 |
-
"# Paste your 4 sample image absolute paths here\n",
|
| 1255 |
-
"#
|
| 1256 |
-
"sample_paths =
|
|
|
|
| 1257 |
"\n",
|
| 1258 |
-
"OPENAI_API_KEY = OPENAI_API_KEY\n",
|
| 1259 |
"if OPENAI_API_KEY:\n",
|
| 1260 |
" openai.api_key = OPENAI_API_KEY\n",
|
| 1261 |
"\n",
|
| 1262 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1263 |
"# -------------------------------------------------\n",
|
| 1264 |
"\n",
|
| 1265 |
"# Validate sample paths\n",
|
|
@@ -1295,7 +943,7 @@
|
|
| 1295 |
" idxs = probs.argsort()[::-1][:topk]\n",
|
| 1296 |
" return [{\"label\": IDX2LABEL[int(i)], \"score\": float(probs[int(i)])} for i in idxs]\n",
|
| 1297 |
"\n",
|
| 1298 |
-
"# OpenAI system prompt\n",
|
| 1299 |
"SYSTEM_TEMPLATE = \"\"\"You are an educational radiology assistant. The model predicted:\n",
|
| 1300 |
"{preds}\n",
|
| 1301 |
"\n",
|
|
@@ -1325,13 +973,12 @@
|
|
| 1325 |
" )\n",
|
| 1326 |
" return resp.choices[0].message.content.strip()\n",
|
| 1327 |
"\n",
|
| 1328 |
-
"\n",
|
| 1329 |
"# Prepare thumbnails (just images)\n",
|
| 1330 |
"thumbs = [Image.open(p).convert(\"RGB\") for p in sample_paths]\n",
|
| 1331 |
"\n",
|
| 1332 |
-
"# Gradio UI\n",
|
| 1333 |
"with gr.Blocks() as demo:\n",
|
| 1334 |
-
" gr.Markdown(\"## Brain MRI multi-class demo — Educational only\")\n",
|
| 1335 |
" with gr.Row():\n",
|
| 1336 |
" with gr.Column():\n",
|
| 1337 |
" # Use a simple list of PIL thumbnails for the gallery value\n",
|
|
@@ -1349,16 +996,16 @@
|
|
| 1349 |
" selected_path_state = gr.State(value=None)\n",
|
| 1350 |
" preds_state = gr.State(value=None)\n",
|
| 1351 |
"\n",
|
| 1352 |
-
" # select image from gallery\n",
|
| 1353 |
" def on_select(evt: gr.SelectData):\n",
|
| 1354 |
-
" idx = evt.index\n",
|
| 1355 |
-
" path = sample_paths[idx]
|
| 1356 |
" return path\n",
|
| 1357 |
"\n",
|
| 1358 |
" # ensure the gallery select writes only the path to the state\n",
|
| 1359 |
" gallery.select(fn=on_select, inputs=None, outputs=[selected_path_state])\n",
|
| 1360 |
"\n",
|
| 1361 |
-
" # analyze\n",
|
| 1362 |
" def analyze(path):\n",
|
| 1363 |
" try:\n",
|
| 1364 |
" if path is None:\n",
|
|
@@ -1375,7 +1022,7 @@
|
|
| 1375 |
"\n",
|
| 1376 |
" analyze_btn.click(fn=analyze, inputs=selected_path_state, outputs=[preds_output, chatbot, preds_state])\n",
|
| 1377 |
"\n",
|
| 1378 |
-
" # chat\n",
|
| 1379 |
" def chat(chat_history, msg, preds_text):\n",
|
| 1380 |
" if preds_text is None:\n",
|
| 1381 |
" return chat_history+[(\"AI\",\"Please analyze an image first.\")], \"\"\n",
|
|
@@ -1406,74 +1053,391 @@
|
|
| 1406 |
"\n",
|
| 1407 |
" gr.Markdown(\"⚠️ Educational use only — not a medical diagnosis.\")\n",
|
| 1408 |
"\n",
|
| 1409 |
-
"demo.launch(share=True)
|
| 1410 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1411 |
"metadata": {
|
| 1412 |
"colab": {
|
| 1413 |
-
"
|
| 1414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1415 |
},
|
| 1416 |
-
"
|
| 1417 |
-
|
| 1418 |
-
|
| 1419 |
-
|
| 1420 |
-
|
| 1421 |
-
|
| 1422 |
-
|
| 1423 |
-
|
| 1424 |
-
|
| 1425 |
-
"
|
| 1426 |
-
"
|
| 1427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1428 |
},
|
| 1429 |
-
{
|
| 1430 |
-
"
|
| 1431 |
-
"
|
| 1432 |
-
"
|
| 1433 |
-
|
| 1434 |
-
"
|
| 1435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1436 |
},
|
| 1437 |
-
{
|
| 1438 |
-
"
|
| 1439 |
-
"
|
| 1440 |
-
"
|
| 1441 |
-
|
| 1442 |
-
"
|
| 1443 |
-
"
|
| 1444 |
-
"
|
| 1445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1446 |
},
|
| 1447 |
-
{
|
| 1448 |
-
"
|
| 1449 |
-
"
|
| 1450 |
-
|
| 1451 |
-
|
| 1452 |
-
|
| 1453 |
-
"
|
| 1454 |
-
|
| 1455 |
-
|
| 1456 |
-
|
| 1457 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1458 |
},
|
| 1459 |
-
{
|
| 1460 |
-
"
|
| 1461 |
-
"
|
| 1462 |
-
|
| 1463 |
-
|
| 1464 |
-
|
| 1465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1466 |
}
|
| 1467 |
-
|
| 1468 |
-
},
|
| 1469 |
-
{
|
| 1470 |
-
"cell_type": "code",
|
| 1471 |
-
"source": [],
|
| 1472 |
-
"metadata": {
|
| 1473 |
-
"id": "TrS_HlY7d8s3"
|
| 1474 |
-
},
|
| 1475 |
-
"execution_count": null,
|
| 1476 |
-
"outputs": []
|
| 1477 |
}
|
| 1478 |
-
|
|
|
|
|
|
|
| 1479 |
}
|
|
|
|
| 1 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {
|
| 7 |
+
"colab": {
|
| 8 |
+
"base_uri": "https://localhost:8080/"
|
| 9 |
+
},
|
| 10 |
+
"id": "yViCfhqXTD-l",
|
| 11 |
+
"outputId": "e599cb0e-fb7f-44d0-c84d-82ce2d17c0fe"
|
| 12 |
+
},
|
| 13 |
+
"outputs": [
|
| 14 |
+
{
|
| 15 |
+
"output_type": "stream",
|
| 16 |
+
"name": "stdout",
|
| 17 |
+
"text": [
|
| 18 |
+
"\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/832.4 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m829.4/832.4 kB\u001b[0m \u001b[31m36.7 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m832.4/832.4 kB\u001b[0m \u001b[31m22.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 19 |
+
"\u001b[?25h\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/983.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m983.2/983.2 kB\u001b[0m \u001b[31m52.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 20 |
+
"\u001b[?25h"
|
| 21 |
+
]
|
| 22 |
+
}
|
| 23 |
+
],
|
| 24 |
"source": [
|
| 25 |
"# Colab: install required packages\n",
|
| 26 |
"!pip install -q kagglehub timm torch torchvision scikit-learn pandas pillow matplotlib gradio openai pytorch-lightning"
|
| 27 |
+
]
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"cell_type": "code",
|
| 31 |
+
"execution_count": 2,
|
| 32 |
"metadata": {
|
| 33 |
"colab": {
|
| 34 |
"base_uri": "https://localhost:8080/"
|
| 35 |
},
|
| 36 |
+
"id": "B7iabh3oOyNe",
|
| 37 |
+
"outputId": "fb2920c5-a19b-49c8-814e-640ce263c27b"
|
| 38 |
},
|
|
|
|
| 39 |
"outputs": [
|
| 40 |
{
|
| 41 |
"output_type": "stream",
|
| 42 |
"name": "stdout",
|
| 43 |
"text": [
|
| 44 |
+
"Downloading from https://www.kaggle.com/api/v1/datasets/download/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri?dataset_version_number=1...\n"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"output_type": "stream",
|
| 49 |
+
"name": "stderr",
|
| 50 |
+
"text": [
|
| 51 |
+
"100%|██████████| 361M/361M [00:02<00:00, 183MB/s]"
|
| 52 |
+
]
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"output_type": "stream",
|
| 56 |
+
"name": "stdout",
|
| 57 |
+
"text": [
|
| 58 |
+
"Extracting files...\n"
|
| 59 |
+
]
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"output_type": "stream",
|
| 63 |
+
"name": "stderr",
|
| 64 |
+
"text": [
|
| 65 |
+
"\n"
|
| 66 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
{
|
| 69 |
"output_type": "stream",
|
| 70 |
"name": "stdout",
|
| 71 |
"text": [
|
| 72 |
+
"Dataset path: /root/.cache/kagglehub/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/versions/1\n",
|
| 73 |
+
"Dataset\n"
|
|
|
|
|
|
|
| 74 |
]
|
| 75 |
}
|
| 76 |
],
|
| 77 |
"source": [
|
| 78 |
"import kagglehub\n",
|
| 79 |
+
"dataset_ref = \"murtozalikhon/brain-tumor-multimodal-image-ct-and-mri\"\n",
|
| 80 |
"path = kagglehub.dataset_download(dataset_ref)\n",
|
| 81 |
"print(\"Dataset path:\", path)\n",
|
| 82 |
"\n",
|
|
|
|
| 88 |
},
|
| 89 |
{
|
| 90 |
"cell_type": "code",
|
| 91 |
+
"execution_count": 3,
|
|
|
|
|
|
|
|
|
|
| 92 |
"metadata": {
|
| 93 |
"colab": {
|
| 94 |
"base_uri": "https://localhost:8080/"
|
| 95 |
},
|
| 96 |
"id": "gq9zpPkmZVK7",
|
| 97 |
+
"outputId": "174e6148-05a4-4b4b-8ebe-7b62ea806b1b"
|
| 98 |
},
|
|
|
|
| 99 |
"outputs": [
|
| 100 |
{
|
| 101 |
"output_type": "stream",
|
|
|
|
| 104 |
"True\n"
|
| 105 |
]
|
| 106 |
}
|
| 107 |
+
],
|
| 108 |
+
"source": [
|
| 109 |
+
"import torch\n",
|
| 110 |
+
"print(torch.cuda.is_available())"
|
| 111 |
]
|
| 112 |
},
|
| 113 |
{
|
| 114 |
"cell_type": "code",
|
| 115 |
+
"execution_count": null,
|
| 116 |
+
"metadata": {
|
| 117 |
+
"id": "YvSgspBGU3iG"
|
| 118 |
+
},
|
| 119 |
+
"outputs": [],
|
| 120 |
"source": [
|
| 121 |
"# Create CSV (multi-class) from folders (in-memory)\n",
|
| 122 |
"\n",
|
| 123 |
"# Change DATA_ROOT to the folder that directly contains glioma_tumor, meningioma_tumor, pituitary_tumor, no_tumor."
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"cell_type": "code",
|
| 128 |
+
"source": [
|
| 129 |
+
"# CONFIG - edit these paths before running the pipeline\n",
|
| 130 |
+
"DATA_ROOT = \"/root/.cache/kagglehub/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/versions/1/Dataset\"\n",
|
| 131 |
+
"# Where to write outputs / checkpoints\n",
|
| 132 |
+
"OUT_DIR = \"/content/output\"\n",
|
| 133 |
+
"os.makedirs(OUT_DIR, exist_ok=True)\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"# Training hyperparams (change if needed)\n",
|
| 136 |
+
"IMG_SIZE = 224\n",
|
| 137 |
+
"BATCH_SIZE = 8\n",
|
| 138 |
+
"EPOCHS = 10\n",
|
| 139 |
+
"LR = 1e-4\n",
|
| 140 |
+
"NUM_WORKERS = 2 # adjust for Colab\n",
|
| 141 |
+
"TOP_K = 3\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"print(\"DATA_ROOT:\", DATA_ROOT)\n"
|
| 144 |
],
|
| 145 |
"metadata": {
|
| 146 |
+
"colab": {
|
| 147 |
+
"base_uri": "https://localhost:8080/"
|
| 148 |
+
},
|
| 149 |
+
"id": "H5gEKtfMJxVQ",
|
| 150 |
+
"outputId": "fd555a63-b825-4860-86df-e4bc3f0caa76"
|
| 151 |
},
|
| 152 |
+
"execution_count": 23,
|
| 153 |
+
"outputs": [
|
| 154 |
+
{
|
| 155 |
+
"output_type": "stream",
|
| 156 |
+
"name": "stdout",
|
| 157 |
+
"text": [
|
| 158 |
+
"DATA_ROOT: /root/.cache/kagglehub/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/versions/1/Dataset\n"
|
| 159 |
+
]
|
| 160 |
+
}
|
| 161 |
+
]
|
| 162 |
},
|
| 163 |
{
|
| 164 |
"cell_type": "code",
|
| 165 |
+
"execution_count": 19,
|
| 166 |
+
"metadata": {
|
| 167 |
+
"colab": {
|
| 168 |
+
"base_uri": "https://localhost:8080/"
|
| 169 |
+
},
|
| 170 |
+
"id": "1grjHP2ZU5st",
|
| 171 |
+
"outputId": "050866ba-120a-40d9-a7b3-d33cdcd21394"
|
| 172 |
+
},
|
| 173 |
+
"outputs": [
|
| 174 |
+
{
|
| 175 |
+
"output_type": "stream",
|
| 176 |
+
"name": "stdout",
|
| 177 |
+
"text": [
|
| 178 |
+
"Total images: 9618\n",
|
| 179 |
+
"label\n",
|
| 180 |
+
"no_tumor 4300\n",
|
| 181 |
+
"ct_tumor 2318\n",
|
| 182 |
+
"meningioma 1112\n",
|
| 183 |
+
"glioma 672\n",
|
| 184 |
+
"pituitary 629\n",
|
| 185 |
+
"mri_tumor 587\n",
|
| 186 |
+
"Name: count, dtype: int64\n",
|
| 187 |
+
"Wrote CSV: /content/output/combined_image_labels.csv\n"
|
| 188 |
+
]
|
| 189 |
+
}
|
| 190 |
+
],
|
| 191 |
"source": [
|
| 192 |
+
"# Build combined CSV for your described dataset layout\n",
|
| 193 |
+
"# Paste & run in Colab; set PARENT_DIR to your dataset root\n",
|
| 194 |
+
"\n",
|
| 195 |
"import os, pandas as pd\n",
|
| 196 |
+
"from pathlib import Path\n",
|
| 197 |
"\n",
|
| 198 |
+
"# EDIT: point this to the parent folder that contains the two modality folders\n",
|
| 199 |
+
"PARENT_DIR = DATA_ROOT # ← change this\n",
|
| 200 |
"\n",
|
| 201 |
+
"# expected modality folder names (we'll match case-insensitively)\n",
|
| 202 |
+
"# adjust if your folder names differ\n",
|
| 203 |
+
"CT_FOLDER_KEY = \"ct\"\n",
|
| 204 |
+
"MRI_FOLDER_KEY = \"mri\"\n",
|
|
|
|
|
|
|
| 205 |
"\n",
|
| 206 |
"rows = []\n",
|
| 207 |
+
"for mod_name in sorted(os.listdir(PARENT_DIR)):\n",
|
| 208 |
+
" mod_path = os.path.join(PARENT_DIR, mod_name)\n",
|
| 209 |
+
" if not os.path.isdir(mod_path):\n",
|
|
|
|
|
|
|
|
|
|
| 210 |
" continue\n",
|
| 211 |
+
" mod_lower = mod_name.lower()\n",
|
| 212 |
+
" modality = \"CT\" if CT_FOLDER_KEY in mod_lower else (\"MRI\" if MRI_FOLDER_KEY in mod_lower else mod_name)\n",
|
| 213 |
+
" # inside each modality we expect subfolders: 'healthy' and 'tumor'\n",
|
| 214 |
+
" for sub in sorted(os.listdir(mod_path)):\n",
|
| 215 |
+
" sub_path = os.path.join(mod_path, sub)\n",
|
| 216 |
+
" if not os.path.isdir(sub_path):\n",
|
| 217 |
+
" continue\n",
|
| 218 |
+
" sub_lower = sub.lower()\n",
|
| 219 |
+
" # CT modality handling\n",
|
| 220 |
+
" if modality == \"CT\":\n",
|
| 221 |
+
" if \"healthy\" in sub_lower:\n",
|
| 222 |
+
" lab = \"no_tumor\"\n",
|
| 223 |
+
" else:\n",
|
| 224 |
+
" lab = \"ct_tumor\"\n",
|
| 225 |
+
" for fn in os.listdir(sub_path):\n",
|
| 226 |
+
" if fn.lower().endswith((\".png\",\".jpg\",\".jpeg\")):\n",
|
| 227 |
+
" rows.append({\"image\": os.path.join(sub_path, fn), \"label\": lab, \"modality\": \"CT\"})\n",
|
| 228 |
+
" # MRI modality handling\n",
|
| 229 |
+
" elif modality == \"MRI\":\n",
|
| 230 |
+
" if \"healthy\" in sub_lower:\n",
|
| 231 |
+
" for fn in os.listdir(sub_path):\n",
|
| 232 |
+
" if fn.lower().endswith((\".png\",\".jpg\",\".jpeg\")):\n",
|
| 233 |
+
" rows.append({\"image\": os.path.join(sub_path, fn), \"label\": \"no_tumor\", \"modality\": \"MRI\"})\n",
|
| 234 |
+
" else:\n",
|
| 235 |
+
" # tumor folder: try to parse subtype from filename, else fallback\n",
|
| 236 |
+
" for fn in os.listdir(sub_path):\n",
|
| 237 |
+
" if not fn.lower().endswith((\".png\",\".jpg\",\".jpeg\")):\n",
|
| 238 |
+
" continue\n",
|
| 239 |
+
" low = fn.lower()\n",
|
| 240 |
+
" if \"glioma\" in low:\n",
|
| 241 |
+
" lab = \"glioma\"\n",
|
| 242 |
+
" elif \"meningioma\" in low:\n",
|
| 243 |
+
" lab = \"meningioma\"\n",
|
| 244 |
+
" elif \"pituitary\" in low:\n",
|
| 245 |
+
" lab = \"pituitary\"\n",
|
| 246 |
+
" else:\n",
|
| 247 |
+
" lab = \"mri_tumor\"\n",
|
| 248 |
+
" rows.append({\"image\": os.path.join(sub_path, fn), \"label\": lab, \"modality\": \"MRI\"})\n",
|
| 249 |
+
" else:\n",
|
| 250 |
+
" # unexpected modality name: treat similarly to MRI fallback\n",
|
| 251 |
+
" for fn in os.listdir(sub_path):\n",
|
| 252 |
+
" if fn.lower().endswith((\".png\",\".jpg\",\".jpeg\")):\n",
|
| 253 |
+
" rows.append({\"image\": os.path.join(sub_path, fn), \"label\": \"unknown_modality\", \"modality\": mod_name})\n",
|
| 254 |
"\n",
|
| 255 |
"df = pd.DataFrame(rows).sample(frac=1, random_state=42).reset_index(drop=True)\n",
|
| 256 |
"print(\"Total images:\", len(df))\n",
|
| 257 |
"print(df.label.value_counts())\n",
|
| 258 |
+
"out_csv = os.path.join(OUT_DIR, \"combined_image_labels.csv\")\n",
|
| 259 |
+
"df.to_csv(out_csv, index=False)\n",
|
| 260 |
+
"print(\"Wrote CSV:\", out_csv)\n"
|
| 261 |
+
]
|
| 262 |
+
},
|
| 263 |
+
{
|
| 264 |
+
"cell_type": "code",
|
| 265 |
+
"execution_count": 20,
|
| 266 |
"metadata": {
|
| 267 |
"colab": {
|
| 268 |
+
"base_uri": "https://localhost:8080/"
|
|
|
|
| 269 |
},
|
| 270 |
+
"id": "OvqCFuB1U5wR",
|
| 271 |
+
"outputId": "be745f82-a3b2-4b69-ea06-f0615f8cc321"
|
| 272 |
},
|
|
|
|
| 273 |
"outputs": [
|
| 274 |
{
|
| 275 |
"output_type": "stream",
|
| 276 |
"name": "stdout",
|
| 277 |
"text": [
|
| 278 |
+
"Train: 7213 Val: 1202 Test: 1203\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
}
|
| 281 |
+
],
|
| 282 |
+
"source": [
|
| 283 |
+
"# Stratified split train/val/test\n",
|
| 284 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 285 |
+
"csv_path = os.path.join(OUT_DIR, \"combined_image_labels.csv\")\n",
|
| 286 |
+
"df = pd.read_csv(csv_path)\n",
|
| 287 |
+
"\n",
|
| 288 |
+
"train_df, temp_df = train_test_split(df, test_size=0.25, stratify=df['label'], random_state=42)\n",
|
| 289 |
+
"val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['label'], random_state=42)\n",
|
| 290 |
+
"\n",
|
| 291 |
+
"print(\"Train:\", len(train_df), \"Val:\", len(val_df), \"Test:\", len(test_df))\n",
|
| 292 |
+
"train_df.to_csv(os.path.join(OUT_DIR, \"train.csv\"), index=False)\n",
|
| 293 |
+
"val_df.to_csv(os.path.join(OUT_DIR, \"val.csv\"), index=False)\n",
|
| 294 |
+
"test_df.to_csv(os.path.join(OUT_DIR, \"test.csv\"), index=False)\n"
|
| 295 |
]
|
| 296 |
},
|
| 297 |
+
{
|
| 298 |
+
"cell_type": "code",
|
| 299 |
+
"source": [],
|
| 300 |
+
"metadata": {
|
| 301 |
+
"id": "MTmULrSFKdxe"
|
| 302 |
+
},
|
| 303 |
+
"execution_count": null,
|
| 304 |
+
"outputs": []
|
| 305 |
+
},
|
| 306 |
{
|
| 307 |
"cell_type": "code",
|
| 308 |
"source": [
|
| 309 |
+
"# Build label -> index mapping from training labels\n",
|
| 310 |
+
"labels = sorted(train_df['label'].unique().tolist())\n",
|
| 311 |
+
"LABEL2IDX = {lab: idx for idx, lab in enumerate(labels)}\n",
|
| 312 |
+
"IDX2LABEL = {v:k for k,v in LABEL2IDX.items()}\n",
|
| 313 |
+
"print(\"Labels:\", LABEL2IDX)\n",
|
| 314 |
+
"NUM_CLASSES = len(labels)\n"
|
|
|
|
|
|
|
|
|
|
| 315 |
],
|
| 316 |
"metadata": {
|
| 317 |
"colab": {
|
| 318 |
"base_uri": "https://localhost:8080/"
|
| 319 |
},
|
| 320 |
+
"id": "LKnSLphiKd6i",
|
| 321 |
+
"outputId": "48c16684-e96c-4146-9ba6-1fe0a28e138d"
|
| 322 |
},
|
| 323 |
+
"execution_count": 21,
|
| 324 |
"outputs": [
|
| 325 |
{
|
| 326 |
"output_type": "stream",
|
| 327 |
"name": "stdout",
|
| 328 |
"text": [
|
| 329 |
+
"Labels: {'ct_tumor': 0, 'glioma': 1, 'meningioma': 2, 'mri_tumor': 3, 'no_tumor': 4, 'pituitary': 5}\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
]
|
| 331 |
}
|
| 332 |
]
|
| 333 |
},
|
| 334 |
{
|
| 335 |
"cell_type": "code",
|
| 336 |
+
"execution_count": null,
|
|
|
|
|
|
|
| 337 |
"metadata": {
|
| 338 |
"id": "MYxdFZEiY4_9"
|
| 339 |
},
|
| 340 |
+
"outputs": [],
|
| 341 |
+
"source": []
|
| 342 |
},
|
| 343 |
{
|
| 344 |
"cell_type": "code",
|
| 345 |
+
"execution_count": 24,
|
| 346 |
+
"metadata": {
|
| 347 |
+
"colab": {
|
| 348 |
+
"base_uri": "https://localhost:8080/"
|
| 349 |
+
},
|
| 350 |
+
"id": "Nd0TbtAlY7to",
|
| 351 |
+
"outputId": "63032d78-ebe7-4370-c780-746cee650c8f"
|
| 352 |
+
},
|
| 353 |
+
"outputs": [
|
| 354 |
+
{
|
| 355 |
+
"output_type": "stream",
|
| 356 |
+
"name": "stdout",
|
| 357 |
+
"text": [
|
| 358 |
+
"batch x shape: torch.Size([8, 3, 224, 224]) y shape: torch.Size([8])\n"
|
| 359 |
+
]
|
| 360 |
+
}
|
| 361 |
+
],
|
| 362 |
"source": [
|
| 363 |
+
"# Define Dataset and Transforms (inline)\n",
|
| 364 |
"import torch\n",
|
| 365 |
"from torch.utils.data import Dataset, DataLoader\n",
|
| 366 |
"from PIL import Image\n",
|
| 367 |
"import torchvision.transforms as T\n",
|
| 368 |
"\n",
|
| 369 |
+
"# transforms (same for MRI/CT for MVP)\n",
|
|
|
|
| 370 |
"train_tf = T.Compose([\n",
|
| 371 |
" T.Resize((IMG_SIZE, IMG_SIZE)),\n",
|
| 372 |
" T.RandomHorizontalFlip(),\n",
|
| 373 |
" T.RandomRotation(10),\n",
|
| 374 |
+
" T.ColorJitter(0.08, 0.08, 0.08, 0.02),\n",
|
| 375 |
" T.ToTensor(),\n",
|
| 376 |
" T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])\n",
|
| 377 |
"])\n",
|
|
|
|
| 381 |
" T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])\n",
|
| 382 |
"])\n",
|
| 383 |
"\n",
|
| 384 |
+
"class MultiModalityDataset(Dataset):\n",
|
| 385 |
+
" def __init__(self, df, transform=None, label_map=None):\n",
|
| 386 |
" self.df = df.reset_index(drop=True)\n",
|
| 387 |
" self.transform = transform\n",
|
| 388 |
+
" self.label_map = label_map\n",
|
| 389 |
" def __len__(self):\n",
|
| 390 |
" return len(self.df)\n",
|
|
|
|
| 391 |
" def __getitem__(self, idx):\n",
|
| 392 |
" row = self.df.iloc[idx]\n",
|
| 393 |
+
" img_path = row['image']\n",
|
| 394 |
+
" img = Image.open(img_path).convert('RGB') # ensures 3 channels\n",
|
| 395 |
" if self.transform:\n",
|
| 396 |
" img = self.transform(img)\n",
|
| 397 |
+
" label = self.label_map[row['label']]\n",
|
| 398 |
" return img, label\n",
|
| 399 |
"\n",
|
| 400 |
+
"train_ds = MultiModalityDataset(train_df, transform=train_tf, label_map=LABEL2IDX)\n",
|
| 401 |
+
"val_ds = MultiModalityDataset(val_df, transform=val_tf, label_map=LABEL2IDX)\n",
|
| 402 |
+
"test_ds = MultiModalityDataset(test_df, transform=val_tf, label_map=LABEL2IDX)\n",
|
| 403 |
+
"\n",
|
| 404 |
+
"train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)\n",
|
| 405 |
+
"val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)\n",
|
| 406 |
+
"test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)\n",
|
| 407 |
+
"\n",
|
| 408 |
+
"# sanity check\n",
|
| 409 |
+
"x,y = next(iter(train_loader))\n",
|
| 410 |
+
"print(\"batch x shape:\", x.shape, \"y shape:\", y.shape)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
]
|
| 412 |
},
|
| 413 |
{
|
| 414 |
"cell_type": "code",
|
| 415 |
+
"execution_count": null,
|
| 416 |
"metadata": {
|
| 417 |
"id": "SrJLOm33ZMpC"
|
| 418 |
},
|
| 419 |
+
"outputs": [],
|
| 420 |
+
"source": []
|
| 421 |
},
|
| 422 |
{
|
| 423 |
"cell_type": "code",
|
| 424 |
+
"execution_count": null,
|
|
|
|
|
|
|
| 425 |
"metadata": {
|
| 426 |
"id": "bG2JclRDY-Di"
|
| 427 |
},
|
| 428 |
+
"outputs": [],
|
| 429 |
+
"source": []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
},
|
| 431 |
{
|
| 432 |
"cell_type": "code",
|
| 433 |
+
"execution_count": 25,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
"metadata": {
|
| 435 |
+
"id": "Jc4pdqreZGDS",
|
| 436 |
"colab": {
|
| 437 |
"base_uri": "https://localhost:8080/",
|
| 438 |
+
"height": 188,
|
| 439 |
"referenced_widgets": [
|
| 440 |
+
"fd1f533d7d814e10b252bf66f8c0697e",
|
| 441 |
+
"243ec1d23b8e469e833db01f89189639",
|
| 442 |
+
"9bcf41c7fe1e425aae4e37dcfca6a67c",
|
| 443 |
+
"f09e93d2fdc64071b3b9860fd7782c58",
|
| 444 |
+
"0a03432f7ce14c8c98090f846a512beb",
|
| 445 |
+
"1da74b90beb44d85a22ca57ee52b1e9d",
|
| 446 |
+
"a7c35ed82c20419b9b559b217d924a6a",
|
| 447 |
+
"0a302207f9884533b49fb150411169d2",
|
| 448 |
+
"bc007e650c0f4f619946055a6c72d132",
|
| 449 |
+
"20b5ba3ee19446dc89ddb5222be6f5f2",
|
| 450 |
+
"1e35273de451460aa8c82df17b3ca23c"
|
| 451 |
]
|
| 452 |
},
|
| 453 |
+
"outputId": "8f6811b3-e720-4ffd-b49b-45c177d2972e"
|
|
|
|
| 454 |
},
|
|
|
|
| 455 |
"outputs": [
|
| 456 |
+
{
|
| 457 |
+
"output_type": "stream",
|
| 458 |
+
"name": "stdout",
|
| 459 |
+
"text": [
|
| 460 |
+
"Using device: cuda\n"
|
| 461 |
+
]
|
| 462 |
+
},
|
| 463 |
{
|
| 464 |
"output_type": "stream",
|
| 465 |
"name": "stderr",
|
|
|
|
| 481 |
"application/vnd.jupyter.widget-view+json": {
|
| 482 |
"version_major": 2,
|
| 483 |
"version_minor": 0,
|
| 484 |
+
"model_id": "fd1f533d7d814e10b252bf66f8c0697e"
|
| 485 |
}
|
| 486 |
},
|
| 487 |
"metadata": {}
|
| 488 |
+
},
|
| 489 |
+
{
|
| 490 |
+
"output_type": "stream",
|
| 491 |
+
"name": "stdout",
|
| 492 |
+
"text": [
|
| 493 |
+
"Class weights: tensor([0.6917, 2.3853, 1.4414, 2.7322, 0.3728, 2.5470], device='cuda:0')\n"
|
| 494 |
+
]
|
| 495 |
}
|
| 496 |
+
],
|
| 497 |
+
"source": [
|
| 498 |
+
"# Model, loss, optimizer, optional class weights\n",
|
| 499 |
+
"import timm\n",
|
| 500 |
+
"import torch.nn as nn\n",
|
| 501 |
+
"\n",
|
| 502 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 503 |
+
"print(\"Using device:\", device)\n",
|
| 504 |
+
"\n",
|
| 505 |
+
"model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=NUM_CLASSES)\n",
|
| 506 |
+
"model = model.to(device)\n",
|
| 507 |
+
"\n",
|
| 508 |
+
"# Optionally compute class weights if highly imbalanced\n",
|
| 509 |
+
"from sklearn.utils.class_weight import compute_class_weight\n",
|
| 510 |
+
"import numpy as np\n",
|
| 511 |
+
"cls_w = compute_class_weight(class_weight='balanced', classes=np.arange(NUM_CLASSES), y=train_df['label'].map(LABEL2IDX))\n",
|
| 512 |
+
"cls_w = torch.tensor(cls_w, dtype=torch.float).to(device)\n",
|
| 513 |
+
"print(\"Class weights:\", cls_w)\n",
|
| 514 |
+
"\n",
|
| 515 |
+
"loss_fn = nn.CrossEntropyLoss(weight=cls_w)\n",
|
| 516 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)\n",
|
| 517 |
+
"scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=2)\n"
|
| 518 |
]
|
| 519 |
},
|
| 520 |
{
|
| 521 |
"cell_type": "code",
|
| 522 |
+
"execution_count": 26,
|
| 523 |
+
"metadata": {
|
| 524 |
+
"colab": {
|
| 525 |
+
"base_uri": "https://localhost:8080/"
|
| 526 |
+
},
|
| 527 |
+
"id": "nUDYRwowZRGb",
|
| 528 |
+
"outputId": "bdac381c-00bd-4330-f8d9-f2ba47208e67"
|
| 529 |
+
},
|
| 530 |
+
"outputs": [
|
| 531 |
+
{
|
| 532 |
+
"output_type": "stream",
|
| 533 |
+
"name": "stdout",
|
| 534 |
+
"text": [
|
| 535 |
+
"Epoch 1/12 | train_loss 0.7560 train_acc 0.7829 | val_loss 0.4183 val_acc 0.8968 | 1.48 min\n",
|
| 536 |
+
"Saved best model: 0.8968386023294509\n",
|
| 537 |
+
"Epoch 2/12 | train_loss 0.3758 train_acc 0.8959 | val_loss 0.3714 val_acc 0.9160 | 1.32 min\n",
|
| 538 |
+
"Saved best model: 0.9159733777038269\n",
|
| 539 |
+
"Epoch 3/12 | train_loss 0.2867 train_acc 0.9186 | val_loss 0.3187 val_acc 0.9268 | 1.27 min\n",
|
| 540 |
+
"Saved best model: 0.9267886855241264\n",
|
| 541 |
+
"Epoch 4/12 | train_loss 0.2459 train_acc 0.9326 | val_loss 0.2907 val_acc 0.9368 | 1.28 min\n",
|
| 542 |
+
"Saved best model: 0.9367720465890182\n",
|
| 543 |
+
"Epoch 5/12 | train_loss 0.2312 train_acc 0.9368 | val_loss 0.3151 val_acc 0.9285 | 1.26 min\n",
|
| 544 |
+
"Epoch 6/12 | train_loss 0.2106 train_acc 0.9427 | val_loss 0.3328 val_acc 0.9276 | 1.26 min\n",
|
| 545 |
+
"Epoch 7/12 | train_loss 0.1896 train_acc 0.9473 | val_loss 0.3394 val_acc 0.9301 | 1.28 min\n",
|
| 546 |
+
"Early stopping triggered. Stopping training.\n"
|
| 547 |
+
]
|
| 548 |
+
}
|
| 549 |
+
],
|
| 550 |
"source": [
|
| 551 |
+
"# Training loop with val_loss, val_acc, checkpointing & early stopping\n",
|
| 552 |
"from sklearn.metrics import accuracy_score\n",
|
| 553 |
"import time, os\n",
|
| 554 |
+
"import pandas as pd\n",
|
| 555 |
"\n",
|
| 556 |
+
"os.makedirs(OUT_DIR, exist_ok=True)\n",
|
|
|
|
|
|
|
| 557 |
"best_val_acc = 0.0\n",
|
| 558 |
+
"patience = 3\n",
|
| 559 |
+
"wait = 0\n",
|
| 560 |
+
"logs = []\n",
|
| 561 |
"\n",
|
| 562 |
"for epoch in range(1, EPOCHS+1):\n",
|
| 563 |
+
" t0 = time.time()\n",
|
| 564 |
+
" # train\n",
|
| 565 |
" model.train()\n",
|
| 566 |
+
" train_loss = 0.0\n",
|
| 567 |
" all_preds, all_labels = [], []\n",
|
|
|
|
|
|
|
|
|
|
| 568 |
" for imgs, labels in train_loader:\n",
|
| 569 |
" imgs = imgs.to(device)\n",
|
| 570 |
" labels = labels.to(device)\n",
|
|
|
|
| 573 |
" loss = loss_fn(logits, labels)\n",
|
| 574 |
" loss.backward()\n",
|
| 575 |
" optimizer.step()\n",
|
| 576 |
+
" train_loss += loss.item() * imgs.size(0)\n",
|
| 577 |
" preds = logits.argmax(dim=1).cpu().numpy()\n",
|
| 578 |
" all_preds.extend(preds.tolist())\n",
|
| 579 |
" all_labels.extend(labels.cpu().numpy().tolist())\n",
|
| 580 |
+
" train_loss = train_loss / len(train_loader.dataset)\n",
|
|
|
|
| 581 |
" train_acc = accuracy_score(all_labels, all_preds)\n",
|
| 582 |
"\n",
|
| 583 |
+
" # val\n",
|
| 584 |
" model.eval()\n",
|
| 585 |
+
" val_loss = 0.0\n",
|
| 586 |
" v_preds, v_labels = [], []\n",
|
| 587 |
" with torch.no_grad():\n",
|
| 588 |
" for imgs, labels in val_loader:\n",
|
| 589 |
" imgs = imgs.to(device)\n",
|
| 590 |
" labels = labels.to(device)\n",
|
| 591 |
" logits = model(imgs)\n",
|
| 592 |
+
" loss = loss_fn(logits, labels)\n",
|
| 593 |
+
" val_loss += loss.item() * imgs.size(0)\n",
|
| 594 |
" preds = logits.argmax(dim=1).cpu().numpy()\n",
|
| 595 |
" v_preds.extend(preds.tolist())\n",
|
| 596 |
" v_labels.extend(labels.cpu().numpy().tolist())\n",
|
| 597 |
+
" val_loss = val_loss / len(val_loader.dataset)\n",
|
|
|
|
| 598 |
" val_acc = accuracy_score(v_labels, v_preds)\n",
|
| 599 |
"\n",
|
| 600 |
+
" # scheduler step\n",
|
| 601 |
+
" scheduler.step(val_loss)\n",
|
| 602 |
+
"\n",
|
| 603 |
+
" elapsed = (time.time() - t0) / 60\n",
|
| 604 |
+
" print(f\"Epoch {epoch}/{EPOCHS} | train_loss {train_loss:.4f} train_acc {train_acc:.4f} | val_loss {val_loss:.4f} val_acc {val_acc:.4f} | {elapsed:.2f} min\")\n",
|
| 605 |
+
" logs.append({\"epoch\":epoch,\"train_loss\":train_loss,\"train_acc\":train_acc,\"val_loss\":val_loss,\"val_acc\":val_acc})\n",
|
| 606 |
"\n",
|
| 607 |
+
" # checkpoint\n",
|
| 608 |
+
" ckpt = os.path.join(OUT_DIR, f\"epoch{epoch}.pth\")\n",
|
| 609 |
+
" torch.save(model.state_dict(), ckpt)\n",
|
| 610 |
" if val_acc > best_val_acc:\n",
|
| 611 |
" best_val_acc = val_acc\n",
|
| 612 |
+
" torch.save(model.state_dict(), os.path.join(OUT_DIR, \"best_model.pth\"))\n",
|
| 613 |
+
" print(\"Saved best model:\", best_val_acc)\n",
|
| 614 |
+
" wait = 0\n",
|
| 615 |
+
" else:\n",
|
| 616 |
+
" wait += 1\n",
|
| 617 |
+
" if wait >= patience:\n",
|
| 618 |
+
" print(\"Early stopping triggered. Stopping training.\")\n",
|
| 619 |
+
" break\n",
|
| 620 |
+
"\n",
|
| 621 |
+
"# save logs to csv\n",
|
| 622 |
+
"pd.DataFrame(logs).to_csv(os.path.join(OUT_DIR, \"training_log.csv\"), index=False)"
|
| 623 |
+
]
|
| 624 |
+
},
|
| 625 |
+
{
|
| 626 |
+
"cell_type": "code",
|
| 627 |
+
"execution_count": null,
|
| 628 |
+
"metadata": {
|
| 629 |
+
"id": "I2WUIxNnZtVI"
|
| 630 |
+
},
|
| 631 |
+
"outputs": [],
|
| 632 |
+
"source": []
|
| 633 |
+
},
|
| 634 |
+
{
|
| 635 |
+
"cell_type": "code",
|
| 636 |
+
"source": [
|
| 637 |
+
"# Evaluate on test set (classification report + confusion matrix)\n",
|
| 638 |
+
"from sklearn.metrics import classification_report, confusion_matrix\n",
|
| 639 |
+
"import numpy as np\n",
|
| 640 |
+
"\n",
|
| 641 |
+
"# load best model\n",
|
| 642 |
+
"best_path = os.path.join(OUT_DIR, \"best_model.pth\")\n",
|
| 643 |
+
"model.load_state_dict(torch.load(best_path, map_location=device))\n",
|
| 644 |
+
"model.eval()\n",
|
| 645 |
+
"\n",
|
| 646 |
+
"y_true, y_pred = [], []\n",
|
| 647 |
+
"with torch.no_grad():\n",
|
| 648 |
+
" for imgs, labels in test_loader:\n",
|
| 649 |
+
" imgs = imgs.to(device)\n",
|
| 650 |
+
" labels = labels.to(device)\n",
|
| 651 |
+
" logits = model(imgs)\n",
|
| 652 |
+
" preds = logits.argmax(dim=1).cpu().numpy()\n",
|
| 653 |
+
" y_pred.extend(preds.tolist())\n",
|
| 654 |
+
" y_true.extend(labels.cpu().numpy().tolist())\n",
|
| 655 |
+
"\n",
|
| 656 |
+
"print(\"Classification Report:\")\n",
|
| 657 |
+
"print(classification_report(y_true, y_pred, target_names=[IDX2LABEL[i] for i in range(NUM_CLASSES)]))\n",
|
| 658 |
+
"print(\"Confusion matrix:\")\n",
|
| 659 |
+
"print(confusion_matrix(y_true, y_pred))\n"
|
| 660 |
],
|
| 661 |
"metadata": {
|
| 662 |
"colab": {
|
| 663 |
"base_uri": "https://localhost:8080/"
|
| 664 |
},
|
| 665 |
+
"id": "nbkHkAPELGYq",
|
| 666 |
+
"outputId": "19358314-2ce1-4bf4-8844-45a0d723100c"
|
| 667 |
},
|
| 668 |
+
"execution_count": 28,
|
| 669 |
"outputs": [
|
| 670 |
{
|
| 671 |
"output_type": "stream",
|
| 672 |
"name": "stdout",
|
| 673 |
"text": [
|
| 674 |
+
"Classification Report:\n",
|
| 675 |
+
" precision recall f1-score support\n",
|
| 676 |
+
"\n",
|
| 677 |
+
" ct_tumor 0.98 0.99 0.99 290\n",
|
| 678 |
+
" glioma 0.95 0.93 0.94 84\n",
|
| 679 |
+
" meningioma 0.77 0.95 0.85 139\n",
|
| 680 |
+
" mri_tumor 0.80 0.53 0.64 73\n",
|
| 681 |
+
" no_tumor 1.00 0.99 0.99 538\n",
|
| 682 |
+
" pituitary 0.99 0.94 0.96 79\n",
|
| 683 |
+
"\n",
|
| 684 |
+
" accuracy 0.95 1203\n",
|
| 685 |
+
" macro avg 0.91 0.89 0.90 1203\n",
|
| 686 |
+
"weighted avg 0.95 0.95 0.95 1203\n",
|
| 687 |
+
"\n",
|
| 688 |
+
"Confusion matrix:\n",
|
| 689 |
+
"[[288 0 0 0 2 0]\n",
|
| 690 |
+
" [ 0 78 5 1 0 0]\n",
|
| 691 |
+
" [ 0 0 132 7 0 0]\n",
|
| 692 |
+
" [ 0 4 29 39 0 1]\n",
|
| 693 |
+
" [ 5 0 2 0 531 0]\n",
|
| 694 |
+
" [ 0 0 3 2 0 74]]\n"
|
| 695 |
]
|
| 696 |
}
|
| 697 |
]
|
|
|
|
| 700 |
"cell_type": "code",
|
| 701 |
"source": [],
|
| 702 |
"metadata": {
|
| 703 |
+
"id": "QZCaUaWzLGbS"
|
| 704 |
+
},
|
| 705 |
+
"execution_count": null,
|
| 706 |
+
"outputs": []
|
| 707 |
+
},
|
| 708 |
+
{
|
| 709 |
+
"cell_type": "code",
|
| 710 |
+
"source": [],
|
| 711 |
+
"metadata": {
|
| 712 |
+
"id": "QMaVFK9xLGgu"
|
| 713 |
},
|
| 714 |
"execution_count": null,
|
| 715 |
"outputs": []
|
| 716 |
},
|
| 717 |
{
|
| 718 |
"cell_type": "code",
|
| 719 |
+
"execution_count": 33,
|
| 720 |
+
"metadata": {
|
| 721 |
+
"colab": {
|
| 722 |
+
"base_uri": "https://localhost:8080/"
|
| 723 |
+
},
|
| 724 |
+
"id": "H2BqijX2Zupx",
|
| 725 |
+
"outputId": "3ed96f69-87a9-4b2b-839e-2a8680c672e9"
|
| 726 |
+
},
|
| 727 |
+
"outputs": [
|
| 728 |
+
{
|
| 729 |
+
"output_type": "stream",
|
| 730 |
+
"name": "stdout",
|
| 731 |
+
"text": [
|
| 732 |
+
"Sample: /root/.cache/kagglehub/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/versions/1/Dataset/Brain Tumor CT scan Images/Tumor/ct_tumor (137).jpg\n",
|
| 733 |
+
"[{'label': 'ct_tumor', 'score': 0.9998412132263184}, {'label': 'meningioma', 'score': 0.00015001322026364505}, {'label': 'mri_tumor', 'score': 5.044169483880978e-06}]\n"
|
| 734 |
+
]
|
| 735 |
+
}
|
| 736 |
+
],
|
| 737 |
"source": [
|
| 738 |
"# Load best model & test a single image (inference helper)\n",
|
| 739 |
"import torch.nn.functional as F\n",
|
| 740 |
"from PIL import Image\n",
|
| 741 |
"from torchvision import transforms\n",
|
| 742 |
"\n",
|
| 743 |
+
"def load_model_for_infer(ckpt_path=os.path.join(OUT_DIR,\"best_model.pth\")):\n",
|
| 744 |
+
" m = timm.create_model('efficientnet_b0', pretrained=False, num_classes=NUM_CLASSES)\n",
|
| 745 |
+
" m.load_state_dict(torch.load(ckpt_path, map_location=device))\n",
|
| 746 |
+
" m.to(device).eval()\n",
|
| 747 |
+
" return m\n",
|
|
|
|
| 748 |
"\n",
|
| 749 |
+
"# reuse tf from earlier\n",
|
| 750 |
+
"def predict_topk_pil(model, pil_img, topk=TOP_K):\n",
|
| 751 |
+
" x = val_tf(pil_img).unsqueeze(0).to(device)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 752 |
" with torch.no_grad():\n",
|
| 753 |
" logits = model(x)\n",
|
| 754 |
" probs = F.softmax(logits, dim=1).cpu().numpy().ravel()\n",
|
| 755 |
" idxs = probs.argsort()[::-1][:topk]\n",
|
| 756 |
" return [{\"label\": IDX2LABEL[int(i)], \"score\": float(probs[int(i)])} for i in idxs]\n",
|
| 757 |
"\n",
|
| 758 |
+
"# quick example\n",
|
| 759 |
+
"infer_model = load_model_for_infer()\n",
|
| 760 |
+
"sample_path = val_df.iloc[110].image\n",
|
| 761 |
"img = Image.open(sample_path).convert(\"RGB\")\n",
|
| 762 |
+
"print(\"Sample:\", sample_path)\n",
|
| 763 |
+
"print(predict_topk_pil(infer_model, img, topk=3))"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 764 |
]
|
| 765 |
},
|
| 766 |
{
|
| 767 |
"cell_type": "code",
|
| 768 |
+
"execution_count": null,
|
| 769 |
"metadata": {
|
| 770 |
"id": "RAZRNlq2ZutV"
|
| 771 |
},
|
| 772 |
+
"outputs": [],
|
| 773 |
+
"source": []
|
| 774 |
},
|
| 775 |
{
|
| 776 |
"cell_type": "code",
|
| 777 |
+
"execution_count": null,
|
| 778 |
"metadata": {
|
| 779 |
"id": "chDN6CJmc8h0"
|
| 780 |
},
|
| 781 |
+
"outputs": [],
|
| 782 |
+
"source": []
|
| 783 |
+
},
|
| 784 |
+
{
|
| 785 |
+
"cell_type": "code",
|
| 786 |
+
"source": [
|
| 787 |
+
"# gradio App plus Openai"
|
| 788 |
+
],
|
| 789 |
+
"metadata": {
|
| 790 |
+
"id": "LsQ5FIkOOHch"
|
| 791 |
+
},
|
| 792 |
"execution_count": null,
|
| 793 |
"outputs": []
|
| 794 |
},
|
| 795 |
{
|
| 796 |
"cell_type": "code",
|
| 797 |
+
"execution_count": 39,
|
| 798 |
+
"metadata": {
|
| 799 |
+
"id": "WSygonZYdvpg"
|
| 800 |
+
},
|
| 801 |
+
"outputs": [],
|
| 802 |
"source": [
|
| 803 |
"sample_paths = [\n",
|
| 804 |
+
" \"/root/.cache/kagglehub/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/versions/1/Dataset/Brain Tumor CT scan Images/Tumor/ct_tumor (10).jpg\",\n",
|
| 805 |
+
" \"/root/.cache/kagglehub/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/versions/1/Dataset/Brain Tumor CT scan Images/Healthy/ct_healthy (1).jpg\",\n",
|
| 806 |
+
" \"/root/.cache/kagglehub/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/versions/1/Dataset/Brain Tumor MRI images/Tumor/glioma (10).jpg\",\n",
|
| 807 |
+
" \"/root/.cache/kagglehub/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/versions/1/Dataset/Brain Tumor MRI images/Tumor/meningioma (1056).jpg\",\n",
|
| 808 |
"]\n",
|
| 809 |
"\n",
|
| 810 |
"OPENAI_API_KEY = \"\""
|
| 811 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 812 |
},
|
| 813 |
{
|
| 814 |
"cell_type": "code",
|
| 815 |
+
"execution_count": 43,
|
| 816 |
+
"metadata": {
|
| 817 |
+
"colab": {
|
| 818 |
+
"base_uri": "https://localhost:8080/",
|
| 819 |
+
"height": 680
|
| 820 |
+
},
|
| 821 |
+
"id": "S_rWA1Xcdvx8",
|
| 822 |
+
"outputId": "55c18697-f04d-45bd-bddc-f59e4a729c3d"
|
| 823 |
+
},
|
| 824 |
+
"outputs": [
|
| 825 |
+
{
|
| 826 |
+
"output_type": "stream",
|
| 827 |
+
"name": "stdout",
|
| 828 |
+
"text": [
|
| 829 |
+
"Loading model...\n",
|
| 830 |
+
"Model loaded on cuda\n"
|
| 831 |
+
]
|
| 832 |
+
},
|
| 833 |
+
{
|
| 834 |
+
"output_type": "stream",
|
| 835 |
+
"name": "stderr",
|
| 836 |
+
"text": [
|
| 837 |
+
"/tmp/ipython-input-2087682989.py:119: UserWarning: You have not specified a value for the `type` parameter. Defaulting to the 'tuples' format for chatbot messages, but this is deprecated and will be removed in a future version of Gradio. Please set type='messages' instead, which uses openai-style dictionaries with 'role' and 'content' keys.\n",
|
| 838 |
+
" chatbot = gr.Chatbot(label=\"Assistant\")\n"
|
| 839 |
+
]
|
| 840 |
+
},
|
| 841 |
+
{
|
| 842 |
+
"output_type": "stream",
|
| 843 |
+
"name": "stdout",
|
| 844 |
+
"text": [
|
| 845 |
+
"Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n",
|
| 846 |
+
"* Running on public URL: https://a1ab9ef721b239799a.gradio.live\n",
|
| 847 |
+
"\n",
|
| 848 |
+
"This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
|
| 849 |
+
]
|
| 850 |
+
},
|
| 851 |
+
{
|
| 852 |
+
"output_type": "display_data",
|
| 853 |
+
"data": {
|
| 854 |
+
"text/plain": [
|
| 855 |
+
"<IPython.core.display.HTML object>"
|
| 856 |
+
],
|
| 857 |
+
"text/html": [
|
| 858 |
+
"<div><iframe src=\"https://a1ab9ef721b239799a.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
| 859 |
+
]
|
| 860 |
+
},
|
| 861 |
+
"metadata": {}
|
| 862 |
+
},
|
| 863 |
+
{
|
| 864 |
+
"output_type": "execute_result",
|
| 865 |
+
"data": {
|
| 866 |
+
"text/plain": []
|
| 867 |
+
},
|
| 868 |
+
"metadata": {},
|
| 869 |
+
"execution_count": 43
|
| 870 |
+
}
|
| 871 |
+
],
|
| 872 |
"source": [
|
| 873 |
"import os\n",
|
| 874 |
"from pathlib import Path\n",
|
|
|
|
| 881 |
"import openai\n",
|
| 882 |
"from openai import OpenAI\n",
|
| 883 |
"\n",
|
| 884 |
+
"# ------------------ USER CONFIG (edit as needed) ------------------\n",
|
| 885 |
+
"CKPT_PATH = \"/content/output/best_model.pth\" # path to trained checkpoint\n",
|
|
|
|
| 886 |
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 887 |
"BACKBONE = \"efficientnet_b0\"\n",
|
|
|
|
| 888 |
"IMG_SIZE = 224\n",
|
| 889 |
"\n",
|
| 890 |
+
"# Paste your 4 sample image absolute paths here (or ensure sample_paths variable already exists)\n",
|
| 891 |
+
"# Example:\n",
|
| 892 |
+
"# sample_paths = [\"/content/samples/ct1.jpg\", \"/content/samples/mri_glioma1.jpg\", ...]\n",
|
| 893 |
+
"sample_paths = sample_paths # keep your existing variable/assignment\n",
|
| 894 |
"\n",
|
| 895 |
+
"OPENAI_API_KEY = OPENAI_API_KEY # keep as before (e.g., from env or variable)\n",
|
| 896 |
"if OPENAI_API_KEY:\n",
|
| 897 |
" openai.api_key = OPENAI_API_KEY\n",
|
| 898 |
"\n",
|
| 899 |
+
"# ------------------ UPDATED LABEL MAP for combined CT+MRI ------------------\n",
|
| 900 |
+
"# Minimal change: include ct_tumor and MRI subtypes + no_tumor\n",
|
| 901 |
+
"IDX2LABEL = {\n",
|
| 902 |
+
" 0: \"ct_tumor\",\n",
|
| 903 |
+
" 1: \"glioma\",\n",
|
| 904 |
+
" 2: \"meningioma\",\n",
|
| 905 |
+
" 3: \"mri_tumor\",\n",
|
| 906 |
+
" 4: \"no_tumor\",\n",
|
| 907 |
+
" 5: \"pituitary\"\n",
|
| 908 |
+
"}\n",
|
| 909 |
+
"\n",
|
| 910 |
+
"NUM_CLASSES = len(IDX2LABEL)\n",
|
| 911 |
"# -------------------------------------------------\n",
|
| 912 |
"\n",
|
| 913 |
"# Validate sample paths\n",
|
|
|
|
| 943 |
" idxs = probs.argsort()[::-1][:topk]\n",
|
| 944 |
" return [{\"label\": IDX2LABEL[int(i)], \"score\": float(probs[int(i)])} for i in idxs]\n",
|
| 945 |
"\n",
|
| 946 |
+
"# OpenAI system prompt (unchanged)\n",
|
| 947 |
"SYSTEM_TEMPLATE = \"\"\"You are an educational radiology assistant. The model predicted:\n",
|
| 948 |
"{preds}\n",
|
| 949 |
"\n",
|
|
|
|
| 973 |
" )\n",
|
| 974 |
" return resp.choices[0].message.content.strip()\n",
|
| 975 |
"\n",
|
|
|
|
| 976 |
"# Prepare thumbnails (just images)\n",
|
| 977 |
"thumbs = [Image.open(p).convert(\"RGB\") for p in sample_paths]\n",
|
| 978 |
"\n",
|
| 979 |
+
"# Gradio UI (unchanged structure)\n",
|
| 980 |
"with gr.Blocks() as demo:\n",
|
| 981 |
+
" gr.Markdown(\"## Brain MRI+CT multi-class demo — Educational only\")\n",
|
| 982 |
" with gr.Row():\n",
|
| 983 |
" with gr.Column():\n",
|
| 984 |
" # Use a simple list of PIL thumbnails for the gallery value\n",
|
|
|
|
| 996 |
" selected_path_state = gr.State(value=None)\n",
|
| 997 |
" preds_state = gr.State(value=None)\n",
|
| 998 |
"\n",
|
| 999 |
+
" # select image from gallery -> store selected path string\n",
|
| 1000 |
" def on_select(evt: gr.SelectData):\n",
|
| 1001 |
+
" idx = int(evt.index)\n",
|
| 1002 |
+
" path = sample_paths[idx]\n",
|
| 1003 |
" return path\n",
|
| 1004 |
"\n",
|
| 1005 |
" # ensure the gallery select writes only the path to the state\n",
|
| 1006 |
" gallery.select(fn=on_select, inputs=None, outputs=[selected_path_state])\n",
|
| 1007 |
"\n",
|
| 1008 |
+
" # analyze (unchanged except label names)\n",
|
| 1009 |
" def analyze(path):\n",
|
| 1010 |
" try:\n",
|
| 1011 |
" if path is None:\n",
|
|
|
|
| 1022 |
"\n",
|
| 1023 |
" analyze_btn.click(fn=analyze, inputs=selected_path_state, outputs=[preds_output, chatbot, preds_state])\n",
|
| 1024 |
"\n",
|
| 1025 |
+
" # chat (keeps same logic)\n",
|
| 1026 |
" def chat(chat_history, msg, preds_text):\n",
|
| 1027 |
" if preds_text is None:\n",
|
| 1028 |
" return chat_history+[(\"AI\",\"Please analyze an image first.\")], \"\"\n",
|
|
|
|
| 1053 |
"\n",
|
| 1054 |
" gr.Markdown(\"⚠️ Educational use only — not a medical diagnosis.\")\n",
|
| 1055 |
"\n",
|
| 1056 |
+
"demo.launch(share=True)"
|
| 1057 |
+
]
|
| 1058 |
+
},
|
| 1059 |
+
{
|
| 1060 |
+
"cell_type": "code",
|
| 1061 |
+
"execution_count": null,
|
| 1062 |
+
"metadata": {
|
| 1063 |
+
"id": "TrS_HlY7d8s3"
|
| 1064 |
+
},
|
| 1065 |
+
"outputs": [],
|
| 1066 |
+
"source": []
|
| 1067 |
+
},
|
| 1068 |
+
{
|
| 1069 |
+
"cell_type": "code",
|
| 1070 |
+
"execution_count": null,
|
| 1071 |
"metadata": {
|
| 1072 |
"colab": {
|
| 1073 |
+
"background_save": true
|
| 1074 |
+
},
|
| 1075 |
+
"id": "6Yj2jOOY0swp"
|
| 1076 |
+
},
|
| 1077 |
+
"outputs": [],
|
| 1078 |
+
"source": []
|
| 1079 |
+
}
|
| 1080 |
+
],
|
| 1081 |
+
"metadata": {
|
| 1082 |
+
"accelerator": "GPU",
|
| 1083 |
+
"colab": {
|
| 1084 |
+
"gpuType": "T4",
|
| 1085 |
+
"provenance": []
|
| 1086 |
+
},
|
| 1087 |
+
"kernelspec": {
|
| 1088 |
+
"display_name": "Python 3",
|
| 1089 |
+
"name": "python3"
|
| 1090 |
+
},
|
| 1091 |
+
"language_info": {
|
| 1092 |
+
"name": "python"
|
| 1093 |
+
},
|
| 1094 |
+
"widgets": {
|
| 1095 |
+
"application/vnd.jupyter.widget-state+json": {
|
| 1096 |
+
"fd1f533d7d814e10b252bf66f8c0697e": {
|
| 1097 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1098 |
+
"model_name": "HBoxModel",
|
| 1099 |
+
"model_module_version": "1.5.0",
|
| 1100 |
+
"state": {
|
| 1101 |
+
"_dom_classes": [],
|
| 1102 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1103 |
+
"_model_module_version": "1.5.0",
|
| 1104 |
+
"_model_name": "HBoxModel",
|
| 1105 |
+
"_view_count": null,
|
| 1106 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1107 |
+
"_view_module_version": "1.5.0",
|
| 1108 |
+
"_view_name": "HBoxView",
|
| 1109 |
+
"box_style": "",
|
| 1110 |
+
"children": [
|
| 1111 |
+
"IPY_MODEL_243ec1d23b8e469e833db01f89189639",
|
| 1112 |
+
"IPY_MODEL_9bcf41c7fe1e425aae4e37dcfca6a67c",
|
| 1113 |
+
"IPY_MODEL_f09e93d2fdc64071b3b9860fd7782c58"
|
| 1114 |
+
],
|
| 1115 |
+
"layout": "IPY_MODEL_0a03432f7ce14c8c98090f846a512beb"
|
| 1116 |
+
}
|
| 1117 |
+
},
|
| 1118 |
+
"243ec1d23b8e469e833db01f89189639": {
|
| 1119 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1120 |
+
"model_name": "HTMLModel",
|
| 1121 |
+
"model_module_version": "1.5.0",
|
| 1122 |
+
"state": {
|
| 1123 |
+
"_dom_classes": [],
|
| 1124 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1125 |
+
"_model_module_version": "1.5.0",
|
| 1126 |
+
"_model_name": "HTMLModel",
|
| 1127 |
+
"_view_count": null,
|
| 1128 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1129 |
+
"_view_module_version": "1.5.0",
|
| 1130 |
+
"_view_name": "HTMLView",
|
| 1131 |
+
"description": "",
|
| 1132 |
+
"description_tooltip": null,
|
| 1133 |
+
"layout": "IPY_MODEL_1da74b90beb44d85a22ca57ee52b1e9d",
|
| 1134 |
+
"placeholder": "",
|
| 1135 |
+
"style": "IPY_MODEL_a7c35ed82c20419b9b559b217d924a6a",
|
| 1136 |
+
"value": "model.safetensors: 100%"
|
| 1137 |
+
}
|
| 1138 |
+
},
|
| 1139 |
+
"9bcf41c7fe1e425aae4e37dcfca6a67c": {
|
| 1140 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1141 |
+
"model_name": "FloatProgressModel",
|
| 1142 |
+
"model_module_version": "1.5.0",
|
| 1143 |
+
"state": {
|
| 1144 |
+
"_dom_classes": [],
|
| 1145 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1146 |
+
"_model_module_version": "1.5.0",
|
| 1147 |
+
"_model_name": "FloatProgressModel",
|
| 1148 |
+
"_view_count": null,
|
| 1149 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1150 |
+
"_view_module_version": "1.5.0",
|
| 1151 |
+
"_view_name": "ProgressView",
|
| 1152 |
+
"bar_style": "success",
|
| 1153 |
+
"description": "",
|
| 1154 |
+
"description_tooltip": null,
|
| 1155 |
+
"layout": "IPY_MODEL_0a302207f9884533b49fb150411169d2",
|
| 1156 |
+
"max": 21355344,
|
| 1157 |
+
"min": 0,
|
| 1158 |
+
"orientation": "horizontal",
|
| 1159 |
+
"style": "IPY_MODEL_bc007e650c0f4f619946055a6c72d132",
|
| 1160 |
+
"value": 21355344
|
| 1161 |
+
}
|
| 1162 |
+
},
|
| 1163 |
+
"f09e93d2fdc64071b3b9860fd7782c58": {
|
| 1164 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1165 |
+
"model_name": "HTMLModel",
|
| 1166 |
+
"model_module_version": "1.5.0",
|
| 1167 |
+
"state": {
|
| 1168 |
+
"_dom_classes": [],
|
| 1169 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1170 |
+
"_model_module_version": "1.5.0",
|
| 1171 |
+
"_model_name": "HTMLModel",
|
| 1172 |
+
"_view_count": null,
|
| 1173 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1174 |
+
"_view_module_version": "1.5.0",
|
| 1175 |
+
"_view_name": "HTMLView",
|
| 1176 |
+
"description": "",
|
| 1177 |
+
"description_tooltip": null,
|
| 1178 |
+
"layout": "IPY_MODEL_20b5ba3ee19446dc89ddb5222be6f5f2",
|
| 1179 |
+
"placeholder": "",
|
| 1180 |
+
"style": "IPY_MODEL_1e35273de451460aa8c82df17b3ca23c",
|
| 1181 |
+
"value": " 21.4M/21.4M [00:00<00:00, 43.5MB/s]"
|
| 1182 |
+
}
|
| 1183 |
+
},
|
| 1184 |
+
"0a03432f7ce14c8c98090f846a512beb": {
|
| 1185 |
+
"model_module": "@jupyter-widgets/base",
|
| 1186 |
+
"model_name": "LayoutModel",
|
| 1187 |
+
"model_module_version": "1.2.0",
|
| 1188 |
+
"state": {
|
| 1189 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1190 |
+
"_model_module_version": "1.2.0",
|
| 1191 |
+
"_model_name": "LayoutModel",
|
| 1192 |
+
"_view_count": null,
|
| 1193 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1194 |
+
"_view_module_version": "1.2.0",
|
| 1195 |
+
"_view_name": "LayoutView",
|
| 1196 |
+
"align_content": null,
|
| 1197 |
+
"align_items": null,
|
| 1198 |
+
"align_self": null,
|
| 1199 |
+
"border": null,
|
| 1200 |
+
"bottom": null,
|
| 1201 |
+
"display": null,
|
| 1202 |
+
"flex": null,
|
| 1203 |
+
"flex_flow": null,
|
| 1204 |
+
"grid_area": null,
|
| 1205 |
+
"grid_auto_columns": null,
|
| 1206 |
+
"grid_auto_flow": null,
|
| 1207 |
+
"grid_auto_rows": null,
|
| 1208 |
+
"grid_column": null,
|
| 1209 |
+
"grid_gap": null,
|
| 1210 |
+
"grid_row": null,
|
| 1211 |
+
"grid_template_areas": null,
|
| 1212 |
+
"grid_template_columns": null,
|
| 1213 |
+
"grid_template_rows": null,
|
| 1214 |
+
"height": null,
|
| 1215 |
+
"justify_content": null,
|
| 1216 |
+
"justify_items": null,
|
| 1217 |
+
"left": null,
|
| 1218 |
+
"margin": null,
|
| 1219 |
+
"max_height": null,
|
| 1220 |
+
"max_width": null,
|
| 1221 |
+
"min_height": null,
|
| 1222 |
+
"min_width": null,
|
| 1223 |
+
"object_fit": null,
|
| 1224 |
+
"object_position": null,
|
| 1225 |
+
"order": null,
|
| 1226 |
+
"overflow": null,
|
| 1227 |
+
"overflow_x": null,
|
| 1228 |
+
"overflow_y": null,
|
| 1229 |
+
"padding": null,
|
| 1230 |
+
"right": null,
|
| 1231 |
+
"top": null,
|
| 1232 |
+
"visibility": null,
|
| 1233 |
+
"width": null
|
| 1234 |
+
}
|
| 1235 |
},
|
| 1236 |
+
"1da74b90beb44d85a22ca57ee52b1e9d": {
|
| 1237 |
+
"model_module": "@jupyter-widgets/base",
|
| 1238 |
+
"model_name": "LayoutModel",
|
| 1239 |
+
"model_module_version": "1.2.0",
|
| 1240 |
+
"state": {
|
| 1241 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1242 |
+
"_model_module_version": "1.2.0",
|
| 1243 |
+
"_model_name": "LayoutModel",
|
| 1244 |
+
"_view_count": null,
|
| 1245 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1246 |
+
"_view_module_version": "1.2.0",
|
| 1247 |
+
"_view_name": "LayoutView",
|
| 1248 |
+
"align_content": null,
|
| 1249 |
+
"align_items": null,
|
| 1250 |
+
"align_self": null,
|
| 1251 |
+
"border": null,
|
| 1252 |
+
"bottom": null,
|
| 1253 |
+
"display": null,
|
| 1254 |
+
"flex": null,
|
| 1255 |
+
"flex_flow": null,
|
| 1256 |
+
"grid_area": null,
|
| 1257 |
+
"grid_auto_columns": null,
|
| 1258 |
+
"grid_auto_flow": null,
|
| 1259 |
+
"grid_auto_rows": null,
|
| 1260 |
+
"grid_column": null,
|
| 1261 |
+
"grid_gap": null,
|
| 1262 |
+
"grid_row": null,
|
| 1263 |
+
"grid_template_areas": null,
|
| 1264 |
+
"grid_template_columns": null,
|
| 1265 |
+
"grid_template_rows": null,
|
| 1266 |
+
"height": null,
|
| 1267 |
+
"justify_content": null,
|
| 1268 |
+
"justify_items": null,
|
| 1269 |
+
"left": null,
|
| 1270 |
+
"margin": null,
|
| 1271 |
+
"max_height": null,
|
| 1272 |
+
"max_width": null,
|
| 1273 |
+
"min_height": null,
|
| 1274 |
+
"min_width": null,
|
| 1275 |
+
"object_fit": null,
|
| 1276 |
+
"object_position": null,
|
| 1277 |
+
"order": null,
|
| 1278 |
+
"overflow": null,
|
| 1279 |
+
"overflow_x": null,
|
| 1280 |
+
"overflow_y": null,
|
| 1281 |
+
"padding": null,
|
| 1282 |
+
"right": null,
|
| 1283 |
+
"top": null,
|
| 1284 |
+
"visibility": null,
|
| 1285 |
+
"width": null
|
| 1286 |
+
}
|
| 1287 |
},
|
| 1288 |
+
"a7c35ed82c20419b9b559b217d924a6a": {
|
| 1289 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1290 |
+
"model_name": "DescriptionStyleModel",
|
| 1291 |
+
"model_module_version": "1.5.0",
|
| 1292 |
+
"state": {
|
| 1293 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1294 |
+
"_model_module_version": "1.5.0",
|
| 1295 |
+
"_model_name": "DescriptionStyleModel",
|
| 1296 |
+
"_view_count": null,
|
| 1297 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1298 |
+
"_view_module_version": "1.2.0",
|
| 1299 |
+
"_view_name": "StyleView",
|
| 1300 |
+
"description_width": ""
|
| 1301 |
+
}
|
| 1302 |
},
|
| 1303 |
+
"0a302207f9884533b49fb150411169d2": {
|
| 1304 |
+
"model_module": "@jupyter-widgets/base",
|
| 1305 |
+
"model_name": "LayoutModel",
|
| 1306 |
+
"model_module_version": "1.2.0",
|
| 1307 |
+
"state": {
|
| 1308 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1309 |
+
"_model_module_version": "1.2.0",
|
| 1310 |
+
"_model_name": "LayoutModel",
|
| 1311 |
+
"_view_count": null,
|
| 1312 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1313 |
+
"_view_module_version": "1.2.0",
|
| 1314 |
+
"_view_name": "LayoutView",
|
| 1315 |
+
"align_content": null,
|
| 1316 |
+
"align_items": null,
|
| 1317 |
+
"align_self": null,
|
| 1318 |
+
"border": null,
|
| 1319 |
+
"bottom": null,
|
| 1320 |
+
"display": null,
|
| 1321 |
+
"flex": null,
|
| 1322 |
+
"flex_flow": null,
|
| 1323 |
+
"grid_area": null,
|
| 1324 |
+
"grid_auto_columns": null,
|
| 1325 |
+
"grid_auto_flow": null,
|
| 1326 |
+
"grid_auto_rows": null,
|
| 1327 |
+
"grid_column": null,
|
| 1328 |
+
"grid_gap": null,
|
| 1329 |
+
"grid_row": null,
|
| 1330 |
+
"grid_template_areas": null,
|
| 1331 |
+
"grid_template_columns": null,
|
| 1332 |
+
"grid_template_rows": null,
|
| 1333 |
+
"height": null,
|
| 1334 |
+
"justify_content": null,
|
| 1335 |
+
"justify_items": null,
|
| 1336 |
+
"left": null,
|
| 1337 |
+
"margin": null,
|
| 1338 |
+
"max_height": null,
|
| 1339 |
+
"max_width": null,
|
| 1340 |
+
"min_height": null,
|
| 1341 |
+
"min_width": null,
|
| 1342 |
+
"object_fit": null,
|
| 1343 |
+
"object_position": null,
|
| 1344 |
+
"order": null,
|
| 1345 |
+
"overflow": null,
|
| 1346 |
+
"overflow_x": null,
|
| 1347 |
+
"overflow_y": null,
|
| 1348 |
+
"padding": null,
|
| 1349 |
+
"right": null,
|
| 1350 |
+
"top": null,
|
| 1351 |
+
"visibility": null,
|
| 1352 |
+
"width": null
|
| 1353 |
+
}
|
| 1354 |
},
|
| 1355 |
+
"bc007e650c0f4f619946055a6c72d132": {
|
| 1356 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1357 |
+
"model_name": "ProgressStyleModel",
|
| 1358 |
+
"model_module_version": "1.5.0",
|
| 1359 |
+
"state": {
|
| 1360 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1361 |
+
"_model_module_version": "1.5.0",
|
| 1362 |
+
"_model_name": "ProgressStyleModel",
|
| 1363 |
+
"_view_count": null,
|
| 1364 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1365 |
+
"_view_module_version": "1.2.0",
|
| 1366 |
+
"_view_name": "StyleView",
|
| 1367 |
+
"bar_color": null,
|
| 1368 |
+
"description_width": ""
|
| 1369 |
+
}
|
| 1370 |
},
|
| 1371 |
+
"20b5ba3ee19446dc89ddb5222be6f5f2": {
|
| 1372 |
+
"model_module": "@jupyter-widgets/base",
|
| 1373 |
+
"model_name": "LayoutModel",
|
| 1374 |
+
"model_module_version": "1.2.0",
|
| 1375 |
+
"state": {
|
| 1376 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1377 |
+
"_model_module_version": "1.2.0",
|
| 1378 |
+
"_model_name": "LayoutModel",
|
| 1379 |
+
"_view_count": null,
|
| 1380 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1381 |
+
"_view_module_version": "1.2.0",
|
| 1382 |
+
"_view_name": "LayoutView",
|
| 1383 |
+
"align_content": null,
|
| 1384 |
+
"align_items": null,
|
| 1385 |
+
"align_self": null,
|
| 1386 |
+
"border": null,
|
| 1387 |
+
"bottom": null,
|
| 1388 |
+
"display": null,
|
| 1389 |
+
"flex": null,
|
| 1390 |
+
"flex_flow": null,
|
| 1391 |
+
"grid_area": null,
|
| 1392 |
+
"grid_auto_columns": null,
|
| 1393 |
+
"grid_auto_flow": null,
|
| 1394 |
+
"grid_auto_rows": null,
|
| 1395 |
+
"grid_column": null,
|
| 1396 |
+
"grid_gap": null,
|
| 1397 |
+
"grid_row": null,
|
| 1398 |
+
"grid_template_areas": null,
|
| 1399 |
+
"grid_template_columns": null,
|
| 1400 |
+
"grid_template_rows": null,
|
| 1401 |
+
"height": null,
|
| 1402 |
+
"justify_content": null,
|
| 1403 |
+
"justify_items": null,
|
| 1404 |
+
"left": null,
|
| 1405 |
+
"margin": null,
|
| 1406 |
+
"max_height": null,
|
| 1407 |
+
"max_width": null,
|
| 1408 |
+
"min_height": null,
|
| 1409 |
+
"min_width": null,
|
| 1410 |
+
"object_fit": null,
|
| 1411 |
+
"object_position": null,
|
| 1412 |
+
"order": null,
|
| 1413 |
+
"overflow": null,
|
| 1414 |
+
"overflow_x": null,
|
| 1415 |
+
"overflow_y": null,
|
| 1416 |
+
"padding": null,
|
| 1417 |
+
"right": null,
|
| 1418 |
+
"top": null,
|
| 1419 |
+
"visibility": null,
|
| 1420 |
+
"width": null
|
| 1421 |
+
}
|
| 1422 |
+
},
|
| 1423 |
+
"1e35273de451460aa8c82df17b3ca23c": {
|
| 1424 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1425 |
+
"model_name": "DescriptionStyleModel",
|
| 1426 |
+
"model_module_version": "1.5.0",
|
| 1427 |
+
"state": {
|
| 1428 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1429 |
+
"_model_module_version": "1.5.0",
|
| 1430 |
+
"_model_name": "DescriptionStyleModel",
|
| 1431 |
+
"_view_count": null,
|
| 1432 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1433 |
+
"_view_module_version": "1.2.0",
|
| 1434 |
+
"_view_name": "StyleView",
|
| 1435 |
+
"description_width": ""
|
| 1436 |
+
}
|
| 1437 |
}
|
| 1438 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1439 |
}
|
| 1440 |
+
},
|
| 1441 |
+
"nbformat": 4,
|
| 1442 |
+
"nbformat_minor": 0
|
| 1443 |
}
|
app.py
CHANGED
|
@@ -9,28 +9,45 @@ import gradio as gr
|
|
| 9 |
import openai
|
| 10 |
from openai import OpenAI
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
|
|
|
| 15 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
BACKBONE = "efficientnet_b0"
|
| 17 |
-
NUM_CLASSES = 4
|
| 18 |
IMG_SIZE = 224
|
| 19 |
|
| 20 |
-
# Paste your 4 sample image absolute paths here
|
| 21 |
-
#
|
| 22 |
-
sample_paths = [
|
| 23 |
-
|
| 24 |
-
"sample_images/image(2).jpg",
|
| 25 |
-
"sample_images/image(3).jpg",
|
| 26 |
-
"sample_images/image(23).jpg",
|
| 27 |
-
]
|
| 28 |
|
| 29 |
-
OPENAI_API_KEY =
|
| 30 |
if OPENAI_API_KEY:
|
| 31 |
openai.api_key = OPENAI_API_KEY
|
| 32 |
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
# -------------------------------------------------
|
| 35 |
|
| 36 |
# Validate sample paths
|
|
@@ -66,7 +83,7 @@ def predict_topk_from_pil(pil_img, topk=3):
|
|
| 66 |
idxs = probs.argsort()[::-1][:topk]
|
| 67 |
return [{"label": IDX2LABEL[int(i)], "score": float(probs[int(i)])} for i in idxs]
|
| 68 |
|
| 69 |
-
# OpenAI system prompt
|
| 70 |
SYSTEM_TEMPLATE = """You are an educational radiology assistant. The model predicted:
|
| 71 |
{preds}
|
| 72 |
|
|
@@ -96,13 +113,12 @@ def call_openai_seed(preds_text):
|
|
| 96 |
)
|
| 97 |
return resp.choices[0].message.content.strip()
|
| 98 |
|
| 99 |
-
|
| 100 |
# Prepare thumbnails (just images)
|
| 101 |
thumbs = [Image.open(p).convert("RGB") for p in sample_paths]
|
| 102 |
|
| 103 |
-
# Gradio UI
|
| 104 |
with gr.Blocks() as demo:
|
| 105 |
-
gr.Markdown("## Brain MRI multi-class demo — Educational only")
|
| 106 |
with gr.Row():
|
| 107 |
with gr.Column():
|
| 108 |
# Use a simple list of PIL thumbnails for the gallery value
|
|
@@ -120,16 +136,16 @@ with gr.Blocks() as demo:
|
|
| 120 |
selected_path_state = gr.State(value=None)
|
| 121 |
preds_state = gr.State(value=None)
|
| 122 |
|
| 123 |
-
# select image from gallery
|
| 124 |
def on_select(evt: gr.SelectData):
|
| 125 |
-
idx = evt.index
|
| 126 |
-
path = sample_paths[idx]
|
| 127 |
return path
|
| 128 |
|
| 129 |
# ensure the gallery select writes only the path to the state
|
| 130 |
gallery.select(fn=on_select, inputs=None, outputs=[selected_path_state])
|
| 131 |
|
| 132 |
-
# analyze
|
| 133 |
def analyze(path):
|
| 134 |
try:
|
| 135 |
if path is None:
|
|
@@ -146,7 +162,7 @@ with gr.Blocks() as demo:
|
|
| 146 |
|
| 147 |
analyze_btn.click(fn=analyze, inputs=selected_path_state, outputs=[preds_output, chatbot, preds_state])
|
| 148 |
|
| 149 |
-
# chat
|
| 150 |
def chat(chat_history, msg, preds_text):
|
| 151 |
if preds_text is None:
|
| 152 |
return chat_history+[("AI","Please analyze an image first.")], ""
|
|
@@ -177,4 +193,4 @@ with gr.Blocks() as demo:
|
|
| 177 |
|
| 178 |
gr.Markdown("⚠️ Educational use only — not a medical diagnosis.")
|
| 179 |
|
| 180 |
-
demo.launch(share=True)
|
|
|
|
| 9 |
import openai
|
| 10 |
from openai import OpenAI
|
| 11 |
|
| 12 |
+
sample_paths = [
|
| 13 |
+
"sample_images/ct_healthy (1).jpg",
|
| 14 |
+
"sample_images/ct_tumor (1).png",
|
| 15 |
+
"sample_images/glioma (1).jpg",
|
| 16 |
+
"sample_images/pituitary (244).jpg",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
| 20 |
+
if OPENAI_API_KEY:
|
| 21 |
+
openai.api_key = OPENAI_API_KEY
|
| 22 |
|
| 23 |
+
|
| 24 |
+
# ------------------ USER CONFIG (edit as needed) ------------------
|
| 25 |
+
CKPT_PATH = "model/best_model.pth" # path to trained checkpoint
|
| 26 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
BACKBONE = "efficientnet_b0"
|
|
|
|
| 28 |
IMG_SIZE = 224
|
| 29 |
|
| 30 |
+
# Paste your 4 sample image absolute paths here (or ensure sample_paths variable already exists)
|
| 31 |
+
# Example:
|
| 32 |
+
# sample_paths = ["/content/samples/ct1.jpg", "/content/samples/mri_glioma1.jpg", ...]
|
| 33 |
+
sample_paths = sample_paths # keep your existing variable/assignment
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
OPENAI_API_KEY = OPENAI_API_KEY # keep as before (e.g., from env or variable)
|
| 36 |
if OPENAI_API_KEY:
|
| 37 |
openai.api_key = OPENAI_API_KEY
|
| 38 |
|
| 39 |
+
# ------------------ UPDATED LABEL MAP for combined CT+MRI ------------------
|
| 40 |
+
# Minimal change: include ct_tumor and MRI subtypes + no_tumor
|
| 41 |
+
IDX2LABEL = {
|
| 42 |
+
0: "ct_tumor",
|
| 43 |
+
1: "glioma",
|
| 44 |
+
2: "meningioma",
|
| 45 |
+
3: "mri_tumor",
|
| 46 |
+
4: "no_tumor",
|
| 47 |
+
5: "pituitary"
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
NUM_CLASSES = len(IDX2LABEL)
|
| 51 |
# -------------------------------------------------
|
| 52 |
|
| 53 |
# Validate sample paths
|
|
|
|
| 83 |
idxs = probs.argsort()[::-1][:topk]
|
| 84 |
return [{"label": IDX2LABEL[int(i)], "score": float(probs[int(i)])} for i in idxs]
|
| 85 |
|
| 86 |
+
# OpenAI system prompt (unchanged)
|
| 87 |
SYSTEM_TEMPLATE = """You are an educational radiology assistant. The model predicted:
|
| 88 |
{preds}
|
| 89 |
|
|
|
|
| 113 |
)
|
| 114 |
return resp.choices[0].message.content.strip()
|
| 115 |
|
|
|
|
| 116 |
# Prepare thumbnails (just images)
|
| 117 |
thumbs = [Image.open(p).convert("RGB") for p in sample_paths]
|
| 118 |
|
| 119 |
+
# Gradio UI (unchanged structure)
|
| 120 |
with gr.Blocks() as demo:
|
| 121 |
+
gr.Markdown("## Brain MRI+CT multi-class demo — Educational only")
|
| 122 |
with gr.Row():
|
| 123 |
with gr.Column():
|
| 124 |
# Use a simple list of PIL thumbnails for the gallery value
|
|
|
|
| 136 |
selected_path_state = gr.State(value=None)
|
| 137 |
preds_state = gr.State(value=None)
|
| 138 |
|
| 139 |
+
# select image from gallery -> store selected path string
|
| 140 |
def on_select(evt: gr.SelectData):
|
| 141 |
+
idx = int(evt.index)
|
| 142 |
+
path = sample_paths[idx]
|
| 143 |
return path
|
| 144 |
|
| 145 |
# ensure the gallery select writes only the path to the state
|
| 146 |
gallery.select(fn=on_select, inputs=None, outputs=[selected_path_state])
|
| 147 |
|
| 148 |
+
# analyze (unchanged except label names)
|
| 149 |
def analyze(path):
|
| 150 |
try:
|
| 151 |
if path is None:
|
|
|
|
| 162 |
|
| 163 |
analyze_btn.click(fn=analyze, inputs=selected_path_state, outputs=[preds_output, chatbot, preds_state])
|
| 164 |
|
| 165 |
+
# chat (keeps same logic)
|
| 166 |
def chat(chat_history, msg, preds_text):
|
| 167 |
if preds_text is None:
|
| 168 |
return chat_history+[("AI","Please analyze an image first.")], ""
|
|
|
|
| 193 |
|
| 194 |
gr.Markdown("⚠️ Educational use only — not a medical diagnosis.")
|
| 195 |
|
| 196 |
+
demo.launch(share=True)
|
model/{best_model_multiclass.pth → best_model.pth}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:66488a18d397630b3f6eda091504415329f432fbf90b0a14d1fae5be90bc7864
|
| 3 |
+
size 16356601
|
sample_images/ct_healthy (1).jpg
ADDED
|
Git LFS Details
|
sample_images/ct_tumor (1).png
ADDED
|
Git LFS Details
|
sample_images/glioma (1).jpg
ADDED
|
Git LFS Details
|
sample_images/image(2).jpg
DELETED
|
Binary file (33.8 kB)
|
|
|
sample_images/image(23).jpg
DELETED
|
Binary file (34.2 kB)
|
|
|
sample_images/image(3).jpg
DELETED
|
Binary file (16.9 kB)
|
|
|
sample_images/image.jpg
DELETED
|
Binary file (49.5 kB)
|
|
|
sample_images/pituitary (244).jpg
ADDED
|
Git LFS Details
|