""" build_resized_dataset.py ------------------------ One-time data shrinker for cloud training (Vast.ai / Lightning.ai / Colab). MIMIC-CXR-JPG originals are ~2-3 MP each; RAD-DINO downscales to ~518 px internally anyway, so storing full-res images just wastes I/O. This script re-encodes every JPG to a small longer-side cap, preserving the EXACT directory tree so you only have to re-point `data.mimic_cxr_root` at the output -- no change to dataset.py / cxr_vlm.py. Why image-resize and NOT feature-cache: a frozen-encoder feature tensor is ~2 MB/image (1369x768 fp16, incompressible) -- larger than the source JPG. The encoder is also only ~1-2% of per-step compute (Vicuna-7B dominates), so caching it barely speeds training. Shrinking the JPG instead removes the real bottleneck (decode of huge images) at ~1/30th the storage, with no architecture risk and augmentation still possible later. Pipeline (each step skippable): 1. resize : src tree -> dst tree (only downscales; skips up-to-date files, resumable) 2. pack : dst tree -> tar shards (~2 GB each, keeps the tree on extract) 3. push : shards -> HF Hub private dataset repo Usage (from project root): # resize + pack python scripts/build_resized_dataset.py \ --src /data/MIMIC-CXR --dst /data/MIMIC-CXR-518 # resize + pack + push to HF $env:HF_TOKEN='hf_xxx' python scripts/build_resized_dataset.py \ --src /data/MIMIC-CXR --dst /data/MIMIC-CXR-518 \ --push --hf_repo /cxr-vlm-data-518 # on the training box: pull shards then rebuild the tree onto fast NVMe python scripts/build_resized_dataset.py --extract "shards/*.tar" /content/MIMIC-CXR-518 # -> set data.mimic_cxr_root: /content/MIMIC-CXR-518 """ from __future__ import annotations import argparse import glob import json import os import shutil import sys import tarfile import time from concurrent.futures import ProcessPoolExecutor, as_completed from pathlib import Path from PIL import Image from tqdm import tqdm # CXR can be large; don't let Pillow's bomb guard abort on legit medical images. Image.MAX_IMAGE_PIXELS = None PROJECT_ROOT = Path(__file__).resolve().parents[1] IMG_EXTS = (".jpg", ".jpeg", ".png") # -- Phase 1: resize --------------------------------------------------------- def _resize_one(args) -> tuple[str, str]: """Worker: resize a single image. Returns (status, rel_path). status is one of: "resized", "squared", "copied", "skipped", "error:". "copied" = source shorter side already <= target (non-square mode only); re-encoding would only lose quality. "skipped" = up-to-date output already exists (makes the run resumable). Two modes: default : resize shortest edge -> target, KEEP aspect ratio. The RAD-DINO processor will center-crop to 518x518 at train time. Flexible (crop/backbone choices stay open), ~20% bigger than square. --square : also replicate the processor's center-crop here, so every file is exactly target x target and the processor becomes a true no-op. Geometry is IDENTICAL to baseline (we reproduce its resize+crop, not a distorting squash). Bakes the crop in -> changing crop/img_size/backbone later needs a rebuild. """ src_path, dst_path, rel, target, quality, square = args try: dst_path = Path(dst_path) if dst_path.exists() and dst_path.stat().st_size > 0: return "skipped", rel dst_path.parent.mkdir(parents=True, exist_ok=True) with Image.open(src_path) as im: w, h = im.size shorter = min(w, h) # Non-square: if shorter side already <= target, downscaling would # push it below 518 -> copy verbatim (lossless, never worsens a # low-res source). In square mode we must always produce exactly # target^2, replicating the processor (which itself upscales a # sub-518 image), so don't short-circuit there. if not square and shorter <= target: shutil.copy2(src_path, dst_path) return "copied", rel # Match training-time load (dataset.py does .convert("RGB")); # collapse exotic modes so JPEG save can't fail. if im.mode not in ("L", "RGB"): im = im.convert("RGB") # Resize shorter axis EXACTLY to target (no rounding drift below # it); longer axis scales proportionally. if w <= h: new_size = (target, round(h * target / w)) else: new_size = (round(w * target / h), target) # square mode mirrors the processor exactly -> bicubic (resample=3) # so this output IS what the processor would have produced. im = im.resize(new_size, Image.BICUBIC if square else Image.LANCZOS) if square: W, H = im.size left, top = (W - target) // 2, (H - target) // 2 im = im.crop((left, top, left + target, top + target)) # subsampling=0 (4:4:4) preserves thin findings (e.g. pneumothorax line). im.save(dst_path, "JPEG", quality=quality, optimize=True, subsampling=0) return ("squared" if square else "resized"), rel except Exception as e: # corrupt/unreadable source -- log & continue return f"error:{type(e).__name__}: {e}", rel def _copy_one(args) -> tuple[str, str]: """Worker: copy a non-image file verbatim, preserving the tree. Used for reports (.txt), CheXpert labels (.csv), metadata (.json) and anything else interleaved in the source tree -- so the tar shards carry a complete copy of MIMIC-CXR_processed, not just images. """ src_path, dst_path, rel = args try: dst_path = Path(dst_path) if dst_path.exists() and dst_path.stat().st_size > 0: return "skipped", rel dst_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(src_path, dst_path) return "copied_other", rel except Exception as e: return f"error:{type(e).__name__}: {e}", rel def resize_tree(src: Path, dst: Path, target: int, quality: int, workers: int, square: bool) -> None: print(f"[resize] scanning {src} ...") img_jobs, other_jobs = [], [] for root, _, files in os.walk(src): for fn in files: sp = Path(root) / fn rel = sp.relative_to(src) dp = dst / rel if fn.lower().endswith(IMG_EXTS): img_jobs.append((str(sp), str(dp), str(rel), target, quality, square)) else: # non-image: reports/csv/json/etc. copied verbatim so the # shipped tree mirrors the source exactly (no data loss). other_jobs.append((str(sp), str(dp), str(rel))) if not img_jobs and not other_jobs: sys.exit(f"ERROR: nothing found under {src}") mode = f"square {target}x{target}" if square else f"shortest-edge {target}px" print(f"[resize] {len(img_jobs):,} images + {len(other_jobs):,} non-image " f"-> {dst} ({mode}, q{quality}, {workers} workers)") counts = {"resized": 0, "squared": 0, "copied": 0, "copied_other": 0, "skipped": 0, "error": 0} errors: list[str] = [] with ProcessPoolExecutor(max_workers=workers) as ex: futs = [ex.submit(_resize_one, j) for j in img_jobs] futs += [ex.submit(_copy_one, j) for j in other_jobs] for f in tqdm(as_completed(futs), total=len(futs), unit="file"): status, rel = f.result() if status.startswith("error:"): counts["error"] += 1 errors.append(f"{rel}\t{status}") else: counts[status] += 1 dst.mkdir(parents=True, exist_ok=True) total = len(img_jobs) + len(other_jobs) out_bytes = sum(p.stat().st_size for p in dst.rglob("*") if p.is_file()) (dst / "_manifest.json").write_text(json.dumps({ "source": str(src), "target": target, "mode": "square" if square else "shortest_edge", "jpeg_quality": quality, "subsampling": "4:4:4", "resampling": "BICUBIC" if square else "LANCZOS", "counts": counts, "total": total, "images": len(img_jobs), "non_image": len(other_jobs), "output_bytes": out_bytes, "built_at": time.strftime("%Y-%m-%dT%H:%M:%S"), }, indent=2), encoding="utf-8") if errors: (dst / "_errors.txt").write_text("\n".join(errors), encoding="utf-8") print(f"[resize] WARNING: {len(errors)} failures -> {dst/'_errors.txt'}") print(f"[resize] done: {counts}") print(f"[resize] output size: {out_bytes / 1024**3:.2f} GB " f"({out_bytes / max(1, len(img_jobs)) / 1024:.0f} KB/image avg)") # -- Phase 2: pack into tar shards ------------------------------------------- def pack_shards(dst: Path, shards_dir: Path, prefix: str, shard_gb: float) -> list[Path]: shard_bytes = int(shard_gb * (1024 ** 3)) shards_dir.mkdir(parents=True, exist_ok=True) files = sorted( p for p in dst.rglob("*") if p.is_file() and p.name not in ("_manifest.json", "_errors.txt") ) if not files: sys.exit(f"ERROR: nothing to pack under {dst} (run resize first)") print(f"[pack] {len(files):,} files -> tar shards (~{shard_gb} GB each) in {shards_dir}") written: list[Path] = [] idx, cur_bytes = 0, 0 def _open(i: int) -> tarfile.TarFile: path = shards_dir / f"{prefix}-{i:04d}.tar" written.append(path) return tarfile.open(path, "w") tar = _open(0) for fp in tqdm(files, unit="file"): if cur_bytes >= shard_bytes: tar.close() idx += 1 tar = _open(idx) cur_bytes = 0 # arcname = path relative to dst -> extracting any shard rebuilds the tree. tar.add(fp, arcname=str(fp.relative_to(dst))) cur_bytes += fp.stat().st_size tar.close() # ship the manifest alongside the shards (not inside them) man = dst / "_manifest.json" if man.exists(): shutil.copy2(man, shards_dir / "_manifest.json") (shards_dir / "SHARDS.txt").write_text( "\n".join(p.name for p in written), encoding="utf-8") print(f"[pack] wrote {len(written)} shards -> {shards_dir}") return written # -- Phase 3: push to HF Hub ------------------------------------------------- def push_hf(shards_dir: Path, repo_id: str, path_in_repo: str, private: bool) -> None: token = os.environ.get("HF_TOKEN") if not token: sys.exit("ERROR: --push needs HF_TOKEN env var (write-scope token).") from huggingface_hub import HfApi, create_repo print(f"[push] {shards_dir} -> {repo_id}:{path_in_repo}") create_repo(repo_id, repo_type="dataset", private=private, token=token, exist_ok=True) HfApi(token=token).upload_folder( folder_path=str(shards_dir), path_in_repo=path_in_repo, repo_id=repo_id, repo_type="dataset", token=token, ) print(f"OK: pushed -> https://huggingface.co/datasets/{repo_id}") # -- Extract helper (run on the training box) -------------------------------- def extract_shards(pattern: str, dest: Path) -> None: tars = sorted(glob.glob(pattern)) if not tars: sys.exit(f"ERROR: no tar shards match {pattern!r}") dest.mkdir(parents=True, exist_ok=True) print(f"[extract] {len(tars)} shards -> {dest}") for t in tqdm(tars, unit="shard"): with tarfile.open(t, "r") as tf: tf.extractall(dest) print(f"[extract] done. Set data.mimic_cxr_root: {dest}") # -- CLI --------------------------------------------------------------------- def parse_args(): ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) ap.add_argument("--src", help="Original dataset root (mirrors recursively)") ap.add_argument("--dst", help="Output root for the resized tree") ap.add_argument("--target", type=int, default=518, help="Shortest-edge target in px. MUST be >= 518 (RAD-DINO's " "processor resizes shortest edge to 518); 518 = smallest " "files, zero extra upscaling. Default 518.") ap.add_argument("--quality", type=int, default=90, help="JPEG quality (default 90)") ap.add_argument("--square", action="store_true", help="Also do the processor's center-crop here -> every file " "is exactly target x target and the RAD-DINO processor " "becomes a true no-op. Geometry identical to baseline " "(reproduces resize+crop, NOT a distorting squash). " "~20%% smaller but BAKES the crop in: changing " "crop/img_size/backbone later needs a full rebuild. " "Default off (keep aspect ratio, stay flexible).") ap.add_argument("--workers", type=int, default=os.cpu_count(), help="Parallel resize workers (default: all cores)") ap.add_argument("--no_resize", action="store_true", help="Skip phase 1") ap.add_argument("--no_pack", action="store_true", help="Skip phase 2 (tar shards)") ap.add_argument("--shards_dir", help="Where to write tar shards (default: _shards)") ap.add_argument("--shard_prefix", default="cxr", help="Shard filename prefix") ap.add_argument("--shard_gb", type=float, default=2.0, help="Approx GB per shard") ap.add_argument("--push", action="store_true", help="Phase 3: upload shards to HF Hub") ap.add_argument("--hf_repo", help="HF dataset repo id, e.g. /cxr-vlm-data-518") ap.add_argument("--hf_path", default="shards", help="Path inside the HF repo") ap.add_argument("--public", action="store_true", help="Make the HF repo public") ap.add_argument("--extract", nargs=2, metavar=("PATTERN", "DEST"), help='Standalone: rebuild the tree from shards, e.g. ' '--extract "shards/*.tar" /content/MIMIC-CXR-518') return ap.parse_args() def main(): a = parse_args() if a.extract: extract_shards(a.extract[0], Path(a.extract[1])) return # --dst is always needed (resize writes it, pack reads it); --src only # when actually resizing. Lets you re-pack/push an existing tree. if not a.dst: sys.exit("ERROR: --dst is required (or use --extract).") if not a.no_resize and a.target < 518: sys.exit(f"ERROR: --target {a.target} < 518. RAD-DINO upscales the " f"shortest edge to 518, so storing smaller only adds blur. " f"Use 518 (default) or higher.") dst = Path(a.dst) shards_dir = Path(a.shards_dir) if a.shards_dir else dst.parent / f"{dst.name}_shards" if not a.no_resize: if not a.src: sys.exit("ERROR: --src is required for the resize step " "(pass --no_resize to pack/push an existing tree).") src = Path(a.src) if not src.is_dir(): sys.exit(f"ERROR: --src not a directory: {src}") resize_tree(src, dst, a.target, a.quality, a.workers, a.square) if not a.no_pack: pack_shards(dst, shards_dir, a.shard_prefix, a.shard_gb) if a.push: if not a.hf_repo: sys.exit("ERROR: --push requires --hf_repo /") push_hf(shards_dir, a.hf_repo, a.hf_path, private=not a.public) print("\nAll done.") if __name__ == "__main__": main()