convitom commited on
Commit
35d4872
·
1 Parent(s): 320063f
Files changed (3) hide show
  1. model/cxr_vlm.py +15 -0
  2. opti.py +4 -0
  3. scripts/resize_and_shard.ipynb +373 -20
model/cxr_vlm.py CHANGED
@@ -554,3 +554,18 @@ class CXRVisionLanguageModel(nn.Module):
554
  trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
555
  print(f"Trainable params: {trainable:,} / {total:,} "
556
  f"({100 * trainable / total:.2f}%)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
  trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
555
  print(f"Trainable params: {trainable:,} / {total:,} "
556
  f"({100 * trainable / total:.2f}%)")
557
+
558
+ # Tensor-count breakdown matching HF Trainer's optimizer param_groups
559
+ # (group 0 = weight-decay params, group 1 = biases + LayerNorm).
560
+ # Useful for diagnosing "optimizer state dict size mismatch" on resume.
561
+ try:
562
+ from transformers.trainer_pt_utils import get_parameter_names
563
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
564
+ decay = set(get_parameter_names(self, ALL_LAYERNORM_LAYERS))
565
+ decay = {n for n in decay if "bias" not in n}
566
+ g0 = [n for n, p in self.named_parameters() if n in decay and p.requires_grad]
567
+ g1 = [n for n, p in self.named_parameters() if n not in decay and p.requires_grad]
568
+ print(f" optimizer group 0 (decay): {len(g0)} tensors")
569
+ print(f" optimizer group 1 (no decay): {len(g1)} tensors")
570
+ except Exception as e:
571
+ print(f" [param-group breakdown skipped: {e}]")
opti.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import torch
2
+ ckpt = torch.load(r"C:\Users\admin\Downloads\optimizer.pt", map_location="cpu", weights_only=False)
3
+ for i, g in enumerate(ckpt["param_groups"]):
4
+ print(f"saved group {i}: {len(g['params'])} params")
scripts/resize_and_shard.ipynb CHANGED
@@ -4,97 +4,450 @@
4
  "cell_type": "markdown",
5
  "id": "c00",
6
  "metadata": {},
7
- "source": "# CXR-VLM -- Resize + tar-shard dataset (one-time, offline)\n\nShrinks the original MIMIC-CXR tree (~2-3 MP/image) to RAD-DINO's working\nresolution and packs it into a few tar shards, so cloud training boxes\n(Vast.ai / Lightning.ai / Colab) pull a small, transfer-friendly dataset\ninstead of ~100 GB of huge JPGs read every epoch.\n\n**Source / destination (HF dataset repo `hieu3636/cxr-vlm-data`):**\n- read : `MIMIC-CXR_processed/shards/*.tar` (already tar-sharded source)\n- write : `MIMIC-CXR_resized/shards/cxr-NNNN.tar` (+ `_manifest.json`, `SHARDS.txt`)\n\n**Flow: streaming per shard** -- for each source shard: download one tar ->\nextract -> resize/copy contents into the cumulative `resized/` tree ->\ndelete the tar + extract scratch. Peak disk usage ~10 GB instead of\n~200 GB if you downloaded + extracted everything first.\n\n**Why HF and not Google Drive:** notebook is meant to run on arbitrary\ncloud GPUs. Drive only mounts conveniently on Colab and rate-limits badly\non bulk many-file reads. HF works everywhere with just a token,\n`hf_hub_download` is parallel + resumable, and the data already lives\nthere. Both download and upload go through HF.\n\n**This step does NOT change what the model sees** -- RAD-DINO's processor\nresizes the shortest edge to 518 and center-crops 518x518 anyway; we just\ndo that downscale once, offline, instead of every epoch on full-res\nimages. Reports (`.txt`), CheXpert (`.csv`), and any other non-image\nfiles are copied verbatim so the resized release is a faithful mirror.\n\n**Prerequisite:** an `HF_TOKEN` with **write** access to the repo."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  },
9
  {
10
  "cell_type": "markdown",
11
  "id": "c01",
12
  "metadata": {},
13
- "source": "## 0. Config -- edit here"
 
 
14
  },
