| """ |
| precompute_image_features.py |
| ---------------------------- |
| One-shot pre-computation of frozen image-encoder patch features. |
| |
| The image encoder (BioViL-T / RAD-DINO / ViT) is frozen at training time and |
| the dataset uses a deterministic transform (Resize + ToTensor + Normalize), |
| so the same image always produces the same (P, 768) patch feature tensor. |
| Running the encoder every step wastes I/O (re-decoding JPEG) and a small but |
| non-trivial slice of GPU compute. |
| |
| This script walks the unified instruct JSON, encodes each UNIQUE image path |
| exactly once, and writes a `.pt` file under `feature_cache_dir` mirroring the |
| relative image path. At training time, set `data.feature_cache_dir` in |
| train_config.yaml — `data/dataset.py` loads the cached tensor on hit and |
| `model/cxr_vlm.py` detects the (P, 768) shape and skips the encoder. |
| |
| Typical usage: |
| |
| python -m scripts.precompute_image_features \ |
| --model_config configs/model_config.yaml \ |
| --train_config configs/train_config.yaml \ |
| --cache_dir cache/image_features \ |
| --batch_size 16 |
| |
| After this finishes, edit train_config.yaml: |
| data: |
| feature_cache_dir: "cache/image_features" |
| """ |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| from omegaconf import OmegaConf |
|
|
| |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) |
|
|
| from model.rad_dino import BioViLTEncoder |
| from utils.dataset_resolver import resolve_dataset_spec |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="Pre-compute & cache patch features") |
| p.add_argument("--model_config", default="configs/model_config.yaml") |
| p.add_argument("--train_config", default="configs/train_config.yaml") |
| p.add_argument( |
| "--cache_dir", required=True, |
| help="Output dir. Each image is saved as {cache_dir}/{image_relpath}.pt" |
| ) |
| p.add_argument("--batch_size", type=int, default=16) |
| p.add_argument( |
| "--device", default="cuda", |
| help="cuda | cpu (cpu is fine — encoder is small)" |
| ) |
| p.add_argument( |
| "--limit", type=int, default=None, |
| help="Cap total images processed (useful for smoke tests)" |
| ) |
| p.add_argument( |
| "--overwrite", action="store_true", |
| help="Re-encode and overwrite existing .pt files (default: skip)" |
| ) |
| return p.parse_args() |
|
|
|
|
| def collect_image_paths(instruct_json: str) -> list: |
| """Walk the unified JSON, return sorted list of unique image relpaths |
| (across ALL splits). Multi-image samples contribute each path separately.""" |
| with open(instruct_json, "r", encoding="utf-8") as f: |
| samples = json.load(f) |
|
|
| seen = set() |
| for s in samples: |
| if s.get("image_paths"): |
| for p in s["image_paths"]: |
| seen.add(p) |
| elif s.get("image_path"): |
| seen.add(s["image_path"]) |
| return sorted(seen) |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| train_cfg = OmegaConf.load(args.train_config) |
| model_cfg = OmegaConf.load(args.model_config) |
| spec = resolve_dataset_spec(train_cfg) |
|
|
| cache_dir = Path(args.cache_dir).resolve() |
| cache_dir.mkdir(parents=True, exist_ok=True) |
| print(f"cache_dir : {cache_dir}") |
| print(f"image_root : {spec.image_root}") |
| print(f"instruct_json : {spec.instruct_json}") |
|
|
| |
| _DTYPE_MAP = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32} |
| enc_dtype = _DTYPE_MAP.get(model_cfg.llm.torch_dtype, torch.float32) |
| encoder = BioViLTEncoder( |
| frozen = True, |
| img_size = model_cfg.image_encoder.img_size, |
| backend = getattr(model_cfg.image_encoder, "backend", "auto"), |
| dtype = enc_dtype, |
| ).to(args.device).eval() |
|
|
| transform = BioViLTEncoder.get_transform("val") |
|
|
| |
| image_paths = collect_image_paths(spec.instruct_json) |
| print(f"unique images : {len(image_paths)}") |
| if args.limit: |
| image_paths = image_paths[: args.limit] |
| print(f"(limited to {len(image_paths)})") |
|
|
| |
| from PIL import Image |
| import time |
| image_root = Path(spec.image_root) |
| done = 0 |
| skipped = 0 |
| failed = 0 |
| t0 = time.time() |
|
|
| def _flush_batch(batch_paths, batch_tensors): |
| nonlocal done |
| if not batch_tensors: |
| return |
| x = torch.stack(batch_tensors).to(args.device) |
| with torch.no_grad(): |
| with torch.autocast( |
| "cuda" if args.device.startswith("cuda") else "cpu", |
| dtype = enc_dtype if enc_dtype != torch.float32 else torch.float32, |
| enabled = enc_dtype != torch.float32, |
| ): |
| feats = encoder(x) |
| feats = feats.to(torch.float16).cpu() |
| for rel, f in zip(batch_paths, feats): |
| out = cache_dir / (rel + ".pt") |
| out.parent.mkdir(parents=True, exist_ok=True) |
| torch.save(f.contiguous(), out) |
| done += 1 |
|
|
| pending_paths, pending_tensors = [], [] |
| for i, rel in enumerate(image_paths): |
| out_path = cache_dir / (rel + ".pt") |
| if out_path.is_file() and not args.overwrite: |
| skipped += 1 |
| continue |
| try: |
| img = Image.open(image_root / rel).convert("RGB") |
| t = transform(img) |
| pending_paths.append(rel) |
| pending_tensors.append(t) |
| except Exception as e: |
| failed += 1 |
| print(f" [skip] {rel}: {type(e).__name__}: {e}") |
| continue |
|
|
| if len(pending_tensors) >= args.batch_size: |
| _flush_batch(pending_paths, pending_tensors) |
| pending_paths, pending_tensors = [], [] |
|
|
| if (i + 1) % 500 == 0: |
| elapsed = time.time() - t0 |
| rate = (done + skipped) / max(elapsed, 1e-6) |
| print(f" [{i+1:>6}/{len(image_paths)}] done={done} " |
| f"skipped={skipped} failed={failed} ({rate:.1f} img/s)") |
|
|
| _flush_batch(pending_paths, pending_tensors) |
|
|
| elapsed = time.time() - t0 |
| print(f"\nFinished. encoded={done} skipped(existing)={skipped} failed={failed} " |
| f"elapsed={elapsed/60:.1f} min") |
| print(f"\nNext: set this in configs/train_config.yaml under `data:` ↓") |
| print(f" feature_cache_dir: \"{cache_dir}\"") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|