""" precompute_image_features.py ---------------------------- One-shot pre-computation of frozen image-encoder patch features. The image encoder (BioViL-T / RAD-DINO / ViT) is frozen at training time and the dataset uses a deterministic transform (Resize + ToTensor + Normalize), so the same image always produces the same (P, 768) patch feature tensor. Running the encoder every step wastes I/O (re-decoding JPEG) and a small but non-trivial slice of GPU compute. This script walks the unified instruct JSON, encodes each UNIQUE image path exactly once, and writes a `.pt` file under `feature_cache_dir` mirroring the relative image path. At training time, set `data.feature_cache_dir` in train_config.yaml — `data/dataset.py` loads the cached tensor on hit and `model/cxr_vlm.py` detects the (P, 768) shape and skips the encoder. Typical usage: python -m scripts.precompute_image_features \ --model_config configs/model_config.yaml \ --train_config configs/train_config.yaml \ --cache_dir cache/image_features \ --batch_size 16 After this finishes, edit train_config.yaml: data: feature_cache_dir: "cache/image_features" """ import argparse import json import sys from pathlib import Path import torch from omegaconf import OmegaConf # project root on path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from model.rad_dino import BioViLTEncoder from utils.dataset_resolver import resolve_dataset_spec def parse_args(): p = argparse.ArgumentParser(description="Pre-compute & cache patch features") p.add_argument("--model_config", default="configs/model_config.yaml") p.add_argument("--train_config", default="configs/train_config.yaml") p.add_argument( "--cache_dir", required=True, help="Output dir. Each image is saved as {cache_dir}/{image_relpath}.pt" ) p.add_argument("--batch_size", type=int, default=16) p.add_argument( "--device", default="cuda", help="cuda | cpu (cpu is fine — encoder is small)" ) p.add_argument( "--limit", type=int, default=None, help="Cap total images processed (useful for smoke tests)" ) p.add_argument( "--overwrite", action="store_true", help="Re-encode and overwrite existing .pt files (default: skip)" ) return p.parse_args() def collect_image_paths(instruct_json: str) -> list: """Walk the unified JSON, return sorted list of unique image relpaths (across ALL splits). Multi-image samples contribute each path separately.""" with open(instruct_json, "r", encoding="utf-8") as f: samples = json.load(f) seen = set() for s in samples: if s.get("image_paths"): for p in s["image_paths"]: seen.add(p) elif s.get("image_path"): seen.add(s["image_path"]) return sorted(seen) def main(): args = parse_args() train_cfg = OmegaConf.load(args.train_config) model_cfg = OmegaConf.load(args.model_config) spec = resolve_dataset_spec(train_cfg) cache_dir = Path(args.cache_dir).resolve() cache_dir.mkdir(parents=True, exist_ok=True) print(f"cache_dir : {cache_dir}") print(f"image_root : {spec.image_root}") print(f"instruct_json : {spec.instruct_json}") # ── Build encoder (frozen, in inference dtype) ─────────────────────── _DTYPE_MAP = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32} enc_dtype = _DTYPE_MAP.get(model_cfg.llm.torch_dtype, torch.float32) encoder = BioViLTEncoder( frozen = True, img_size = model_cfg.image_encoder.img_size, backend = getattr(model_cfg.image_encoder, "backend", "auto"), dtype = enc_dtype, ).to(args.device).eval() transform = BioViLTEncoder.get_transform("val") # ── Collect unique image paths ─────────────────────────────────────── image_paths = collect_image_paths(spec.instruct_json) print(f"unique images : {len(image_paths)}") if args.limit: image_paths = image_paths[: args.limit] print(f"(limited to {len(image_paths)})") # ── Encode in batches ──────────────────────────────────────────────── from PIL import Image import time image_root = Path(spec.image_root) done = 0 skipped = 0 failed = 0 t0 = time.time() def _flush_batch(batch_paths, batch_tensors): nonlocal done if not batch_tensors: return x = torch.stack(batch_tensors).to(args.device) with torch.no_grad(): with torch.autocast( "cuda" if args.device.startswith("cuda") else "cpu", dtype = enc_dtype if enc_dtype != torch.float32 else torch.float32, enabled = enc_dtype != torch.float32, ): feats = encoder(x) # (B, P, 768) feats = feats.to(torch.float16).cpu() # fp16 .pt → smaller on disk for rel, f in zip(batch_paths, feats): out = cache_dir / (rel + ".pt") out.parent.mkdir(parents=True, exist_ok=True) torch.save(f.contiguous(), out) done += 1 pending_paths, pending_tensors = [], [] for i, rel in enumerate(image_paths): out_path = cache_dir / (rel + ".pt") if out_path.is_file() and not args.overwrite: skipped += 1 continue try: img = Image.open(image_root / rel).convert("RGB") t = transform(img) pending_paths.append(rel) pending_tensors.append(t) except Exception as e: failed += 1 print(f" [skip] {rel}: {type(e).__name__}: {e}") continue if len(pending_tensors) >= args.batch_size: _flush_batch(pending_paths, pending_tensors) pending_paths, pending_tensors = [], [] if (i + 1) % 500 == 0: elapsed = time.time() - t0 rate = (done + skipped) / max(elapsed, 1e-6) print(f" [{i+1:>6}/{len(image_paths)}] done={done} " f"skipped={skipped} failed={failed} ({rate:.1f} img/s)") _flush_batch(pending_paths, pending_tensors) elapsed = time.time() - t0 print(f"\nFinished. encoded={done} skipped(existing)={skipped} failed={failed} " f"elapsed={elapsed/60:.1f} min") print(f"\nNext: set this in configs/train_config.yaml under `data:` ↓") print(f" feature_cache_dir: \"{cache_dir}\"") if __name__ == "__main__": main()