15
  {
16
  "cell_type": "code",
17
- "id": "c02",
18
  "execution_count": null,
 
19
  "metadata": {},
20
  "outputs": [],
21
- "source": "import os\nfrom pathlib import Path\n\nREPO_ID = \"hieu3636/cxr-vlm-data\"\nREPO_TYPE = \"dataset\"\nSRC_SUBDIR = \"MIMIC-CXR_processed\" # source folder on HF (tar-sharded under shards/)\nDST_SUBDIR = \"MIMIC-CXR_resized\" # output folder on HF (will hold shards/ too)\n\n# Big scratch disk on the VM (Vast/Lightning: /workspace, Colab: /content).\nWORK_DIR = Path(os.environ.get(\"WORK_DIR\", \"/workspace/cxr_resize\"))\n\n# --- resize params -------------------------------------------------------\nTARGET = 518 # shortest-edge target. MUST be >= 518 (RAD-DINO crops 518).\nSQUARE = False # False: keep aspect (518xN), flexible, processor crops at\n # train time. ~20% bigger.\n # True : also center-crop to 518x518 here -> file is exactly\n # 518x518 and the processor is a true no-op. Smaller,\n # but BAKES the crop (changing backbone/img_size later\n # needs a full rebuild). Recommended off for a thesis.\nQUALITY = 90 # JPEG quality (q90 + 4:4:4 = near-lossless for CXR)\nSHARD_GB = 2.0 # approx GB per tar shard (output)\nWORKERS = min(32, (os.cpu_count() or 8) * 4) # I/O-bound; PIL frees the GIL\n\n# Derived local paths -- streaming flow keeps disk small:\n# one tar at a time in DL_DIR, one tar extracted in SCRATCH, resized\n# tree accumulates in DST_TREE.\nDL_DIR = WORK_DIR / \"_dl\" # per-shard download buffer (~2 GB at a time)\nSCRATCH = WORK_DIR / \"_extract\" # per-shard extraction scratch (~2 GB at a time)\nDST_TREE = WORK_DIR / \"resized\" # cumulative resized tree (~5-8 GB final)\nSHARDS_DIR = WORK_DIR / \"shards\" # output tar shards (~5-8 GB final)\nfor p in (WORK_DIR, DL_DIR, SCRATCH, DST_TREE, SHARDS_DIR):\n p.mkdir(parents=True, exist_ok=True)\n\nassert TARGET >= 518, \"TARGET must be >= 518 (RAD-DINO upscales shortest edge to 518)\"\nprint(\"WORK_DIR:\", WORK_DIR, \"| TARGET:\", TARGET, \"| SQUARE:\", SQUARE, \"| WORKERS:\", WORKERS)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  },
23
  {
24
  "cell_type": "markdown",
25
  "id": "c03",
26
  "metadata": {},
27
- "source": "## 1. Setup -- deps + HF token\n\nToken resolution: env `HF_TOKEN` -> Colab `userdata` -> Kaggle secret."
 
 
 
 
28
  },
29
  {
30
  "cell_type": "code",
31
- "id": "c04",
32
  "execution_count": null,
 
33
  "metadata": {},
34
  "outputs": [],
35
- "source": "import os, sys, subprocess\nsubprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\",\n \"huggingface_hub>=0.24,<0.27\", \"Pillow>=10\", \"tqdm\"], check=True)\n\n# Be tolerant of slow/flaky chunks on cloud networks (default is ~10s).\nos.environ.setdefault(\"HF_HUB_DOWNLOAD_TIMEOUT\", \"60\")\n# Optional: faster + more robust large-file downloads via the Rust backend.\n# Set to \"1\" and `pip install hf_transfer` if you keep hitting broken\n# connections; leave off if it causes trouble on your network.\n# os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"\n\nif not os.environ.get(\"HF_TOKEN\"):\n try:\n from google.colab import userdata\n os.environ[\"HF_TOKEN\"] = userdata.get(\"HF_TOKEN\")\n except Exception:\n try:\n from kaggle_secrets import UserSecretsClient\n os.environ[\"HF_TOKEN\"] = UserSecretsClient().get_secret(\"HF_TOKEN\")\n except Exception:\n pass\n\nHF_TOKEN = os.environ.get(\"HF_TOKEN\")\nassert HF_TOKEN, \"HF_TOKEN missing -- set it via env var or platform secrets (needs WRITE access).\"\nprint(\"HF_TOKEN loaded OK\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  },
37
  {
38
  "cell_type": "markdown",
39
  "id": "c05",
40
  "metadata": {},
41
- "source": "## 2. Resize + pack logic (inlined, mirrors `scripts/build_resized_dataset.py`)\n\nUses a thread pool (not processes): PIL releases the GIL during\ndecode/resize/encode, so threads parallelise well and avoid notebook\nmultiprocessing pickling issues.\n\nSplit into two layers so the streaming orchestrator below can call the\nworker repeatedly (once per source shard) and accumulate counts before\nwriting a single final manifest:\n\n- `_walk_and_process(src, dst, ...)` -- walk one tree, resize images,\n copy non-images, return `(counts, errors, n_img, n_other)`. No I/O\n beyond reading src and writing dst; no manifest.\n- `resize_tree(...)` -- thin wrapper for standalone use (one src ->\n one dst -> manifest). Used by the script CLI.\n- `_write_manifest(...)` -- shared manifest writer.\n- `pack_shards(...)` -- bundle the final tree into ~2 GB tar shards."
 
 
 
 
 
 
 
42
  },
43
  {
44
  "cell_type": "code",
45
- "id": "c06",
46
  "execution_count": null,
 
47
  "metadata": {},
48
  "outputs": [],
49
- "source": "import os, json, shutil, tarfile, time\nfrom pathlib import Path\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\nfrom PIL import Image\nfrom tqdm.auto import tqdm\n\nImage.MAX_IMAGE_PIXELS = None # don't abort on large medical images\nIMG_EXTS = (\".jpg\", \".jpeg\", \".png\")\n\n\ndef _resize_one(src_path, dst_path, target, quality, square):\n \"\"\"Returns one of: resized | squared | copied | skipped | error:<msg>.\"\"\"\n try:\n dst_path = Path(dst_path)\n if dst_path.exists() and dst_path.stat().st_size > 0:\n return \"skipped\" # resumable\n dst_path.parent.mkdir(parents=True, exist_ok=True)\n with Image.open(src_path) as im:\n w, h = im.size\n shorter = min(w, h)\n # Non-square: if shorter side already <= target, downscaling would\n # push it below 518 -> copy verbatim (lossless, never worsens a\n # low-res source). Square mode must always emit exactly target^2.\n if not square and shorter <= target:\n shutil.copy2(src_path, dst_path)\n return \"copied\"\n if im.mode not in (\"L\", \"RGB\"):\n im = im.convert(\"RGB\")\n # shorter axis EXACTLY = target; longer scales proportionally\n if w <= h:\n new_size = (target, round(h * target / w))\n else:\n new_size = (round(w * target / h), target)\n # square mode reproduces the processor exactly -> bicubic\n im = im.resize(new_size, Image.BICUBIC if square else Image.LANCZOS)\n if square:\n W, H = im.size\n left, top = (W - target) // 2, (H - target) // 2\n im = im.crop((left, top, left + target, top + target))\n im.save(dst_path, \"JPEG\", quality=quality, optimize=True, subsampling=0)\n return \"squared\" if square else \"resized\"\n except Exception as e:\n return f\"error:{type(e).__name__}: {e}\"\n\n\ndef _copy_one(src_path, dst_path):\n \"\"\"Copy non-image files (reports .txt, chexpert .csv, metadata .json, ...)\n verbatim so the shipped tree mirrors the source exactly.\"\"\"\n try:\n dst_path = Path(dst_path)\n if dst_path.exists() and dst_path.stat().st_size > 0:\n return \"skipped\"\n dst_path.parent.mkdir(parents=True, exist_ok=True)\n shutil.copy2(src_path, dst_path)\n return \"copied_other\"\n except Exception as e:\n return f\"error:{type(e).__name__}: {e}\"\n\n\ndef _walk_and_process(src: Path, dst: Path, target, quality, workers, square):\n \"\"\"Walk one src tree -> write resized/copied files into dst tree.\n Returns (counts, errors, n_img, n_other). Does NOT write manifest.\"\"\"\n img_jobs, other_jobs = [], []\n for root, _, files in os.walk(src):\n for fn in files:\n sp = Path(root) / fn\n rel = sp.relative_to(src)\n dp = dst / rel\n if fn.lower().endswith(IMG_EXTS):\n img_jobs.append((str(sp), str(dp)))\n else:\n other_jobs.append((str(sp), str(dp)))\n counts = {\"resized\": 0, \"squared\": 0, \"copied\": 0,\n \"copied_other\": 0, \"skipped\": 0, \"error\": 0}\n errors = []\n with ThreadPoolExecutor(max_workers=workers) as ex:\n futs = {}\n for s, d in img_jobs:\n futs[ex.submit(_resize_one, s, d, target, quality, square)] = d\n for s, d in other_jobs:\n futs[ex.submit(_copy_one, s, d)] = d\n for f in as_completed(futs):\n st = f.result()\n if st.startswith(\"error:\"):\n counts[\"error\"] += 1\n errors.append(f\"{futs[f]}\\t{st}\")\n else:\n counts[st] += 1\n return counts, errors, len(img_jobs), len(other_jobs)\n\n\ndef _write_manifest(dst: Path, *, src, target, quality, square,\n counts, errors, n_img, n_oth):\n dst.mkdir(parents=True, exist_ok=True)\n out_bytes = sum(p.stat().st_size for p in dst.rglob(\"*\") if p.is_file())\n total = n_img + n_oth\n (dst / \"_manifest.json\").write_text(json.dumps({\n \"source\": src, \"target\": target,\n \"mode\": \"square\" if square else \"shortest_edge\",\n \"jpeg_quality\": quality, \"subsampling\": \"4:4:4\",\n \"resampling\": \"BICUBIC\" if square else \"LANCZOS\",\n \"counts\": counts, \"total\": total,\n \"images\": n_img, \"non_image\": n_oth,\n \"output_bytes\": out_bytes,\n \"built_at\": time.strftime(\"%Y-%m-%dT%H:%M:%S\"),\n }, indent=2), encoding=\"utf-8\")\n if errors:\n (dst / \"_errors.txt\").write_text(\"\\n\".join(errors), encoding=\"utf-8\")\n print(f\"WARNING: {len(errors)} failures -> {dst/'_errors.txt'}\")\n print(f\"done: {counts}\")\n print(f\"output size: {out_bytes/1024**3:.2f} GB \"\n f\"({out_bytes/max(1,n_img)/1024:.0f} KB/image avg)\")\n\n\ndef resize_tree(src: Path, dst: Path, target, quality, workers, square):\n \"\"\"Standalone API: one src tree -> resized dst + manifest. (Not used by\n the streaming flow below; kept for parity with the script CLI.)\"\"\"\n print(f\"[resize] scanning {src} ...\")\n counts, errors, n_img, n_oth = _walk_and_process(\n src, dst, target, quality, workers, square)\n mode = f\"square {target}x{target}\" if square else f\"shortest-edge {target}px\"\n print(f\"[resize] {n_img:,} images + {n_oth:,} non-image -> {dst} \"\n f\"({mode}, q{quality}, {workers} threads)\")\n _write_manifest(dst, src=str(src), target=target, quality=quality,\n square=square, counts=counts, errors=errors,\n n_img=n_img, n_oth=n_oth)\n\n\ndef pack_shards(dst: Path, shards_dir: Path, shard_gb, prefix=\"cxr\"):\n shard_bytes = int(shard_gb * 1024**3)\n shards_dir.mkdir(parents=True, exist_ok=True)\n files = sorted(p for p in dst.rglob(\"*\")\n if p.is_file() and p.name not in (\"_manifest.json\", \"_errors.txt\"))\n if not files:\n raise SystemExit(f\"ERROR: nothing to pack under {dst}\")\n print(f\"[pack] {len(files):,} files -> tar shards (~{shard_gb} GB each)\")\n written, idx, cur = [], 0, 0\n\n def _open(i):\n path = shards_dir / f\"{prefix}-{i:04d}.tar\"\n written.append(path)\n return tarfile.open(path, \"w\")\n\n tar = _open(0)\n for fp in tqdm(files, unit=\"file\"):\n if cur >= shard_bytes:\n tar.close(); idx += 1; tar = _open(idx); cur = 0\n tar.add(fp, arcname=str(fp.relative_to(dst))) # rel path -> tree rebuilt on extract\n cur += fp.stat().st_size\n tar.close()\n man = dst / \"_manifest.json\"\n if man.exists():\n shutil.copy2(man, shards_dir / \"_manifest.json\")\n (shards_dir / \"SHARDS.txt\").write_text(\"\\n\".join(p.name for p in written), encoding=\"utf-8\")\n print(f\"[pack] wrote {len(written)} shards -> {shards_dir}\")\n return written\n\nprint(\"functions ready\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  },
51
  {
52
  "cell_type": "markdown",
53
  "id": "c07",
54
  "metadata": {},
55
- "source": "## 3. Stream source shards: download -> extract -> resize -> cleanup (per shard)\n\nLoops over every `MIMIC-CXR_processed/shards/*.tar` on HF and processes\nthem one at a time. For each shard: pull it down, extract into a scratch\ndir, run `_walk_and_process` to resize images + copy reports into the\ncumulative `DST_TREE`, then delete the tar and the scratch. Peak disk\nstays around ~10 GB regardless of total source size, and the run is\nfully resumable -- already-resized files are skipped.\n\nTar arcname layout auto-detected on the first shard: handles both\n`files/p10/...` (rooted at content) and `MIMIC-CXR_processed/files/...`\n(rooted at parent).\""
 
 
 
 
 
56
  },
57
  {
58
  "cell_type": "code",
 
59
  "id": "c08",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  "execution_count": null,
 
61
  "metadata": {},
62
  "outputs": [],
63
- "source": "import os, time\nfrom huggingface_hub import HfApi, hf_hub_download\n\napi = HfApi(token=HF_TOKEN)\nall_files = api.list_repo_files(REPO_ID, repo_type=REPO_TYPE)\nsrc_shards = sorted(f for f in all_files\n if f.startswith(f\"{SRC_SUBDIR}/shards/\") and f.endswith(\".tar\"))\nif not src_shards:\n raise SystemExit(\n f\"ERROR: no .tar shards under {SRC_SUBDIR}/shards/ on {REPO_ID}. \"\n f\"Check the path / your HF token has read access.\")\nprint(f\"found {len(src_shards)} source shards in {SRC_SUBDIR}/shards/\")\n\n# Per-shard 'done' markers, kept OUTSIDE DST_TREE so they're never packed.\n# A shard is marked done only after full success, so a resumed run skips\n# finished shards WITHOUT re-downloading them (~100 GB saved on a retry).\nDONE_DIR = WORK_DIR / \"_done_shards\"\nDONE_DIR.mkdir(parents=True, exist_ok=True)\n\n\ndef _download_retry(filename, retries=6, base_delay=5):\n \"\"\"hf_hub_download resumes partial downloads, so retrying after a dropped\n connection (ChunkedEncodingError / IncompleteRead) continues from where it\n broke rather than restarting. Linear backoff.\"\"\"\n for attempt in range(1, retries + 1):\n try:\n return hf_hub_download(\n repo_id=REPO_ID, repo_type=REPO_TYPE, filename=filename,\n local_dir=str(DL_DIR), token=HF_TOKEN)\n except Exception as e:\n if attempt == retries:\n raise\n wait = base_delay * attempt\n print(f\" [retry {attempt}/{retries}] {filename}: \"\n f\"{type(e).__name__}: {e} -> waiting {wait}s\")\n time.sleep(wait)\n\n\ndef _detect_content_root(extracted):\n \"\"\"Return the dir under `extracted` that holds `files/`. Handles arcnames\n rooted at 'files/...' (=> extracted itself) or\n '{SRC_SUBDIR}/files/...' (=> extracted/{SRC_SUBDIR}).\"\"\"\n if (extracted / \"files\").is_dir():\n return extracted\n cand = extracted / SRC_SUBDIR\n if (cand / \"files\").is_dir():\n return cand\n for p in extracted.rglob(\"files\"):\n if p.is_dir():\n return p.parent\n return extracted # last resort -- process whatever's there\n\n\ncum = {\"resized\": 0, \"squared\": 0, \"copied\": 0,\n \"copied_other\": 0, \"skipped\": 0, \"error\": 0}\nall_errors = []\ncontent_offset = None\n\nfor shard in tqdm(src_shards, unit=\"shard\", desc=\"shards\"):\n marker = DONE_DIR / (Path(shard).name + \".done\")\n if marker.exists():\n continue # done in a previous run -> skip\n # 1. Download this single shard (auto-retry + resume on broken connection)\n local_tar = _download_retry(shard)\n # 2. Fresh scratch + extract\n if SCRATCH.exists():\n shutil.rmtree(SCRATCH)\n SCRATCH.mkdir(parents=True, exist_ok=True)\n with tarfile.open(local_tar) as tf:\n tf.extractall(SCRATCH)\n # 3. Free the tar bytes\n os.remove(local_tar)\n # 4. Locate content root once (assume consistent across shards)\n content_root = _detect_content_root(SCRATCH)\n if content_offset is None:\n content_offset = str(content_root.relative_to(SCRATCH)) or \"<top>\"\n print(f\"[stream] tar content root: '{content_offset}/' \"\n f\"(arcnames rooted at {content_offset})\")\n # 5. Resize + copy this shard's tree into the cumulative DST_TREE\n counts, errors, n_img, n_oth = _walk_and_process(\n content_root, DST_TREE, TARGET, QUALITY, WORKERS, SQUARE)\n for k, v in counts.items():\n cum[k] += v\n all_errors.extend(errors)\n # 6. Free scratch + mark this shard done (only after full success)\n shutil.rmtree(SCRATCH)\n marker.write_text(\"ok\")\n\n# 7. Final manifest -- recount from the actual tree so totals are correct\n# even across resumed runs (cum reflects only THIS run's work).\nfinal_img = sum(1 for p in DST_TREE.rglob(\"*\")\n if p.is_file() and p.suffix.lower() in IMG_EXTS)\nfinal_oth = sum(1 for p in DST_TREE.rglob(\"*\")\n if p.is_file() and p.suffix.lower() not in IMG_EXTS\n and p.name not in (\"_manifest.json\", \"_errors.txt\"))\n_write_manifest(\n DST_TREE,\n src=f\"{REPO_ID}:{SRC_SUBDIR}/shards ({len(src_shards)} shards)\",\n target=TARGET, quality=QUALITY, square=SQUARE,\n counts=cum, errors=all_errors, n_img=final_img, n_oth=final_oth,\n)\nprint(f\"\\nresized tree -> {DST_TREE} ({final_img:,} images, {final_oth:,} non-image)\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  },
65
  {
66
  "cell_type": "markdown",
67
  "id": "c13",
68
  "metadata": {},
69
- "source": "## 4. Pack the resized tree into tar shards (output)"
 
 
70
  },
71
  {
72
  "cell_type": "code",
73
- "id": "c14",
74
  "execution_count": null,
 
75
  "metadata": {},
76
  "outputs": [],
77
- "source": "shards = pack_shards(DST_TREE, SHARDS_DIR, SHARD_GB)\nprint(\"\\n\".join(p.name for p in shards))"
 
 
 
78
  },
79
  {
80
  "cell_type": "markdown",
81
  "id": "c15",
82
  "metadata": {},
83
- "source": "## 5. Upload shards to HF (`MIMIC-CXR_resized/shards/`)\n\nMirrors the source layout: output sits at\n`hieu3636/cxr-vlm-data/MIMIC-CXR_resized/shards/cxr-NNNN.tar`.\""
 
 
84
  },
85
  {
86
  "cell_type": "code",
87
- "id": "c16",
88
  "execution_count": null,
 
89
  "metadata": {},
90
  "outputs": [],
91
- "source": "from huggingface_hub import HfApi\n\nHfApi(token=HF_TOKEN).upload_folder(\n folder_path = str(SHARDS_DIR),\n path_in_repo = f\"{DST_SUBDIR}/shards\", # mirror source: <subdir>/shards/\n repo_id = REPO_ID,\n repo_type = REPO_TYPE,\n token = HF_TOKEN,\n commit_message = f\"Add resized+sharded dataset ({DST_SUBDIR}, target={TARGET}, square={SQUARE})\",\n)\nprint(f\"OK: pushed -> https://huggingface.co/datasets/{REPO_ID}/tree/main/{DST_SUBDIR}/shards\")"
 
 
 
 
 
 
 
 
 
 
 
 
92
  },
93
  {
94
  "cell_type": "markdown",
95
  "id": "c17",
96
  "metadata": {},
97
- "source": "## Done. On the training box, consume it like this\n\n```python\nfrom huggingface_hub import snapshot_download\nimport glob, tarfile, os\n\nDST = \"/workspace/MIMIC-CXR_resized\"\ndl = snapshot_download(\"hieu3636/cxr-vlm-data\", repo_type=\"dataset\",\n allow_patterns=\"MIMIC-CXR_resized/shards/*.tar\",\n local_dir=\"/workspace/dl\")\nos.makedirs(DST, exist_ok=True)\nfor t in sorted(glob.glob(\"/workspace/dl/MIMIC-CXR_resized/shards/*.tar\")):\n with tarfile.open(t) as tf:\n tf.extractall(DST)\n# -> DST now holds files/p10/... (same tree as the original, smaller JPGs)\n```\n\nThen point training at it -- edit `configs/train_config.yaml`:\n\n```yaml\nmimic_cxr_root: /workspace/MIMIC-CXR_resized\n```\n\nNo change to `dataset.py` / `cxr_vlm.py` -- the image tree is identical,\nonly the JPGs are smaller; reports / chexpert.csv come through too so\nauto-build of the instruct JSON works if needed. Extract once per VM\nsession, then train any number of epochs from the extracted tree.\n\n(Equivalent CLI using the repo script: `python scripts/build_resized_dataset.py\n--extract \"/workspace/dl/MIMIC-CXR_resized/shards/*.tar\" /workspace/MIMIC-CXR_resized`.)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  }
99
  ],
100
  "metadata": {
@@ -109,4 +462,4 @@
109
  },
110
  "nbformat": 4,
111
  "nbformat_minor": 5
112
- }
 
4
  "cell_type": "markdown",
5
  "id": "c00",
6
  "metadata": {},
7
+ "source": [
8
+ "# CXR-VLM -- Resize + tar-shard dataset (one-time, offline)\n",
9
+ "\n",
10
+ "Shrinks the original MIMIC-CXR tree (~2-3 MP/image) to RAD-DINO's working\n",
11
+ "resolution and packs it into a few tar shards, so cloud training boxes\n",
12
+ "(Vast.ai / Lightning.ai / Colab) pull a small, transfer-friendly dataset\n",
13
+ "instead of ~100 GB of huge JPGs read every epoch.\n",
14
+ "\n",
15
+ "**Source / destination (HF dataset repo `hieu3636/cxr-vlm-data`):**\n",
16
+ "- read : `MIMIC-CXR_processed/` (tree `files/p{10-19}/.../*.jpg`)\n",
17
+ "- write : `MIMIC-CXR_resized/` (tar shards `cxr-NNNN.tar` + manifest)\n",
18
+ "\n",
19
+ "**Why HF and not Google Drive for the transfer:** this notebook is meant to\n",
20
+ "run on arbitrary cloud GPUs. Drive only mounts conveniently on Colab and\n",
21
+ "rate-limits badly on bulk many-file reads. HF works everywhere with just a\n",
22
+ "token, `snapshot_download` is parallel + resumable, and the data already\n",
23
+ "lives there. So both download and upload go through HF.\n",
24
+ "\n",
25
+ "**This step does NOT change what the model sees** -- RAD-DINO's processor\n",
26
+ "resizes the shortest edge to 518 and center-crops 518x518 anyway; we just do\n",
27
+ "that downscale once, offline, instead of every epoch on full-res images.\n",
28
+ "See `scripts/build_resized_dataset.py` (this notebook mirrors its logic).\n",
29
+ "\n",
30
+ "**Prerequisite:** an `HF_TOKEN` with **write** access to the repo.\n",
31
+ "\n",
32
+ "**Disk:** needs room for source (~100 GB) + resized tree (~5-8 GB) +\n",
33
+ "shards (~5-8 GB). Set `DELETE_SOURCE_AFTER_RESIZE=True` to free the ~100 GB\n",
34
+ "before packing if the box is tight."
35
+ ]
36
  },
37
  {
38
  "cell_type": "markdown",
39
  "id": "c01",
40
  "metadata": {},
41
+ "source": [
42
+ "## 0. Config -- edit here"
43
+ ]
44
  },
45
  {
46
  "cell_type": "code",
 
47
  "execution_count": null,
48
+ "id": "c02",
49
  "metadata": {},
50
  "outputs": [],
51
+ "source": [
52
+ "import os\n",
53
+ "from pathlib import Path\n",
54
+ "\n",
55
+ "REPO_ID = \"hieu3636/cxr-vlm-data\"\n",
56
+ "REPO_TYPE = \"dataset\"\n",
57
+ "SRC_SUBDIR = \"MIMIC-CXR_processed\" # folder in the repo holding files/p{10-19}/...\n",
58
+ "DST_SUBDIR = \"MIMIC-CXR_resized\" # where the shards get uploaded back\n",
59
+ "\n",
60
+ "# Big scratch disk on the VM (Vast/Lightning: /workspace, Colab: /content).\n",
61
+ "WORK_DIR = Path(os.environ.get(\"WORK_DIR\", \"/content/cxr_resize\"))\n",
62
+ "\n",
63
+ "# --- resize params -------------------------------------------------------\n",
64
+ "TARGET = 518 # shortest-edge target. MUST be >= 518 (RAD-DINO crops 518).\n",
65
+ "SQUARE = True # False: keep aspect (518xN), flexible, processor crops at\n",
66
+ " # train time. ~20% bigger.\n",
67
+ " # True : also center-crop to 518x518 here -> file is exactly\n",
68
+ " # 518x518 and the processor is a true no-op. Smaller,\n",
69
+ " # but BAKES the crop (changing backbone/img_size later\n",
70
+ " # needs a full rebuild). Recommended off for a thesis.\n",
71
+ "QUALITY = 90 # JPEG quality (q90 + 4:4:4 = near-lossless for CXR)\n",
72
+ "SHARD_GB = 2.0 # approx GB per tar shard\n",
73
+ "WORKERS = min(32, (os.cpu_count() or 8) * 4) # I/O-bound; PIL frees the GIL\n",
74
+ "\n",
75
+ "DELETE_SOURCE_AFTER_RESIZE = False # True to free ~100 GB before packing\n",
76
+ "\n",
77
+ "# Derived local paths\n",
78
+ "DL_DIR = WORK_DIR / \"download\" # snapshot_download target\n",
79
+ "SRC_TREE = DL_DIR / SRC_SUBDIR # contains files/p10/...\n",
80
+ "DST_TREE = WORK_DIR / \"resized\" / SRC_SUBDIR # mirrors files/p10/...\n",
81
+ "SHARDS_DIR = WORK_DIR / \"shards\"\n",
82
+ "for p in (WORK_DIR, DL_DIR, SHARDS_DIR):\n",
83
+ " p.mkdir(parents=True, exist_ok=True)\n",
84
+ "\n",
85
+ "assert TARGET >= 518, \"TARGET must be >= 518 (RAD-DINO upscales shortest edge to 518)\"\n",
86
+ "print(\"WORK_DIR:\", WORK_DIR, \"| TARGET:\", TARGET, \"| SQUARE:\", SQUARE, \"| WORKERS:\", WORKERS)"
87
+ ]
88
  },
89
  {
90
  "cell_type": "markdown",
91
  "id": "c03",
92
  "metadata": {},
93
+ "source": [
94
+ "## 1. Setup -- deps + HF token\n",
95
+ "\n",
96
+ "Token resolution: env `HF_TOKEN` -> Colab `userdata` -> Kaggle secret."
97
+ ]
98
  },
99
  {
100
  "cell_type": "code",
 
101
  "execution_count": null,
102
+ "id": "c04",
103
  "metadata": {},
104
  "outputs": [],
105
+ "source": [
106
+ "import sys, subprocess\n",
107
+ "subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\",\n",
108
+ " \"huggingface_hub>=0.24,<0.27\", \"Pillow>=10\", \"tqdm\"], check=True)\n",
109
+ "\n",
110
+ "if not os.environ.get(\"HF_TOKEN\"):\n",
111
+ " try:\n",
112
+ " from google.colab import userdata\n",
113
+ " os.environ[\"HF_TOKEN\"] = userdata.get(\"HF_TOKEN\")\n",
114
+ " except Exception:\n",
115
+ " try:\n",
116
+ " from kaggle_secrets import UserSecretsClient\n",
117
+ " os.environ[\"HF_TOKEN\"] = UserSecretsClient().get_secret(\"HF_TOKEN\")\n",
118
+ " except Exception:\n",
119
+ " pass\n",
120
+ "\n",
121
+ "HF_TOKEN = os.environ.get(\"HF_TOKEN\")\n",
122
+ "assert HF_TOKEN, \"HF_TOKEN missing -- set it via env var or platform secrets (needs WRITE access).\"\n",
123
+ "print(\"HF_TOKEN loaded OK\")"
124
+ ]
125
  },
126
  {
127
  "cell_type": "markdown",
128
  "id": "c05",
129
  "metadata": {},
130
+ "source": [
131
+ "## 2. Resize + pack logic (inlined, mirrors `scripts/build_resized_dataset.py`)\n",
132
+ "\n",
133
+ "Uses a thread pool (not processes): PIL releases the GIL during\n",
134
+ "decode/resize/encode, so threads parallelise well and avoid any\n",
135
+ "notebook multiprocessing pickling issues. Self-contained -- safe to\n",
136
+ "re-run this cell alone."
137
+ ]
138
  },
139
  {
140
  "cell_type": "code",
 
141
  "execution_count": null,
142
+ "id": "c06",
143
  "metadata": {},
144
  "outputs": [],
145
+ "source": [
146
+ "import os, json, shutil, tarfile, time\n",
147
+ "from pathlib import Path\n",
148
+ "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
149
+ "from PIL import Image\n",
150
+ "from tqdm.auto import tqdm\n",
151
+ "\n",
152
+ "Image.MAX_IMAGE_PIXELS = None # don't abort on large medical images\n",
153
+ "IMG_EXTS = (\".jpg\", \".jpeg\", \".png\")\n",
154
+ "\n",
155
+ "\n",
156
+ "def _resize_one(src_path, dst_path, target, quality, square):\n",
157
+ " \"\"\"Returns one of: resized | squared | copied | skipped | error:<msg>.\"\"\"\n",
158
+ " try:\n",
159
+ " dst_path = Path(dst_path)\n",
160
+ " if dst_path.exists() and dst_path.stat().st_size > 0:\n",
161
+ " return \"skipped\" # resumable\n",
162
+ " dst_path.parent.mkdir(parents=True, exist_ok=True)\n",
163
+ " with Image.open(src_path) as im:\n",
164
+ " w, h = im.size\n",
165
+ " shorter = min(w, h)\n",
166
+ " # Non-square: if shorter side already <= target, downscaling would\n",
167
+ " # push it below 518 -> copy verbatim (lossless, never worsens a\n",
168
+ " # low-res source). Square mode must always emit exactly target^2.\n",
169
+ " if not square and shorter <= target:\n",
170
+ " shutil.copy2(src_path, dst_path)\n",
171
+ " return \"copied\"\n",
172
+ " if im.mode not in (\"L\", \"RGB\"):\n",
173
+ " im = im.convert(\"RGB\")\n",
174
+ " # shorter axis EXACTLY = target; longer scales proportionally\n",
175
+ " if w <= h:\n",
176
+ " new_size = (target, round(h * target / w))\n",
177
+ " else:\n",
178
+ " new_size = (round(w * target / h), target)\n",
179
+ " # square mode reproduces the processor exactly -> bicubic\n",
180
+ " im = im.resize(new_size, Image.BICUBIC if square else Image.LANCZOS)\n",
181
+ " if square:\n",
182
+ " W, H = im.size\n",
183
+ " left, top = (W - target) // 2, (H - target) // 2\n",
184
+ " im = im.crop((left, top, left + target, top + target))\n",
185
+ " im.save(dst_path, \"JPEG\", quality=quality, optimize=True, subsampling=0)\n",
186
+ " return \"squared\" if square else \"resized\"\n",
187
+ " except Exception as e:\n",
188
+ " return f\"error:{type(e).__name__}: {e}\"\n",
189
+ "\n",
190
+ "\n",
191
+ "def _copy_one(src_path, dst_path):\n",
192
+ " \"\"\"Copy non-image files (reports .txt, chexpert .csv, metadata .json, ...)\n",
193
+ " verbatim so the shipped tree mirrors MIMIC-CXR_processed exactly.\"\"\"\n",
194
+ " try:\n",
195
+ " dst_path = Path(dst_path)\n",
196
+ " if dst_path.exists() and dst_path.stat().st_size > 0:\n",
197
+ " return \"skipped\"\n",
198
+ " dst_path.parent.mkdir(parents=True, exist_ok=True)\n",
199
+ " shutil.copy2(src_path, dst_path)\n",
200
+ " return \"copied_other\"\n",
201
+ " except Exception as e:\n",
202
+ " return f\"error:{type(e).__name__}: {e}\"\n",
203
+ "\n",
204
+ "\n",
205
+ "def resize_tree(src: Path, dst: Path, target, quality, workers, square):\n",
206
+ " print(f\"[resize] scanning {src} ...\")\n",
207
+ " img_jobs, other_jobs = [], []\n",
208
+ " for root, _, files in os.walk(src):\n",
209
+ " for fn in files:\n",
210
+ " sp = Path(root) / fn\n",
211
+ " rel = sp.relative_to(src)\n",
212
+ " dp = dst / rel\n",
213
+ " if fn.lower().endswith(IMG_EXTS):\n",
214
+ " img_jobs.append((str(sp), str(dp)))\n",
215
+ " else:\n",
216
+ " other_jobs.append((str(sp), str(dp)))\n",
217
+ " if not img_jobs and not other_jobs:\n",
218
+ " raise SystemExit(f\"ERROR: nothing found under {src}\")\n",
219
+ " mode = f\"square {target}x{target}\" if square else f\"shortest-edge {target}px\"\n",
220
+ " print(f\"[resize] {len(img_jobs):,} images + {len(other_jobs):,} non-image \"\n",
221
+ " f\"-> {dst} ({mode}, q{quality}, {workers} threads)\")\n",
222
+ "\n",
223
+ " counts = {\"resized\": 0, \"squared\": 0, \"copied\": 0,\n",
224
+ " \"copied_other\": 0, \"skipped\": 0, \"error\": 0}\n",
225
+ " errors = []\n",
226
+ " with ThreadPoolExecutor(max_workers=workers) as ex:\n",
227
+ " futs = {}\n",
228
+ " for s, d in img_jobs:\n",
229
+ " futs[ex.submit(_resize_one, s, d, target, quality, square)] = d\n",
230
+ " for s, d in other_jobs:\n",
231
+ " futs[ex.submit(_copy_one, s, d)] = d\n",
232
+ " for f in tqdm(as_completed(futs), total=len(futs), unit=\"file\"):\n",
233
+ " st = f.result()\n",
234
+ " if st.startswith(\"error:\"):\n",
235
+ " counts[\"error\"] += 1\n",
236
+ " errors.append(f\"{futs[f]}\\t{st}\")\n",
237
+ " else:\n",
238
+ " counts[st] += 1\n",
239
+ "\n",
240
+ " dst.mkdir(parents=True, exist_ok=True)\n",
241
+ " total = len(img_jobs) + len(other_jobs)\n",
242
+ " out_bytes = sum(p.stat().st_size for p in dst.rglob(\"*\") if p.is_file())\n",
243
+ " (dst / \"_manifest.json\").write_text(json.dumps({\n",
244
+ " \"source\": str(src), \"target\": target,\n",
245
+ " \"mode\": \"square\" if square else \"shortest_edge\",\n",
246
+ " \"jpeg_quality\": quality, \"subsampling\": \"4:4:4\",\n",
247
+ " \"resampling\": \"BICUBIC\" if square else \"LANCZOS\",\n",
248
+ " \"counts\": counts, \"total\": total,\n",
249
+ " \"images\": len(img_jobs), \"non_image\": len(other_jobs),\n",
250
+ " \"output_bytes\": out_bytes,\n",
251
+ " \"built_at\": time.strftime(\"%Y-%m-%dT%H:%M:%S\"),\n",
252
+ " }, indent=2), encoding=\"utf-8\")\n",
253
+ " if errors:\n",
254
+ " (dst / \"_errors.txt\").write_text(\"\\n\".join(errors), encoding=\"utf-8\")\n",
255
+ " print(f\"[resize] WARNING: {len(errors)} failures -> {dst/'_errors.txt'}\")\n",
256
+ " print(f\"[resize] done: {counts}\")\n",
257
+ " print(f\"[resize] output size: {out_bytes/1024**3:.2f} GB \"\n",
258
+ " f\"({out_bytes/max(1,len(img_jobs))/1024:.0f} KB/image avg)\")\n",
259
+ "\n",
260
+ "\n",
261
+ "def pack_shards(dst: Path, shards_dir: Path, shard_gb, prefix=\"cxr\"):\n",
262
+ " shard_bytes = int(shard_gb * 1024**3)\n",
263
+ " shards_dir.mkdir(parents=True, exist_ok=True)\n",
264
+ " files = sorted(p for p in dst.rglob(\"*\")\n",
265
+ " if p.is_file() and p.name not in (\"_manifest.json\", \"_errors.txt\"))\n",
266
+ " if not files:\n",
267
+ " raise SystemExit(f\"ERROR: nothing to pack under {dst}\")\n",
268
+ " print(f\"[pack] {len(files):,} files -> tar shards (~{shard_gb} GB each)\")\n",
269
+ " written, idx, cur = [], 0, 0\n",
270
+ "\n",
271
+ " def _open(i):\n",
272
+ " path = shards_dir / f\"{prefix}-{i:04d}.tar\"\n",
273
+ " written.append(path)\n",
274
+ " return tarfile.open(path, \"w\")\n",
275
+ "\n",
276
+ " tar = _open(0)\n",
277
+ " for fp in tqdm(files, unit=\"file\"):\n",
278
+ " if cur >= shard_bytes:\n",
279
+ " tar.close(); idx += 1; tar = _open(idx); cur = 0\n",
280
+ " tar.add(fp, arcname=str(fp.relative_to(dst))) # rel path -> tree rebuilt on extract\n",
281
+ " cur += fp.stat().st_size\n",
282
+ " tar.close()\n",
283
+ " man = dst / \"_manifest.json\"\n",
284
+ " if man.exists():\n",
285
+ " shutil.copy2(man, shards_dir / \"_manifest.json\")\n",
286
+ " (shards_dir / \"SHARDS.txt\").write_text(\"\\n\".join(p.name for p in written), encoding=\"utf-8\")\n",
287
+ " print(f\"[pack] wrote {len(written)} shards -> {shards_dir}\")\n",
288
+ " return written\n",
289
+ "\n",
290
+ "print(\"functions ready\")"
291
+ ]
292
  },
293
  {
294
  "cell_type": "markdown",
295
  "id": "c07",
296
  "metadata": {},
297
+ "source": [
298
+ "## 3. Download source from HF (`MIMIC-CXR_processed/`)\n",
299
+ "\n",
300
+ "Parallel + resumable. Re-running skips already-downloaded files. This is\n",
301
+ "the slow step (~100 GB of full-res JPGs)."
302
+ ]
303
  },
304
  {
305
  "cell_type": "code",
306
+ "execution_count": null,
307
  "id": "c08",
308
+ "metadata": {},
309
+ "outputs": [],
310
+ "source": [
311
+ "from huggingface_hub import snapshot_download\n",
312
+ "\n",
313
+ "local = snapshot_download(\n",
314
+ " repo_id = REPO_ID,\n",
315
+ " repo_type = REPO_TYPE,\n",
316
+ " allow_patterns = f\"{SRC_SUBDIR}/**\", # only the source folder\n",
317
+ " local_dir = str(DL_DIR),\n",
318
+ " token = HF_TOKEN,\n",
319
+ " max_workers = 16,\n",
320
+ ")\n",
321
+ "assert SRC_TREE.is_dir(), f\"expected {SRC_TREE} after download, not found\"\n",
322
+ "n = sum(1 for _ in SRC_TREE.rglob(\"*\") if _.suffix.lower() in IMG_EXTS)\n",
323
+ "print(f\"downloaded -> {SRC_TREE} ({n:,} images)\")"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "markdown",
328
+ "id": "c09",
329
+ "metadata": {},
330
+ "source": [
331
+ "## 4. Resize\n",
332
+ "\n",
333
+ "Reads the printed `output size` line + writes `_manifest.json` so you get\n",
334
+ "the real GB on your actual data. Resumable -- safe to re-run."
335
+ ]
336
+ },
337
+ {
338
+ "cell_type": "code",
339
  "execution_count": null,
340
+ "id": "c10",
341
  "metadata": {},
342
  "outputs": [],
343
+ "source": [
344
+ "resize_tree(SRC_TREE, DST_TREE, TARGET, QUALITY, WORKERS, SQUARE)"
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "markdown",
349
+ "id": "c11",
350
+ "metadata": {},
351
+ "source": [
352
+ "## 5. (Optional) Free the ~100 GB source before packing"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "code",
357
+ "execution_count": null,
358
+ "id": "c12",
359
+ "metadata": {},
360
+ "outputs": [],
361
+ "source": [
362
+ "if DELETE_SOURCE_AFTER_RESIZE:\n",
363
+ " shutil.rmtree(DL_DIR, ignore_errors=True)\n",
364
+ " print(\"removed source download dir:\", DL_DIR)\n",
365
+ "else:\n",
366
+ " print(\"keeping source (set DELETE_SOURCE_AFTER_RESIZE=True to free disk)\")"
367
+ ]
368
  },
369
  {
370
  "cell_type": "markdown",
371
  "id": "c13",
372
  "metadata": {},
373
+ "source": [
374
+ "## 6. Pack into tar shards"
375
+ ]
376
  },
377
  {
378
  "cell_type": "code",
 
379
  "execution_count": null,
380
+ "id": "c14",
381
  "metadata": {},
382
  "outputs": [],
383
+ "source": [
384
+ "shards = pack_shards(DST_TREE, SHARDS_DIR, SHARD_GB)\n",
385
+ "print(\"\\n\".join(p.name for p in shards))"
386
+ ]
387
  },
388
  {
389
  "cell_type": "markdown",
390
  "id": "c15",
391
  "metadata": {},
392
+ "source": [
393
+ "## 7. Upload shards to HF (`MIMIC-CXR_resized/`)"
394
+ ]
395
  },
396
  {
397
  "cell_type": "code",
 
398
  "execution_count": null,
399
+ "id": "c16",
400
  "metadata": {},
401
  "outputs": [],
402
+ "source": [
403
+ "from huggingface_hub import HfApi\n",
404
+ "\n",
405
+ "HfApi(token=HF_TOKEN).upload_folder(\n",
406
+ " folder_path = str(SHARDS_DIR),\n",
407
+ " path_in_repo = DST_SUBDIR,\n",
408
+ " repo_id = REPO_ID,\n",
409
+ " repo_type = REPO_TYPE,\n",
410
+ " token = HF_TOKEN,\n",
411
+ " commit_message = f\"Add resized+sharded dataset ({DST_SUBDIR}, target={TARGET}, square={SQUARE})\",\n",
412
+ ")\n",
413
+ "print(f\"OK: pushed -> https://huggingface.co/datasets/{REPO_ID}/tree/main/{DST_SUBDIR}\")"
414
+ ]
415
  },
416
  {
417
  "cell_type": "markdown",
418
  "id": "c17",
419
  "metadata": {},
420
+ "source": [
421
+ "## Done. On the training box, consume it like this\n",
422
+ "\n",
423
+ "```python\n",
424
+ "from huggingface_hub import snapshot_download\n",
425
+ "import glob, tarfile, os\n",
426
+ "\n",
427
+ "DST = \"/workspace/MIMIC-CXR_resized\"\n",
428
+ "dl = snapshot_download(\"hieu3636/cxr-vlm-data\", repo_type=\"dataset\",\n",
429
+ " allow_patterns=\"MIMIC-CXR_resized/**\",\n",
430
+ " local_dir=\"/workspace/dl\")\n",
431
+ "os.makedirs(DST, exist_ok=True)\n",
432
+ "for t in sorted(glob.glob(\"/workspace/dl/MIMIC-CXR_resized/*.tar\")):\n",
433
+ " with tarfile.open(t) as tf:\n",
434
+ " tf.extractall(DST)\n",
435
+ "# -> DST now holds files/p10/... (same tree as the original)\n",
436
+ "```\n",
437
+ "\n",
438
+ "Then point training at it -- edit `configs/train_config.yaml`:\n",
439
+ "\n",
440
+ "```yaml\n",
441
+ "mimic_cxr_root: /workspace/MIMIC-CXR_resized\n",
442
+ "```\n",
443
+ "\n",
444
+ "No change to `dataset.py` / `cxr_vlm.py` -- the image tree is identical,\n",
445
+ "only the JPGs are smaller. Extract once per VM session, then train any\n",
446
+ "number of epochs from the extracted tree.\n",
447
+ "\n",
448
+ "(Equivalent CLI using the repo script: `python scripts/build_resized_dataset.py\n",
449
+ "--extract \"/workspace/dl/MIMIC-CXR_resized/*.tar\" /workspace/MIMIC-CXR_resized`.)"
450
+ ]
451
  }
452
  ],
453
  "metadata": {
 
462
  },
463
  "nbformat": 4,
464
  "nbformat_minor": 5
465
+ }