cxr-vlm-code / scripts /build_resized_dataset.py
convitom
fix(scripts): carry non-image files through resize + tar shards
0a99045
"""
build_resized_dataset.py
------------------------
One-time data shrinker for cloud training (Vast.ai / Lightning.ai / Colab).
MIMIC-CXR-JPG originals are ~2-3 MP each; RAD-DINO downscales to ~518 px
internally anyway, so storing full-res images just wastes I/O. This script
re-encodes every JPG to a small longer-side cap, preserving the EXACT
directory tree so you only have to re-point `data.mimic_cxr_root` at the
output -- no change to dataset.py / cxr_vlm.py.
Why image-resize and NOT feature-cache: a frozen-encoder feature tensor is
~2 MB/image (1369x768 fp16, incompressible) -- larger than the source JPG.
The encoder is also only ~1-2% of per-step compute (Vicuna-7B dominates),
so caching it barely speeds training. Shrinking the JPG instead removes the
real bottleneck (decode of huge images) at ~1/30th the storage, with no
architecture risk and augmentation still possible later.
Pipeline (each step skippable):
1. resize : src tree -> dst tree (only downscales; skips up-to-date files, resumable)
2. pack : dst tree -> tar shards (~2 GB each, keeps the tree on extract)
3. push : shards -> HF Hub private dataset repo
Usage (from project root):
# resize + pack
python scripts/build_resized_dataset.py \
--src /data/MIMIC-CXR --dst /data/MIMIC-CXR-518
# resize + pack + push to HF
$env:HF_TOKEN='hf_xxx'
python scripts/build_resized_dataset.py \
--src /data/MIMIC-CXR --dst /data/MIMIC-CXR-518 \
--push --hf_repo <user>/cxr-vlm-data-518
# on the training box: pull shards then rebuild the tree onto fast NVMe
python scripts/build_resized_dataset.py --extract "shards/*.tar" /content/MIMIC-CXR-518
# -> set data.mimic_cxr_root: /content/MIMIC-CXR-518
"""
from __future__ import annotations
import argparse
import glob
import json
import os
import shutil
import sys
import tarfile
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from PIL import Image
from tqdm import tqdm
# CXR can be large; don't let Pillow's bomb guard abort on legit medical images.
Image.MAX_IMAGE_PIXELS = None
PROJECT_ROOT = Path(__file__).resolve().parents[1]
IMG_EXTS = (".jpg", ".jpeg", ".png")
# -- Phase 1: resize ---------------------------------------------------------
def _resize_one(args) -> tuple[str, str]:
"""Worker: resize a single image. Returns (status, rel_path).
status is one of: "resized", "squared", "copied", "skipped", "error:<msg>".
"copied" = source shorter side already <= target (non-square mode only);
re-encoding would only lose quality.
"skipped" = up-to-date output already exists (makes the run resumable).
Two modes:
default : resize shortest edge -> target, KEEP aspect ratio. The
RAD-DINO processor will center-crop to 518x518 at train
time. Flexible (crop/backbone choices stay open), ~20%
bigger than square.
--square : also replicate the processor's center-crop here, so every
file is exactly target x target and the processor becomes a
true no-op. Geometry is IDENTICAL to baseline (we reproduce
its resize+crop, not a distorting squash). Bakes the crop in
-> changing crop/img_size/backbone later needs a rebuild.
"""
src_path, dst_path, rel, target, quality, square = args
try:
dst_path = Path(dst_path)
if dst_path.exists() and dst_path.stat().st_size > 0:
return "skipped", rel
dst_path.parent.mkdir(parents=True, exist_ok=True)
with Image.open(src_path) as im:
w, h = im.size
shorter = min(w, h)
# Non-square: if shorter side already <= target, downscaling would
# push it below 518 -> copy verbatim (lossless, never worsens a
# low-res source). In square mode we must always produce exactly
# target^2, replicating the processor (which itself upscales a
# sub-518 image), so don't short-circuit there.
if not square and shorter <= target:
shutil.copy2(src_path, dst_path)
return "copied", rel
# Match training-time load (dataset.py does .convert("RGB"));
# collapse exotic modes so JPEG save can't fail.
if im.mode not in ("L", "RGB"):
im = im.convert("RGB")
# Resize shorter axis EXACTLY to target (no rounding drift below
# it); longer axis scales proportionally.
if w <= h:
new_size = (target, round(h * target / w))
else:
new_size = (round(w * target / h), target)
# square mode mirrors the processor exactly -> bicubic (resample=3)
# so this output IS what the processor would have produced.
im = im.resize(new_size, Image.BICUBIC if square else Image.LANCZOS)
if square:
W, H = im.size
left, top = (W - target) // 2, (H - target) // 2
im = im.crop((left, top, left + target, top + target))
# subsampling=0 (4:4:4) preserves thin findings (e.g. pneumothorax line).
im.save(dst_path, "JPEG", quality=quality, optimize=True, subsampling=0)
return ("squared" if square else "resized"), rel
except Exception as e: # corrupt/unreadable source -- log & continue
return f"error:{type(e).__name__}: {e}", rel
def _copy_one(args) -> tuple[str, str]:
"""Worker: copy a non-image file verbatim, preserving the tree.
Used for reports (.txt), CheXpert labels (.csv), metadata (.json) and
anything else interleaved in the source tree -- so the tar shards carry
a complete copy of MIMIC-CXR_processed, not just images.
"""
src_path, dst_path, rel = args
try:
dst_path = Path(dst_path)
if dst_path.exists() and dst_path.stat().st_size > 0:
return "skipped", rel
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_path, dst_path)
return "copied_other", rel
except Exception as e:
return f"error:{type(e).__name__}: {e}", rel
def resize_tree(src: Path, dst: Path, target: int, quality: int,
workers: int, square: bool) -> None:
print(f"[resize] scanning {src} ...")
img_jobs, other_jobs = [], []
for root, _, files in os.walk(src):
for fn in files:
sp = Path(root) / fn
rel = sp.relative_to(src)
dp = dst / rel
if fn.lower().endswith(IMG_EXTS):
img_jobs.append((str(sp), str(dp), str(rel), target, quality, square))
else:
# non-image: reports/csv/json/etc. copied verbatim so the
# shipped tree mirrors the source exactly (no data loss).
other_jobs.append((str(sp), str(dp), str(rel)))
if not img_jobs and not other_jobs:
sys.exit(f"ERROR: nothing found under {src}")
mode = f"square {target}x{target}" if square else f"shortest-edge {target}px"
print(f"[resize] {len(img_jobs):,} images + {len(other_jobs):,} non-image "
f"-> {dst} ({mode}, q{quality}, {workers} workers)")
counts = {"resized": 0, "squared": 0, "copied": 0,
"copied_other": 0, "skipped": 0, "error": 0}
errors: list[str] = []
with ProcessPoolExecutor(max_workers=workers) as ex:
futs = [ex.submit(_resize_one, j) for j in img_jobs]
futs += [ex.submit(_copy_one, j) for j in other_jobs]
for f in tqdm(as_completed(futs), total=len(futs), unit="file"):
status, rel = f.result()
if status.startswith("error:"):
counts["error"] += 1
errors.append(f"{rel}\t{status}")
else:
counts[status] += 1
dst.mkdir(parents=True, exist_ok=True)
total = len(img_jobs) + len(other_jobs)
out_bytes = sum(p.stat().st_size for p in dst.rglob("*") if p.is_file())
(dst / "_manifest.json").write_text(json.dumps({
"source": str(src), "target": target,
"mode": "square" if square else "shortest_edge",
"jpeg_quality": quality, "subsampling": "4:4:4",
"resampling": "BICUBIC" if square else "LANCZOS",
"counts": counts, "total": total,
"images": len(img_jobs), "non_image": len(other_jobs),
"output_bytes": out_bytes,
"built_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
}, indent=2), encoding="utf-8")
if errors:
(dst / "_errors.txt").write_text("\n".join(errors), encoding="utf-8")
print(f"[resize] WARNING: {len(errors)} failures -> {dst/'_errors.txt'}")
print(f"[resize] done: {counts}")
print(f"[resize] output size: {out_bytes / 1024**3:.2f} GB "
f"({out_bytes / max(1, len(img_jobs)) / 1024:.0f} KB/image avg)")
# -- Phase 2: pack into tar shards -------------------------------------------
def pack_shards(dst: Path, shards_dir: Path, prefix: str, shard_gb: float) -> list[Path]:
shard_bytes = int(shard_gb * (1024 ** 3))
shards_dir.mkdir(parents=True, exist_ok=True)
files = sorted(
p for p in dst.rglob("*")
if p.is_file() and p.name not in ("_manifest.json", "_errors.txt")
)
if not files:
sys.exit(f"ERROR: nothing to pack under {dst} (run resize first)")
print(f"[pack] {len(files):,} files -> tar shards (~{shard_gb} GB each) in {shards_dir}")
written: list[Path] = []
idx, cur_bytes = 0, 0
def _open(i: int) -> tarfile.TarFile:
path = shards_dir / f"{prefix}-{i:04d}.tar"
written.append(path)
return tarfile.open(path, "w")
tar = _open(0)
for fp in tqdm(files, unit="file"):
if cur_bytes >= shard_bytes:
tar.close()
idx += 1
tar = _open(idx)
cur_bytes = 0
# arcname = path relative to dst -> extracting any shard rebuilds the tree.
tar.add(fp, arcname=str(fp.relative_to(dst)))
cur_bytes += fp.stat().st_size
tar.close()
# ship the manifest alongside the shards (not inside them)
man = dst / "_manifest.json"
if man.exists():
shutil.copy2(man, shards_dir / "_manifest.json")
(shards_dir / "SHARDS.txt").write_text(
"\n".join(p.name for p in written), encoding="utf-8")
print(f"[pack] wrote {len(written)} shards -> {shards_dir}")
return written
# -- Phase 3: push to HF Hub -------------------------------------------------
def push_hf(shards_dir: Path, repo_id: str, path_in_repo: str, private: bool) -> None:
token = os.environ.get("HF_TOKEN")
if not token:
sys.exit("ERROR: --push needs HF_TOKEN env var (write-scope token).")
from huggingface_hub import HfApi, create_repo
print(f"[push] {shards_dir} -> {repo_id}:{path_in_repo}")
create_repo(repo_id, repo_type="dataset", private=private, token=token, exist_ok=True)
HfApi(token=token).upload_folder(
folder_path=str(shards_dir),
path_in_repo=path_in_repo,
repo_id=repo_id,
repo_type="dataset",
token=token,
)
print(f"OK: pushed -> https://huggingface.co/datasets/{repo_id}")
# -- Extract helper (run on the training box) --------------------------------
def extract_shards(pattern: str, dest: Path) -> None:
tars = sorted(glob.glob(pattern))
if not tars:
sys.exit(f"ERROR: no tar shards match {pattern!r}")
dest.mkdir(parents=True, exist_ok=True)
print(f"[extract] {len(tars)} shards -> {dest}")
for t in tqdm(tars, unit="shard"):
with tarfile.open(t, "r") as tf:
tf.extractall(dest)
print(f"[extract] done. Set data.mimic_cxr_root: {dest}")
# -- CLI ---------------------------------------------------------------------
def parse_args():
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--src", help="Original dataset root (mirrors recursively)")
ap.add_argument("--dst", help="Output root for the resized tree")
ap.add_argument("--target", type=int, default=518,
help="Shortest-edge target in px. MUST be >= 518 (RAD-DINO's "
"processor resizes shortest edge to 518); 518 = smallest "
"files, zero extra upscaling. Default 518.")
ap.add_argument("--quality", type=int, default=90, help="JPEG quality (default 90)")
ap.add_argument("--square", action="store_true",
help="Also do the processor's center-crop here -> every file "
"is exactly target x target and the RAD-DINO processor "
"becomes a true no-op. Geometry identical to baseline "
"(reproduces resize+crop, NOT a distorting squash). "
"~20%% smaller but BAKES the crop in: changing "
"crop/img_size/backbone later needs a full rebuild. "
"Default off (keep aspect ratio, stay flexible).")
ap.add_argument("--workers", type=int, default=os.cpu_count(),
help="Parallel resize workers (default: all cores)")
ap.add_argument("--no_resize", action="store_true", help="Skip phase 1")
ap.add_argument("--no_pack", action="store_true", help="Skip phase 2 (tar shards)")
ap.add_argument("--shards_dir", help="Where to write tar shards (default: <dst>_shards)")
ap.add_argument("--shard_prefix", default="cxr", help="Shard filename prefix")
ap.add_argument("--shard_gb", type=float, default=2.0, help="Approx GB per shard")
ap.add_argument("--push", action="store_true", help="Phase 3: upload shards to HF Hub")
ap.add_argument("--hf_repo", help="HF dataset repo id, e.g. <user>/cxr-vlm-data-518")
ap.add_argument("--hf_path", default="shards", help="Path inside the HF repo")
ap.add_argument("--public", action="store_true", help="Make the HF repo public")
ap.add_argument("--extract", nargs=2, metavar=("PATTERN", "DEST"),
help='Standalone: rebuild the tree from shards, e.g. '
'--extract "shards/*.tar" /content/MIMIC-CXR-518')
return ap.parse_args()
def main():
a = parse_args()
if a.extract:
extract_shards(a.extract[0], Path(a.extract[1]))
return
# --dst is always needed (resize writes it, pack reads it); --src only
# when actually resizing. Lets you re-pack/push an existing tree.
if not a.dst:
sys.exit("ERROR: --dst is required (or use --extract).")
if not a.no_resize and a.target < 518:
sys.exit(f"ERROR: --target {a.target} < 518. RAD-DINO upscales the "
f"shortest edge to 518, so storing smaller only adds blur. "
f"Use 518 (default) or higher.")
dst = Path(a.dst)
shards_dir = Path(a.shards_dir) if a.shards_dir else dst.parent / f"{dst.name}_shards"
if not a.no_resize:
if not a.src:
sys.exit("ERROR: --src is required for the resize step "
"(pass --no_resize to pack/push an existing tree).")
src = Path(a.src)
if not src.is_dir():
sys.exit(f"ERROR: --src not a directory: {src}")
resize_tree(src, dst, a.target, a.quality, a.workers, a.square)
if not a.no_pack:
pack_shards(dst, shards_dir, a.shard_prefix, a.shard_gb)
if a.push:
if not a.hf_repo:
sys.exit("ERROR: --push requires --hf_repo <user>/<repo>")
push_hf(shards_dir, a.hf_repo, a.hf_path, private=not a.public)
print("\nAll done.")
if __name__ == "__main__":
main()