File size: 46,428 Bytes
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
{
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10"
  },
  "kaggle": {
   "accelerator": "nvidiaTeslaT4",
   "dataSources": [
    {
     "sourceType": "datasetVersion",
     "sourceId": 15854996,
     "datasetId": 10122310,
     "databundleVersionId": 16806564
    },
    {
     "sourceType": "datasetVersion",
     "sourceId": 15849207,
     "datasetId": 10121644,
     "databundleVersionId": 16800280
    }
   ],
   "dockerImageVersionId": 31329,
   "isInternetEnabled": true,
   "language": "python",
   "sourceType": "notebook",
   "isGpuEnabled": true
  }
 },
 "nbformat_minor": 5,
 "nbformat": 4,
 "cells": [
  {
   "id": "c622dfc2-9f77-4cc0-9121-9b7a86af172c",
   "cell_type": "markdown",
   "source": "# CXR-VLM — Kaggle Training Notebook (consolidated)\n\nTrains the 2-stage CXR-VLM (Vicuna-7B + BioViL-T fallback to timm ViT + LoRA) on a Kaggle **T4** GPU.\n\nSupports **two datasets**, selected by `DATASET_NAME` in section 0:\n- **`MIMIC-CXR`** — full 3 tasks (findings, impression, VQA).\n- **`IU-Xray`**   — 2 tasks only (findings, impression). Much lighter dataset (~7.5k images).\n\n### Before you run\n\nAttach Kaggle Datasets via `+ Add Input`:\n\n| Dataset slug | Contents | When needed |\n|---|---|---|\n| `cxr-vlm-code` | entire `D:\\USTH\\KLTN` folder (configs/, data/*.py, model/, training/, evaluation/, utils/, requirements.txt) | **always** |\n| `cxr-vlm-data` | holds **both** datasets: `MIMIC-CXR/{train,valid,test}/p*/...` + `MIMIC-Ext-MIMIC-CXR-VQA/...` and/or `IU-Xray/images/` + `IU-Xray/labels/` | **always** |\n\n**Settings (right panel):**\n- Accelerator: **T4 x2** (only GPU 0 will be used)\n- Persistence: **Variables and Files**\n- Internet: **On**\n\n**Kaggle Secrets** (Add-ons → Secrets):\n- `HF_TOKEN` — HuggingFace token with write access to the runs repo.",
   "metadata": {
    "id": "cell-0"
   }
  },
  {
   "id": "d8523ec6-1cbe-43c1-8a24-86deb1ea2e6d",
   "cell_type": "markdown",
   "source": "## 0. Select dataset\n\nChange this one variable to switch between datasets. Everything else (data loading, config patching, training, evaluation) is driven by it.",
   "metadata": {
    "id": "cell-select-md"
   }
  },
  {
   "id": "659be15e-ce36-4f97-9362-019ba31d9ad7",
   "cell_type": "code",
   "source": "DATASET_NAME = 'IU-Xray'     # 'MIMIC-CXR' | 'IU-Xray'\nassert DATASET_NAME in ('MIMIC-CXR', 'IU-Xray')\nprint('DATASET_NAME =', DATASET_NAME)",
   "metadata": {
    "id": "cell-select",
    "trusted": true,
    "execution": {
     "iopub.status.busy": "2026-04-21T10:16:15.022329Z",
     "iopub.execute_input": "2026-04-21T10:16:15.023127Z",
     "iopub.status.idle": "2026-04-21T10:16:15.027751Z",
     "shell.execute_reply.started": "2026-04-21T10:16:15.023086Z",
     "shell.execute_reply": "2026-04-21T10:16:15.027045Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "text": "DATASET_NAME = IU-Xray\n",
     "output_type": "stream"
    }
   ],
   "execution_count": 26
  },
  {
   "id": "2c0bf73e-acf3-464a-91bd-30a7a0eb56e4",
   "cell_type": "markdown",
   "source": "## 1. Environment setup\n\nForce single-GPU (device 0) **before** any `import torch` to avoid the `cuda:0/cuda:1` tensor-mismatch we hit on T4x2.",
   "metadata": {
    "id": "cell-1-md"
   }
  },
  {
   "id": "7f0f48f6-8d2d-4f73-8bef-d427559934f0",
   "cell_type": "code",
   "source": "import os\nos.environ['CUDA_VISIBLE_DEVICES']            = '0'      # single-GPU\nos.environ['TOKENIZERS_PARALLELISM']          = 'false'  # silence HF tokenizers fork warning\nos.environ['BITSANDBYTES_NOWELCOME']          = '1'\nos.environ['HF_HUB_DISABLE_PROGRESS_BARS']    = '1'      # kill per-shard download bars\nos.environ['TRANSFORMERS_VERBOSITY']          = 'warning'\nos.environ['PYTHONUNBUFFERED']                = '1'\n\nimport sys, shutil, subprocess\nfrom pathlib import Path\n",
   "metadata": {
    "id": "cell-env",
    "trusted": true,
    "execution": {
     "iopub.status.busy": "2026-04-21T10:16:15.029209Z",
     "iopub.execute_input": "2026-04-21T10:16:15.029495Z",
     "iopub.status.idle": "2026-04-21T10:16:15.045903Z",
     "shell.execute_reply.started": "2026-04-21T10:16:15.029473Z",
     "shell.execute_reply": "2026-04-21T10:16:15.045209Z"
    }
   },
   "outputs": [],
   "execution_count": 27
  },
  {
   "id": "76ee1a2e-ccda-4916-8b7c-79044971cc99",
   "cell_type": "code",
   "source": "# ── Auto-detect the attached datasets (Kaggle sometimes nests them under /kaggle/input/datasets/<user>/) ──\nINPUT_ROOT = Path('/kaggle/input')\n\ndef find_dataset(slug: str, required: bool = True) -> Path:\n    # Common mount points, in order\n    for cand in [INPUT_ROOT / slug, *INPUT_ROOT.rglob(slug)]:\n        if cand.is_dir():\n            return cand\n    if required:\n        raise FileNotFoundError(f'Dataset {slug!r} not attached to this notebook')\n    return None\n\nCODE_SRC = find_dataset('cxr-vlm-code')\n# One slug holds both datasets — MIMIC-CXR (+ VQA) and/or IU-Xray.\nDATA_SRC = find_dataset('cxr-vlm-data')\n\nWORK     = Path('/kaggle/working')\nPROJECT  = WORK / 'cxr_vlm'\n\n# Kaggle inputs are read-only → copy code to /kaggle/working so we can write configs / __init__ / ckpt state\nif not PROJECT.exists():\n    shutil.copytree(CODE_SRC, PROJECT)\n\nos.chdir(PROJECT)\nsys.path.insert(0, str(PROJECT))\nprint('CODE_SRC :', CODE_SRC)\nprint('DATA_SRC :', DATA_SRC)\nprint('PROJECT  :', PROJECT)\nprint('Contents :', sorted(os.listdir(PROJECT)))",
   "metadata": {
    "id": "cell-paths",
    "trusted": true,
    "execution": {
     "iopub.status.busy": "2026-04-21T10:16:15.047039Z",
     "iopub.execute_input": "2026-04-21T10:16:15.047351Z",
     "iopub.status.idle": "2026-04-21T10:25:54.218040Z",
     "shell.execute_reply.started": "2026-04-21T10:16:15.047328Z",
     "shell.execute_reply": "2026-04-21T10:25:54.217191Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "text": "CODE_SRC : /kaggle/input/datasets/mycatis/cxr-vlm-code\nDATA_SRC : /kaggle/input/datasets/mycatis/cxr-vlm-data\nPROJECT  : /kaggle/working/cxr_vlm\nContents : ['configs', 'data', 'evaluation', 'model', 'requirements.txt', 'scripts', 'training', 'utils']\n",
     "output_type": "stream"
    }
   ],
   "execution_count": 28
  },
  {
   "id": "a596814c-23ff-4364-84cf-e8bf8ceb320f",
   "cell_type": "code",
   "source": "# ── pip install (matches the versions the code was debugged against) ──\n# torchao 0.10.0 ships on Kaggle and breaks peft 0.13+ — remove it, pin peft<0.13.\n!pip uninstall -y -q torchao || true\n\n!pip install -q \\\n    'transformers>=4.41,<4.46' \\\n    'peft>=0.11,<0.13' \\\n    'accelerate>=0.30' \\\n    'bitsandbytes>=0.43' \\\n    'huggingface_hub>=0.23' \\\n    omegaconf sentencepiece 'protobuf>=3.20'\n\n!pip install -q nltk rouge-score bert-score sacrebleu\n# NOTE: hi-ml-multimodal is intentionally NOT installed — model/image_encoder.py\n# has a timm ViT-B/16 fallback (backend='auto' picks 'vit' when BioViL-T is unavailable).",
   "metadata": {
    "id": "cell-pip",
    "trusted": true,
    "execution": {
     "iopub.status.busy": "2026-04-21T10:25:54.219100Z",
     "iopub.execute_input": "2026-04-21T10:25:54.219419Z",
     "iopub.status.idle": "2026-04-21T10:26:02.536805Z",
     "shell.execute_reply.started": "2026-04-21T10:25:54.219395Z",
     "shell.execute_reply": "2026-04-21T10:26:02.535704Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "text": "\u001b[33mWARNING: Skipping torchao as it is not installed.\u001b[0m\u001b[33m\n\u001b[0m",
     "output_type": "stream"
    }
   ],
   "execution_count": 29
  },
  {
   "id": "d6e9c24c-bfa1-4965-966b-0967a5506ca8",
   "cell_type": "code",
   "source": "import torch\nprint('torch', torch.__version__,\n      '| cuda', torch.cuda.is_available(),\n      '|', torch.cuda.get_device_name(0) if torch.cuda.is_available() else '')\nprint('cuda cap:', torch.cuda.get_device_capability(0) if torch.cuda.is_available() else 'n/a')\n# Sanity: sm_60 (P100) is NOT supported by the prebuilt wheels in PyTorch 2.7+.\n# If you see (6, 0) here, switch the accelerator to T4 x2 or L4.",
   "metadata": {
    "id": "cell-torch",
    "trusted": true,
    "execution": {
     "iopub.status.busy": "2026-04-21T10:26:02.539794Z",
     "iopub.execute_input": "2026-04-21T10:26:02.540165Z",
     "iopub.status.idle": "2026-04-21T10:26:02.546873Z",
     "shell.execute_reply.started": "2026-04-21T10:26:02.540135Z",
     "shell.execute_reply": "2026-04-21T10:26:02.545858Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "text": "torch 2.10.0+cu128 | cuda True | Tesla T4\ncuda cap: (7, 5)\n",
     "output_type": "stream"
    }
   ],
   "execution_count": 30
  },
  {
   "id": "d92272c6-e912-4035-ac03-e25de0ace4d0",
   "cell_type": "markdown",
   "source": "## 2. Locate data on Kaggle\n\nBoth datasets live under the single `cxr-vlm-data` slug. Expected layouts:\n\n**MIMIC-CXR**:\n```\nDATA_SRC/\n├── MIMIC-CXR/ (or at root)\n│   ├── train/p10/pXXXXXX/sYYYYY/*.jpg + sYYYYY.txt\n│   ├── valid/p10/...\n│   └── test/p10/...\n└── .../MIMIC-Ext-MIMIC-CXR-VQA/dataset/{train,valid,test}.json\n```\n\n**IU-Xray** (added alongside MIMIC under the same slug):\n```\nDATA_SRC/\n└── IU-Xray/\n    ├── images/        # CXR*_IM-*-*.png (~7.5k files)\n    └── labels/        # {1..3999}.xml   (~3.9k files, flat — no ecgen-radiology subfolder)\n```",
   "metadata": {
    "id": "cell-data-md"
   }
  },
  {
   "id": "e8e2572c-ff50-4337-9a2d-fc837cdb1dc1",
   "cell_type": "code",
   "source": "def find_split_parent(root: Path) -> Path:\n    for cand in [root, root / 'MIMIC-CXR', root / 'data' / 'MIMIC-CXR']:\n        if (cand / 'train').exists() and (cand / 'valid').exists() and (cand / 'test').exists():\n            return cand\n    for p in root.rglob('train'):\n        if p.is_dir() and (p.parent / 'valid').exists() and (p.parent / 'test').exists():\n            return p.parent\n    raise FileNotFoundError('Could not find train/ valid/ test/ under ' + str(root))\n\n\ndef find_iu_dirs(root: Path):\n    \"\"\"Locate IU-Xray `images/` and `labels/` (flat XMLs) under `root`.\n\n    Resolution order:\n      1. `{root}/IU-Xray/{images,labels}` — canonical layout.\n      2. Any nested `IU-Xray` folder that contains both.\n      3. Fallback: any folder containing CXR*.png (images) and\n         any folder containing *.xml — whichever comes first.\n\n    The labels subfolder is treated as a flat directory of XMLs (we no\n    longer require the legacy `ecgen-radiology/` subfolder).\n    \"\"\"\n    # Canonical + nested\n    for cand in [root / 'IU-Xray', *root.rglob('IU-Xray')]:\n        if not cand.is_dir():\n            continue\n        imgs = cand / 'images'\n        lbls = cand / 'labels'\n        if imgs.is_dir() and lbls.is_dir() and any(lbls.glob('*.xml')):\n            return imgs, lbls\n        # Legacy: labels/ecgen-radiology/*.xml\n        legacy = lbls / 'ecgen-radiology'\n        if imgs.is_dir() and legacy.is_dir() and any(legacy.glob('*.xml')):\n            return imgs, legacy\n\n    # Fallback: any images/ with CXR*.png + any folder with XML\n    img_dir = lbl_dir = None\n    for cand in [root / 'images', *root.rglob('images')]:\n        if cand.is_dir() and any(cand.glob('CXR*.png')):\n            img_dir = cand; break\n    for cand in [root / 'labels', *root.rglob('labels')]:\n        if cand.is_dir() and any(cand.glob('*.xml')):\n            lbl_dir = cand; break\n    if lbl_dir is None:\n        # very last resort — any ecgen-radiology folder with XMLs\n        for cand in root.rglob('ecgen-radiology'):\n            if cand.is_dir() and any(cand.glob('*.xml')):\n                lbl_dir = cand; break\n    return img_dir, lbl_dir\n\n\n# Filled in below depending on DATASET_NAME\nCXR_ROOT      = None                  # MIMIC-CXR root (with train/valid/test subdirs)\nSPLIT_DIRS    = None                  # MIMIC only\nVQA_ROOT      = None                  # MIMIC only\nIU_IMAGES_DIR = None                  # IU-Xray only\nIU_LABELS_DIR = None                  # IU-Xray only\n\nif DATASET_NAME == 'MIMIC-CXR':\n    CXR_ROOT = find_split_parent(DATA_SRC)\n    print('MIMIC-CXR root:', CXR_ROOT)\n\n    SPLIT_DIRS = {\n        'train'   : ('train', CXR_ROOT / 'train'),\n        'validate': ('valid', CXR_ROOT / 'valid'),\n        'test'    : ('test',  CXR_ROOT / 'test'),\n    }\n    for s, (sub, d) in SPLIT_DIRS.items():\n        assert d.exists(), f'Missing split dir: {d}'\n        print(f'  {s:<9s} → {d}')\n\n    for p in DATA_SRC.rglob('MIMIC-Ext-MIMIC-CXR-VQA'):\n        cand = p / 'dataset'\n        if cand.exists() and (cand / 'train.json').exists():\n            VQA_ROOT = cand\n            break\n    assert VQA_ROOT is not None, 'VQA dataset folder not found under ' + str(DATA_SRC)\n    print('VQA root:', VQA_ROOT)\n\nelse:   # IU-Xray\n    IU_IMAGES_DIR, IU_LABELS_DIR = find_iu_dirs(DATA_SRC)\n    assert IU_IMAGES_DIR is not None, f'IU images/ not found under {DATA_SRC}'\n    assert IU_LABELS_DIR is not None, f'IU labels/ (with *.xml) not found under {DATA_SRC}'\n    print('IU images dir:', IU_IMAGES_DIR, '→', len(list(IU_IMAGES_DIR.glob('*.png'))), 'PNGs')\n    print('IU labels dir:', IU_LABELS_DIR, '→', len(list(IU_LABELS_DIR.glob('*.xml'))), 'XMLs')",
   "metadata": {
    "id": "cell-find-data-mimic",
    "trusted": true,
    "execution": {
     "iopub.status.busy": "2026-04-21T10:26:02.547941Z",
     "iopub.execute_input": "2026-04-21T10:26:02.548677Z",
     "iopub.status.idle": "2026-04-21T10:28:48.873813Z",
     "shell.execute_reply.started": "2026-04-21T10:26:02.548639Z",
     "shell.execute_reply": "2026-04-21T10:28:48.872952Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "text": "IU images dir: /kaggle/input/datasets/mycatis/cxr-vlm-data/IU-Xray/images → 7470 PNGs\nIU labels dir: /kaggle/input/datasets/mycatis/cxr-vlm-data/IU-Xray/labels → 3955 XMLs\n",
     "output_type": "stream"
    }
   ],
   "execution_count": 31
  },
  {
   "id": "f7f0c889-b3a8-4192-ab5e-b903decb8fe3",
   "cell_type": "markdown",
   "source": "## 3. Build the unified instruction JSON\n\n- **MIMIC-CXR**: parse report `.txt` files for findings/impression + attach VQA rows from `MIMIC-Ext-MIMIC-CXR-VQA`.\n- **IU-Xray**: handled automatically by `data.iu_xray_builder` when `train.py` runs (or you can trigger the build manually in the cell below).\n\nEither way, the resulting JSON has the same schema (`image_path`, `task`, `target`, `question`, `structured_findings`, `split`, ...) so `CXRInstructDataset` loads it unchanged.",
   "metadata": {
    "id": "cell-json-md"
   }
  },
  {
   "id": "702ac89c-802a-4a46-b7c6-9c94fa948e1e",
   "cell_type": "code",
   "source": "import json, re\nfrom tqdm.auto import tqdm\n\nif DATASET_NAME == 'MIMIC-CXR':\n    FINDINGS_RE   = re.compile(r'FINDINGS\\s*:\\s*(.*?)(?=\\n\\s*[A-Z ]{3,}\\s*:|\\Z)', re.S | re.I)\n    IMPRESSION_RE = re.compile(r'IMPRESSION\\s*:\\s*(.*?)(?=\\n\\s*[A-Z ]{3,}\\s*:|\\Z)', re.S | re.I)\n\n    def clean(txt: str) -> str:\n        return re.sub(r'\\s+', ' ', txt).strip()\n\n    def parse_report(txt_path: Path):\n        t = txt_path.read_text(errors='ignore')\n        f = FINDINGS_RE.search(t)\n        i = IMPRESSION_RE.search(t)\n        return (clean(f.group(1)) if f else None,\n                clean(i.group(1)) if i else None)\n\n    # image_index keyed on the *subject-relative* path (no split prefix) so VQA json can look it up.\n    image_index = {}                   # 'p10/pXXX/sYYY/img.jpg' -> (split_tag, split_sub, abs_path)\n    studies     = {}                   # (split_tag, subj, study) -> {report_txt, [sub_relpaths]}\n\n    for split_tag, (split_sub, split_dir) in SPLIT_DIRS.items():\n        for p_dir in sorted(split_dir.glob('p*')):\n            for pat_dir in p_dir.glob('p*'):\n                for study_dir in pat_dir.glob('s*'):\n                    imgs = sorted(study_dir.glob('*.jpg'))\n                    txts = list(study_dir.glob('*.txt'))\n                    if not imgs:\n                        continue\n                    report_txt = txts[0] if txts else None\n                    sub_relpaths = []\n                    for img in imgs:\n                        sub_rel = f'{p_dir.name}/{pat_dir.name}/{study_dir.name}/{img.name}'\n                        image_index[sub_rel] = (split_tag, split_sub, img)\n                        sub_relpaths.append(sub_rel)\n                    studies[(split_tag, pat_dir.name, study_dir.name)] = {\n                        'report_txt': report_txt,\n                        'images':     sub_relpaths,\n                        'split_sub':  split_sub,\n                    }\n\n    print(f'Indexed {len(image_index):,} images across {len(studies):,} studies')\nelse:\n    print('IU-Xray: skipping MIMIC indexing cell.')",
   "metadata": {
    "id": "cell-parse",
    "trusted": true,
    "execution": {
     "iopub.status.busy": "2026-04-21T10:28:48.874868Z",
     "iopub.execute_input": "2026-04-21T10:28:48.875199Z",
     "iopub.status.idle": "2026-04-21T10:28:48.884052Z",
     "shell.execute_reply.started": "2026-04-21T10:28:48.875173Z",
     "shell.execute_reply": "2026-04-21T10:28:48.883270Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "text": "IU-Xray: skipping MIMIC indexing cell.\n",
     "output_type": "stream"
    }
   ],
   "execution_count": 32
  },
  {
   "id": "d300c250-4213-42f9-8567-8b3ce79ad7df",
   "cell_type": "code",
   "source": "if DATASET_NAME == 'MIMIC-CXR':\n    samples = []\n    missing_report = no_findings = no_impression = 0\n\n    for (split_tag, subj, study), meta in tqdm(studies.items(), desc='reports'):\n        if meta['report_txt'] is None:\n            missing_report += 1\n            continue\n        findings, impression = parse_report(meta['report_txt'])\n        split_sub = meta['split_sub']\n        for sub_rel in meta['images']:\n            rel_with_split = f'{split_sub}/{sub_rel}'\n            if findings:\n                samples.append({\n                    'image_path': rel_with_split,\n                    'task':       'findings',\n                    'target':     findings,\n                    'question':   None,\n                    'structured_findings': None,\n                    'split':      split_tag,\n                    'study_id':   study,\n                    'subject_id': subj,\n                })\n            else:\n                no_findings += 1\n            if impression:\n                samples.append({\n                    'image_path': rel_with_split,\n                    'task':       'impression',\n                    'target':     impression,\n                    'question':   None,\n                    'structured_findings': None,\n                    'split':      split_tag,\n                    'study_id':   study,\n                    'subject_id': subj,\n                })\n            else:\n                no_impression += 1\n\n    print(f'findings+impression samples: {len(samples):,}')\n    print(f'  missing report.txt: {missing_report}, w/o FINDINGS: {no_findings}, w/o IMPRESSION: {no_impression}')\nelse:\n    samples = None\n    print('IU-Xray: skipping MIMIC report parsing cell.')",
   "metadata": {
    "id": "cell-build-findings",
    "trusted": true,
    "execution": {
     "iopub.status.busy": "2026-04-21T10:28:48.885047Z",
     "iopub.execute_input": "2026-04-21T10:28:48.885349Z",
     "iopub.status.idle": "2026-04-21T10:28:48.901431Z",
     "shell.execute_reply.started": "2026-04-21T10:28:48.885327Z",
     "shell.execute_reply": "2026-04-21T10:28:48.900806Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "text": "IU-Xray: skipping MIMIC report parsing cell.\n",
     "output_type": "stream"
    }
   ],
   "execution_count": 33
  },
  {
   "id": "f36553bb-50d7-481e-8e33-7fea94013b63",
   "cell_type": "code",
   "source": "if DATASET_NAME == 'MIMIC-CXR':\n    vqa_split_map = {'train': 'train', 'valid': 'validate', 'test': 'test'}\n    vqa_added = vqa_missed = 0\n\n    def normalize_vqa_relpath(p: str) -> str:\n        p = p.lstrip('/')\n        if p.startswith('files/'):\n            p = p[len('files/'):]\n        return p\n\n    for vqa_file_name, split_tag in vqa_split_map.items():\n        vqa_path = VQA_ROOT / f'{vqa_file_name}.json'\n        if not vqa_path.exists():\n            print(f'skip {vqa_path} (not present)')\n            continue\n        with open(vqa_path) as f:\n            vqa_rows = json.load(f)\n        for row in tqdm(vqa_rows, desc=f'vqa/{vqa_file_name}'):\n            sub_rel = normalize_vqa_relpath(row['image_path'])\n            if sub_rel not in image_index:\n                vqa_missed += 1\n                continue\n            _, split_sub, _ = image_index[sub_rel]\n            ans = row.get('answer', [])\n            if isinstance(ans, list):\n                answer = ', '.join(map(str, ans)) if ans else 'No.'\n            else:\n                answer = str(ans)\n            samples.append({\n                'image_path': f'{split_sub}/{sub_rel}',\n                'task':       'vqa',\n                'target':     answer,\n                'question':   row['question'],\n                'structured_findings': None,\n                'split':      split_tag,\n                'study_id':   row.get('study_id'),\n                'subject_id': row.get('subject_id'),\n            })\n            vqa_added += 1\n\n    print(f'VQA samples added: {vqa_added:,}, dropped (image outside subset): {vqa_missed:,}')\nelse:\n    print('IU-Xray: skipping MIMIC VQA cell.')",
   "metadata": {
    "id": "cell-build-vqa",
    "trusted": true,
    "execution": {
     "iopub.status.busy": "2026-04-21T10:28:48.902314Z",
     "iopub.execute_input": "2026-04-21T10:28:48.902574Z",
     "iopub.status.idle": "2026-04-21T10:28:48.917460Z",
     "shell.execute_reply.started": "2026-04-21T10:28:48.902547Z",
     "shell.execute_reply": "2026-04-21T10:28:48.916661Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "text": "IU-Xray: skipping MIMIC VQA cell.\n",
     "output_type": "stream"
    }
   ],
   "execution_count": 34
  },
  {
   "id": "4e0561c4-87b4-49a2-8829-3e3a4e5830be",
   "cell_type": "code",
   "source": "if DATASET_NAME == 'MIMIC-CXR':\n    before = len(samples)\n    samples = [s for s in samples if (CXR_ROOT / s['image_path']).exists()]\n    print(f'filtered out {before - len(samples)} samples w/ missing image files — kept {len(samples):,}')\nelse:\n    print('IU-Xray: skipping.')",
   "metadata": {
    "id": "cell-filter",
    "trusted": true,
    "execution": {
     "iopub.status.busy": "2026-04-21T10:28:48.918395Z",
     "iopub.execute_input": "2026-04-21T10:28:48.918702Z",
     "iopub.status.idle": "2026-04-21T10:28:48.933600Z",
     "shell.execute_reply.started": "2026-04-21T10:28:48.918679Z",
     "shell.execute_reply": "2026-04-21T10:28:48.933057Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "text": "IU-Xray: skipping.\n",
     "output_type": "stream"
    }
   ],
   "execution_count": 35
  },
  {
   "id": "7c2237cb-03c3-46c5-9ff8-4fb08eb00ef7",
   "cell_type": "code",
   "source": "if DATASET_NAME == 'MIMIC-CXR':\n    from collections import Counter\n    print('By task :', Counter(s['task']  for s in samples))\n    print('By split:', Counter(s['split'] for s in samples))\n\n    out_dir  = PROJECT / 'data' / 'data_files'\n    out_dir.mkdir(parents=True, exist_ok=True)\n    mimic_json_path = out_dir / 'mimic_cxr_instruct_unified.json'\n    with open(mimic_json_path, 'w') as f:\n        json.dump(samples, f)\n    print('Wrote', mimic_json_path, f'({len(samples):,} samples)')\nelse:\n    # Build IU-Xray JSON (the resolver would also do this lazily, but doing it\n    # here gives us a nice summary log in the notebook).\n    from data.iu_xray_builder import build_iu_xray_instruct_json\n    out_dir  = PROJECT / 'data' / 'data_files'\n    out_dir.mkdir(parents=True, exist_ok=True)\n    iu_json_path = out_dir / 'iu_xray_instruct.json'\n    build_iu_xray_instruct_json(\n        images_dir  = str(IU_IMAGES_DIR),\n        labels_dir  = str(IU_LABELS_DIR),\n        output_path = str(iu_json_path),\n        train_ratio = 0.70, val_ratio = 0.15, test_ratio = 0.15, seed = 42,\n    )",
   "metadata": {
    "id": "cell-save-json",
    "trusted": true,
    "execution": {
     "iopub.status.busy": "2026-04-21T10:28:48.934683Z",
     "iopub.execute_input": "2026-04-21T10:28:48.935048Z",
     "iopub.status.idle": "2026-04-21T10:29:39.749241Z",
     "shell.execute_reply.started": "2026-04-21T10:28:48.935012Z",
     "shell.execute_reply": "2026-04-21T10:29:39.748211Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "text": "[iu_xray_builder] wrote 13891 samples → /kaggle/working/cxr_vlm/data/data_files/iu_xray_instruct.json\n  XMLs scanned     : 3955\n  reports kept     : 3826\n  skipped no_text  : 28\n  skipped no_image : 101\n  by split         : {'train': 9760, 'validate': 2040, 'test': 2091}\n  by task          : {'findings': 6473, 'impression': 7418}\n",
     "output_type": "stream"
    }
   ],
   "execution_count": 36
  },
  {
   "id": "1d1f7c35-ecb5-4219-bfcf-f134a62fb9a1",
   "cell_type": "markdown",
   "source": "## 4. Patch configs for the Kaggle environment\n\n- Sets `data.dataset_name` based on DATASET_NAME.\n- For **MIMIC-CXR**: points `mimic_cxr_root` + `instruct_json` at what we built above.\n- For **IU-Xray**: points `iu_xray.images_dir` + `iu_xray.labels_dir` + `iu_xray.instruct_json` at the Kaggle mount.\n- `training.output_root = /kaggle/working/ckpt` — Persistence keeps it. Run folders become e.g. `ckpt/IU-Xray_run_1/`.\n- **4-bit QLoRA** (8-bit can trip `named symbol not found` on some bnb/CUDA combos).\n- WandB off, HF hub on.\n- Edit `hf_hub.repo_id` to your own HF model repo.",
   "metadata": {
    "id": "cell-cfg-md"
   }
  },
  {
   "id": "3f3ad9bf-90a2-4353-a252-647740e7a6d2",
   "cell_type": "code",
   "source": [
    "from omegaconf import OmegaConf\n",
    "\n",
    "train_cfg = OmegaConf.load(PROJECT / 'configs' / 'train_config.yaml')\n",
    "model_cfg = OmegaConf.load(PROJECT / 'configs' / 'model_config.yaml')\n",
    "\n",
    "# ── dataset selector ──\n",
    "train_cfg.data.dataset_name = DATASET_NAME\n",
    "\n",
    "# ── dataset-specific paths ──\n",
    "if DATASET_NAME == 'MIMIC-CXR':\n",
    "    train_cfg.data.mimic_cxr_root = str(CXR_ROOT)\n",
    "    train_cfg.data.instruct_json  = str(mimic_json_path)\n",
    "else:  # IU-Xray\n",
    "    train_cfg.data.iu_xray.images_dir    = str(IU_IMAGES_DIR)\n",
    "    train_cfg.data.iu_xray.labels_dir    = str(IU_LABELS_DIR)\n",
    "    train_cfg.data.iu_xray.instruct_json = str(iu_json_path)\n",
    "    train_cfg.data.iu_xray.auto_build    = True\n",
    "\n",
    "train_cfg.data.train_split = 'train'\n",
    "train_cfg.data.val_split   = 'validate'\n",
    "train_cfg.data.test_split  = 'test'\n",
    "\n",
    "# ── checkpoint root (Persistence keeps it across kernel restarts) ──\n",
    "CKPT_ROOT = WORK / 'ckpt'\n",
    "train_cfg.training.output_root = str(CKPT_ROOT)\n",
    "\n",
    "# ─────────────────────────────────────────────────────────────────────\n",
    "# GPU profile — change the 6 lines below if you switch GPU.\n",
    "# Effective batch = per_device * gradient_accumulation_steps.\n",
    "# Keep effective ≈ 16 across GPUs so learning rate stays valid.\n",
    "#\n",
    "#                          per_device  accum   precision   workers   effective\n",
    "# T4    (15 GB, sm_75) :        1        16    fp16  ✓        4         16\n",
    "# L4    (22 GB, sm_89) :        2         8    bf16  ✓        6         16\n",
    "# A100  (40 GB, sm_80) :        4         4    bf16  ✓        8         16\n",
    "# A100  (80 GB)        :        8         2    bf16  ✓       16         16\n",
    "# H100  (80 GB, sm_90) :        8         2    bf16  ✓       16         16\n",
    "#\n",
    "# bf16 chỉ chạy trên Ampere+ (sm_80+); T4/V100 không có bf16 → ép fp16.\n",
    "# ─────────────────────────────────────────────────────────────────────\n",
    "# ── ACTIVE values: T4-15GB (Kaggle) ──\n",
    "train_cfg.training.per_device_train_batch_size = 1     # T4 active | A100-40: 4 | A100-80: 8\n",
    "train_cfg.training.per_device_eval_batch_size  = 1\n",
    "train_cfg.training.gradient_accumulation_steps = 16    # T4 active | A100-40: 4 | A100-80: 2\n",
    "train_cfg.training.fp16                        = True  # T4 active (T4 không có bf16)\n",
    "train_cfg.training.bf16                        = False # T4 active\n",
    "train_cfg.training.dataloader_num_workers      = 4     # T4 active | A100-40: 8 | A100-80: 16\n",
    "\n",
    "# (Optional) tăng cutoff_len trên A100 nếu báo cáo dài bị truncate.\n",
    "# train_cfg.training.cutoff_len = 1024   # T4: 512\n",
    "\n",
    "# ── stage epochs (typo của bản cũ là num_epoch — không có 's' → bị bỏ qua) ──\n",
    "train_cfg.stage2.num_epochs = 5\n",
    "\n",
    "# ── wandb off ──\n",
    "train_cfg.wandb.enabled = False\n",
    "\n",
    "# ── HuggingFace Hub run tracking ──\n",
    "train_cfg.hf_hub.enabled        = True\n",
    "train_cfg.hf_hub.repo_id        = 'hieu3636/cxr-vlm-runs'   # <<< EDIT ME\n",
    "train_cfg.hf_hub.token_env      = 'HF_TOKEN'\n",
    "train_cfg.hf_hub.private        = True\n",
    "train_cfg.hf_hub.run_state_file = str(CKPT_ROOT / 'run_id.txt')\n",
    "\n",
    "# ─────────────────────────────────────────────────────────────────────\n",
    "# LLM dtype + quantization\n",
    "#\n",
    "# Vicuna-7B base size:       fp16/bf16 ≈ 13 GB | 4-bit ≈ 3.5 GB\n",
    "# T4 (15 GB)   → must use 4-bit QLoRA, fp16 compute.\n",
    "# A100 (40 GB) → có thể tắt 4-bit, dùng bf16 thuần (chất lượng tốt hơn).\n",
    "# A100 (80 GB) → tắt 4-bit, batch=8.\n",
    "# ─────────────────────────────────────────────────────────────────────\n",
    "model_cfg.llm.load_in_8bit = False\n",
    "model_cfg.llm.load_in_4bit = True              # T4: bắt buộc | A100: tùy chọn\n",
    "model_cfg.llm.torch_dtype  = 'float16'         # T4: 'float16' | A100: 'bfloat16'\n",
    "model_cfg.chexpert_classifier.enabled = False\n",
    "\n",
    "OmegaConf.save(train_cfg, PROJECT / 'configs' / 'train_config.yaml')\n",
    "OmegaConf.save(model_cfg, PROJECT / 'configs' / 'model_config.yaml')\n",
    "\n",
    "# Sanity print: confirm GPU + active settings line up.\n",
    "import torch\n",
    "gpu  = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'\n",
    "vram = (torch.cuda.get_device_properties(0).total_memory / 1e9\n",
    "        if torch.cuda.is_available() else 0)\n",
    "print(f'>>> GPU detected: {gpu} ({vram:.1f} GB VRAM)')\n",
    "print(f'>>> per_device batch={train_cfg.training.per_device_train_batch_size}, '\n",
    "      f'accum={train_cfg.training.gradient_accumulation_steps}, '\n",
    "      f'fp16={train_cfg.training.fp16}, bf16={train_cfg.training.bf16}')\n",
    "if 'T4' in gpu and train_cfg.training.bf16:\n",
    "    print('!!! WARNING: bf16=True nhưng đang chạy T4 — T4 không hỗ trợ bf16. '\n",
    "          'Đổi sang fp16=True / bf16=False trước khi train.')\n",
    "if ('A100' in gpu or 'H100' in gpu or 'L4' in gpu) and train_cfg.training.fp16:\n",
    "    print('!!! Note: đang dùng fp16 trên Ampere+ — bf16 ổn định hơn. Cân nhắc đổi.')\n",
    "\n",
    "print()\n",
    "print('--- train_cfg.data ---');     print(OmegaConf.to_yaml(train_cfg.data))\n",
    "print('--- train_cfg.training ---'); print(OmegaConf.to_yaml(train_cfg.training))\n",
    "print('--- train_cfg.hf_hub ---');   print(OmegaConf.to_yaml(train_cfg.hf_hub))\n",
    "print('--- model_cfg.llm ---');      print(OmegaConf.to_yaml(model_cfg.llm))\n"
   ],
   "metadata": {
    "id": "cell-cfg",
    "trusted": true,
    "execution": {
     "iopub.status.busy": "2026-04-21T10:29:39.750563Z",
     "iopub.execute_input": "2026-04-21T10:29:39.750960Z",
     "iopub.status.idle": "2026-04-21T10:29:39.804753Z",
     "shell.execute_reply.started": "2026-04-21T10:29:39.750913Z",
     "shell.execute_reply": "2026-04-21T10:29:39.803917Z"
    }
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "id": "85182004-6fcb-4546-909e-01b798d0f662",
   "cell_type": "markdown",
   "source": "## 5. HF auth\n\nRequires a Kaggle secret named **`HF_TOKEN`** (Add-ons → Secrets).",
   "metadata": {
    "id": "cell-hf-md"
   }
  },
  {
   "id": "b18a9613-e27e-40f4-ad6b-3bac49bc9ea3",
   "cell_type": "code",
   "source": "try:\n    from kaggle_secrets import UserSecretsClient\n    os.environ['HF_TOKEN'] = UserSecretsClient().get_secret('HF_TOKEN')\n    print('HF_TOKEN loaded from Kaggle secret ✓')\nexcept Exception as e:\n    print('No HF_TOKEN secret — Vicuna-7B download may rate-limit and hub upload will be disabled:', e)",
   "metadata": {
    "id": "cell-hf-token",
    "trusted": true,
    "execution": {
     "iopub.status.busy": "2026-04-21T10:29:39.806107Z",
     "iopub.execute_input": "2026-04-21T10:29:39.806402Z",
     "iopub.status.idle": "2026-04-21T10:29:39.881692Z",
     "shell.execute_reply.started": "2026-04-21T10:29:39.806368Z",
     "shell.execute_reply": "2026-04-21T10:29:39.881062Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "text": "HF_TOKEN loaded from Kaggle secret ✓\n",
     "output_type": "stream"
    }
   ],
   "execution_count": 38
  },
  {
   "id": "c20528c5-1ea1-4da6-a12a-d4caea5f3ab3",
   "cell_type": "code",
   "source": [
    "# Sanity-import the package modules.\n",
    "import importlib, model, data, utils\n",
    "importlib.reload(model); importlib.reload(data); importlib.reload(utils)\n",
    "from model import CXRVisionLanguageModel\n",
    "from data  import CXRInstructDataset, CXRDataCollator\n",
    "from utils.hf_uploader      import build_tracker_from_cfg\n",
    "from utils.dataset_resolver import resolve_dataset_spec, resolve_run_id\n",
    "print('Imports OK')\n",
    "\n",
    "# Show what will be used and what run_id will be picked.\n",
    "# Pass hf_repo_id + hf_token so the resolver scans BOTH local disk AND\n",
    "# the HF Hub when picking the next N — same logic train.py uses.\n",
    "spec = resolve_dataset_spec(train_cfg)\n",
    "run_id = resolve_run_id(\n",
    "    dataset_name = spec.dataset_name,\n",
    "    output_root  = str(CKPT_ROOT),\n",
    "    state_file   = str(CKPT_ROOT / 'run_id.txt'),\n",
    "    resuming     = False,\n",
    "    explicit     = None,\n",
    "    hf_repo_id   = train_cfg.hf_hub.repo_id,\n",
    "    hf_token     = os.environ.get('HF_TOKEN'),\n",
    ")\n",
    "print(f'Resolved spec : {spec}')\n",
    "print(f'Next run_id   : {run_id}')\n",
    "print(f'Checkpoints will land under: {CKPT_ROOT / run_id}')\n"
   ],
   "metadata": {
    "id": "cell-sanity",
    "trusted": true,
    "execution": {
     "iopub.status.busy": "2026-04-21T10:29:39.883693Z",
     "iopub.execute_input": "2026-04-21T10:29:39.884050Z",
     "iopub.status.idle": "2026-04-21T10:29:39.892955Z",
     "shell.execute_reply.started": "2026-04-21T10:29:39.884013Z",
     "shell.execute_reply": "2026-04-21T10:29:39.892162Z"
    }
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "cell-resume-md",
   "metadata": {},
   "source": [
    "## 5b. Resume a previous run (only if you were interrupted)\n",
    "\n",
    "**Skip this section if you're starting fresh.** Set `RESUME_STAGE = None` in the cell below and run Stage 1 → Stage 2 normally.\n",
    "\n",
    "### Khi nào cần resume\n",
    "\n",
    "| Tình huống | Cần làm |\n",
    "|---|---|\n",
    "| Stage 1 đang train dở, cùng VM | `RESUME_STAGE=1`, `EXPLICIT_RUN_ID=None` |\n",
    "| Stage 1 dở, **VM mới** | `RESUME_STAGE=1`, `EXPLICIT_RUN_ID=\"IU-Xray_run_1\"` |\n",
    "| Stage 1 xong, chạy tiếp stage 2 | `RESUME_STAGE=2`, `EXPLICIT_RUN_ID=None` (cùng VM) hoặc set (VM mới) |\n",
    "| Stage 2 đang dở, cùng VM | `RESUME_STAGE=2`, `EXPLICIT_RUN_ID=None` |\n",
    "| Stage 2 dở, **VM mới** | `RESUME_STAGE=2`, `EXPLICIT_RUN_ID=\"IU-Xray_run_1\"` |\n",
    "\n",
    "### Cách hoạt động\n",
    "\n",
    "Cell controller chỉ set 2 biến (`RESUME_STAGE` + `EXPLICIT_RUN_ID`). Việc pull checkpoint từ HF Hub giao cho `train.py` qua cờ `--resume_from_hf` — nó tự download `<run_id>/<stage>/last/` về local rồi resume từ đó.\n",
    "\n",
    "### Train tiếp sẽ lưu ở đâu trên HF?\n",
    "\n",
    "**Vẫn cùng folder `<RUN_ID>/`** trên HF, không tạo run mới. Sau resume:\n",
    "- `<RUN_ID>/stage2/last/` — checkpoint mới nhất (overwrite mỗi `save_steps`).\n",
    "- `<RUN_ID>/stage2/best/` — checkpoint có `eval_loss` thấp nhất.\n",
    "- `<RUN_ID>/stage2/training_log.jsonl` — log_history (replace mỗi save).\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "cell-resume",
   "metadata": {},
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Resume controller — set 2 variables, run once, then run cell-stage1 / cell-stage2.\n",
    "RESUME_STAGE    = None     # None | 1 | 2     (None = fresh run, skip this cell)\n",
    "EXPLICIT_RUN_ID = None     # None | \"IU-Xray_run_1\"   (set if VM is fresh and you want to continue)\n",
    "\n",
    "USE_RESUME_FROM_HF = False\n",
    "RESUME_RUN_ID      = None\n",
    "\n",
    "if RESUME_STAGE is not None:\n",
    "    assert RESUME_STAGE in (1, 2), \"RESUME_STAGE must be 1 or 2\"\n",
    "    USE_RESUME_FROM_HF = True\n",
    "    if EXPLICIT_RUN_ID:\n",
    "        RESUME_RUN_ID = EXPLICIT_RUN_ID\n",
    "        CKPT_ROOT.mkdir(parents=True, exist_ok=True)\n",
    "        (CKPT_ROOT / \"run_id.txt\").write_text(RESUME_RUN_ID)\n",
    "        print(f\"Will resume run {RESUME_RUN_ID} stage {RESUME_STAGE} (EXPLICIT_RUN_ID)\")\n",
    "    else:\n",
    "        state_file = CKPT_ROOT / \"run_id.txt\"\n",
    "        assert state_file.exists(), (\n",
    "            \"No local run_id.txt — looks like a fresh VM. \"\n",
    "            \"Set EXPLICIT_RUN_ID to the run folder on HF (e.g. \\\"IU-Xray_run_1\\\").\"\n",
    "        )\n",
    "        RESUME_RUN_ID = state_file.read_text().strip()\n",
    "        print(f\"Will resume run {RESUME_RUN_ID} stage {RESUME_STAGE} (from state file)\")\n",
    "    print(f\"\\u2192 train.py se tu pull {RESUME_RUN_ID}/stage{RESUME_STAGE}/last/ tu HF Hub\")\n",
    "    print(f\"\\u2192 Bay gio chay cell-stage{RESUME_STAGE} ben duoi.\")\n",
    "else:\n",
    "    print(\"RESUME_STAGE=None — fresh run. Skip; chạy thẳng cell-stage1.\")\n"
   ]
  },
  {
   "id": "3b6b5da1-a916-4f9e-9e30-7536eff992fa",
   "cell_type": "markdown",
   "source": "## 6. Stage 1 — projection layer only (~2 epochs)\n\nFirst launch creates `{DATASET_NAME}_run_1/` on HF and on disk. Subsequent fresh launches auto-increment to `run_2`, `run_3`, … — tracked via `ckpt/run_id.txt`.\n\nIf you need to continue training from an existing checkpoint, pass `--resume_from <ckpt>` — that reuses the same `run_N` folder.",
   "metadata": {
    "id": "cell-stage1-md"
   }
  },
  {
   "id": "453e3585-5e35-4108-8de0-81824b8ca459",
   "cell_type": "code",
   "source": [
    "# Picks up RESUME_STAGE / RESUME_RUN_ID / USE_RESUME_FROM_HF from cell-resume.\n",
    "_resume_args = \"\"\n",
    "if \"RESUME_STAGE\" in dir() and RESUME_STAGE == 1 and USE_RESUME_FROM_HF:\n",
    "    _resume_args = f'--resume_from_hf --run_id \"{RESUME_RUN_ID}\"'\n",
    "    print(\"\\u25b6 STAGE 1 resuming from HF Hub run\", RESUME_RUN_ID)\n",
    "elif \"RESUME_RUN_ID\" in dir() and RESUME_RUN_ID:\n",
    "    # Edge case: stage1 đã xong session trước, nay re-run cùng run_id.\n",
    "    _resume_args = f'--run_id \"{RESUME_RUN_ID}\"'\n",
    "    print(\"\\u25b6 STAGE 1 fresh, pinned to run_id\", RESUME_RUN_ID)\n",
    "else:\n",
    "    print(\"\\u25b6 STAGE 1 fresh run\")\n",
    "\n",
    "!HF_HUB_DISABLE_PROGRESS_BARS=1 TRANSFORMERS_VERBOSITY=warning TOKENIZERS_PARALLELISM=false BITSANDBYTES_NOWELCOME=1 PYTHONUNBUFFERED=1 \\\n",
    "python -u -m training.train \\\n",
    "    --model_config configs/model_config.yaml \\\n",
    "    --train_config configs/train_config.yaml \\\n",
    "    --stage 1 {_resume_args}\n"
   ],
   "metadata": {
    "id": "cell-stage1",
    "trusted": true,
    "execution": {
     "iopub.status.busy": "2026-04-21T10:29:39.893833Z",
     "iopub.execute_input": "2026-04-21T10:29:39.894148Z",
     "execution_failed": "2026-04-21T17:01:21.580Z"
    }
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "id": "71763eb0-6a77-4184-a591-eba3947ba41e",
   "cell_type": "markdown",
   "source": [
    "## 7. Stage 2 — projection + LoRA instruction tuning\n",
    "\n",
    "If a session is killed mid-train, you can pick up exactly where you left off:\n",
    "\n",
    "1. **Same VM** (Persistence on, files still on disk) → cell-resume with `RESUME_STAGE=2`, `EXPLICIT_RUN_ID=None`.\n",
    "2. **Fresh VM** (kernel killed, disk wiped) → cell-resume with `RESUME_STAGE=2`, `EXPLICIT_RUN_ID=\"<your run id>\"`. `train.py` tự pull `<run>/stage2/last/` từ HF Hub về và resume.\n",
    "\n",
    "Cả 2 trường hợp đều **reuse cùng run_id trên HF** — không tạo `run_N+1` mới.\n"
   ],
   "metadata": {
    "id": "cell-stage2-md"
   }
  },
  {
   "id": "ae16b58b-fa7a-4993-9e64-ec48a42aa70e",
   "cell_type": "code",
   "source": [
    "# Picks up RESUME_STAGE / RESUME_RUN_ID / USE_RESUME_FROM_HF from cell-resume.\n",
    "_resume_args = \"\"\n",
    "if \"RESUME_STAGE\" in dir() and RESUME_STAGE == 2 and USE_RESUME_FROM_HF:\n",
    "    _resume_args = f'--resume_from_hf --run_id \"{RESUME_RUN_ID}\"'\n",
    "    print(\"\\u25b6 STAGE 2 resuming from HF Hub run\", RESUME_RUN_ID)\n",
    "elif \"RESUME_RUN_ID\" in dir() and RESUME_RUN_ID:\n",
    "    _resume_args = f'--run_id \"{RESUME_RUN_ID}\"'\n",
    "    print(\"\\u25b6 STAGE 2 fresh, pinned to run_id\", RESUME_RUN_ID)\n",
    "else:\n",
    "    print(\"\\u25b6 STAGE 2 fresh (cùng session với stage1)\")\n",
    "\n",
    "!HF_HUB_DISABLE_PROGRESS_BARS=1 TRANSFORMERS_VERBOSITY=warning TOKENIZERS_PARALLELISM=false BITSANDBYTES_NOWELCOME=1 PYTHONUNBUFFERED=1 \\\n",
    "python -u -m training.train \\\n",
    "    --model_config configs/model_config.yaml \\\n",
    "    --train_config configs/train_config.yaml \\\n",
    "    --stage 2 {_resume_args}\n"
   ],
   "metadata": {
    "id": "cell-stage2",
    "trusted": true,
    "execution": {
     "execution_failed": "2026-04-21T17:01:21.580Z"
    }
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "id": "237e4b68-5584-4df8-b0f3-9295bed9c6a8",
   "cell_type": "markdown",
   "source": "## 8. Evaluate\n\nUploads `results/{RUN_ID}/predictions_*.json` + `metrics_summary.json` into the same `{RUN_ID}/results/` folder on HF.\n\nFor IU-Xray only `findings` + `impression` are evaluated (VQA is skipped automatically).",
   "metadata": {
    "id": "cell-eval-md"
   }
  },
  {
   "id": "be093922-fb4c-43f1-b87b-4da68035d89f",
   "cell_type": "code",
   "source": [
    "# Resolve run_id from state file\n",
    "RUN_ID      = (CKPT_ROOT / 'run_id.txt').read_text().strip()\n",
    "RESULTS_DIR = WORK / 'results'\n",
    "\n",
    "# Best checkpoint = <RUN_ID>/stage2/best/ on HF Hub.\n",
    "# Pull it locally if not already present.\n",
    "local_best_dir  = CKPT_ROOT / RUN_ID / 'stage2_instruct' / '_best_from_hf'\n",
    "local_best_proj = local_best_dir / 'checkpoint_projection.pt'\n",
    "\n",
    "if not local_best_proj.exists():\n",
    "    print(f'Pulling {RUN_ID}/stage2/best/ from HF Hub …')\n",
    "    from huggingface_hub import snapshot_download\n",
    "    pulled = snapshot_download(\n",
    "        repo_id        = train_cfg.hf_hub.repo_id,\n",
    "        repo_type      = 'model',\n",
    "        token          = os.environ['HF_TOKEN'],\n",
    "        allow_patterns = [f'{RUN_ID}/stage2/best/**'],\n",
    "        local_dir      = str(WORK / 'hf_eval_pull'),\n",
    "    )\n",
    "    hub_best = Path(pulled) / RUN_ID / 'stage2' / 'best'\n",
    "    assert hub_best.exists() and (hub_best / 'checkpoint_projection.pt').exists(), (\n",
    "        f'Could not find {RUN_ID}/stage2/best/checkpoint_projection.pt on HF repo '\n",
    "        f'{train_cfg.hf_hub.repo_id}. Did stage 2 finish?'\n",
    "    )\n",
    "    if local_best_dir.exists():\n",
    "        shutil.rmtree(local_best_dir)\n",
    "    shutil.copytree(hub_best, local_best_dir)\n",
    "    print(f'  pulled \\u2192 {local_best_dir}')\n",
    "\n",
    "CKPT_PATH = local_best_dir / 'checkpoint_projection.pt'\n",
    "print('Evaluating run_id:', RUN_ID)\n",
    "print('Checkpoint       :', CKPT_PATH)\n",
    "print('Results          \\u2192', RESULTS_DIR)\n",
    "\n",
    "!HF_HUB_DISABLE_PROGRESS_BARS=1 TRANSFORMERS_VERBOSITY=warning TOKENIZERS_PARALLELISM=false BITSANDBYTES_NOWELCOME=1 PYTHONUNBUFFERED=1 \\\n",
    "python -u -m evaluation.evaluate \\\n",
    "    --model_config configs/model_config.yaml \\\n",
    "    --train_config configs/train_config.yaml \\\n",
    "    --checkpoint \"{CKPT_PATH}\" \\\n",
    "    --run_id     \"{RUN_ID}\" \\\n",
    "    --task       all \\\n",
    "    --output_dir \"{RESULTS_DIR}\"\n"
   ],
   "metadata": {
    "id": "cell-eval",
    "trusted": true,
    "execution": {
     "execution_failed": "2026-04-21T17:01:21.580Z"
    }
   },
   "outputs": [],
   "execution_count": null
  }
 ]
}