cxr-vlm-code / scripts /precompute_image_features.py
convitom
f
c61f01a
"""
precompute_image_features.py
----------------------------
One-shot pre-computation of frozen image-encoder patch features.
The image encoder (BioViL-T / RAD-DINO / ViT) is frozen at training time and
the dataset uses a deterministic transform (Resize + ToTensor + Normalize),
so the same image always produces the same (P, 768) patch feature tensor.
Running the encoder every step wastes I/O (re-decoding JPEG) and a small but
non-trivial slice of GPU compute.
This script walks the unified instruct JSON, encodes each UNIQUE image path
exactly once, and writes a `.pt` file under `feature_cache_dir` mirroring the
relative image path. At training time, set `data.feature_cache_dir` in
train_config.yaml — `data/dataset.py` loads the cached tensor on hit and
`model/cxr_vlm.py` detects the (P, 768) shape and skips the encoder.
Typical usage:
python -m scripts.precompute_image_features \
--model_config configs/model_config.yaml \
--train_config configs/train_config.yaml \
--cache_dir cache/image_features \
--batch_size 16
After this finishes, edit train_config.yaml:
data:
feature_cache_dir: "cache/image_features"
"""
import argparse
import json
import sys
from pathlib import Path
import torch
from omegaconf import OmegaConf
# project root on path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from model.rad_dino import BioViLTEncoder
from utils.dataset_resolver import resolve_dataset_spec
def parse_args():
p = argparse.ArgumentParser(description="Pre-compute & cache patch features")
p.add_argument("--model_config", default="configs/model_config.yaml")
p.add_argument("--train_config", default="configs/train_config.yaml")
p.add_argument(
"--cache_dir", required=True,
help="Output dir. Each image is saved as {cache_dir}/{image_relpath}.pt"
)
p.add_argument("--batch_size", type=int, default=16)
p.add_argument(
"--device", default="cuda",
help="cuda | cpu (cpu is fine — encoder is small)"
)
p.add_argument(
"--limit", type=int, default=None,
help="Cap total images processed (useful for smoke tests)"
)
p.add_argument(
"--overwrite", action="store_true",
help="Re-encode and overwrite existing .pt files (default: skip)"
)
return p.parse_args()
def collect_image_paths(instruct_json: str) -> list:
"""Walk the unified JSON, return sorted list of unique image relpaths
(across ALL splits). Multi-image samples contribute each path separately."""
with open(instruct_json, "r", encoding="utf-8") as f:
samples = json.load(f)
seen = set()
for s in samples:
if s.get("image_paths"):
for p in s["image_paths"]:
seen.add(p)
elif s.get("image_path"):
seen.add(s["image_path"])
return sorted(seen)
def main():
args = parse_args()
train_cfg = OmegaConf.load(args.train_config)
model_cfg = OmegaConf.load(args.model_config)
spec = resolve_dataset_spec(train_cfg)
cache_dir = Path(args.cache_dir).resolve()
cache_dir.mkdir(parents=True, exist_ok=True)
print(f"cache_dir : {cache_dir}")
print(f"image_root : {spec.image_root}")
print(f"instruct_json : {spec.instruct_json}")
# ── Build encoder (frozen, in inference dtype) ───────────────────────
_DTYPE_MAP = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
enc_dtype = _DTYPE_MAP.get(model_cfg.llm.torch_dtype, torch.float32)
encoder = BioViLTEncoder(
frozen = True,
img_size = model_cfg.image_encoder.img_size,
backend = getattr(model_cfg.image_encoder, "backend", "auto"),
dtype = enc_dtype,
).to(args.device).eval()
transform = BioViLTEncoder.get_transform("val")
# ── Collect unique image paths ───────────────────────────────────────
image_paths = collect_image_paths(spec.instruct_json)
print(f"unique images : {len(image_paths)}")
if args.limit:
image_paths = image_paths[: args.limit]
print(f"(limited to {len(image_paths)})")
# ── Encode in batches ────────────────────────────────────────────────
from PIL import Image
import time
image_root = Path(spec.image_root)
done = 0
skipped = 0
failed = 0
t0 = time.time()
def _flush_batch(batch_paths, batch_tensors):
nonlocal done
if not batch_tensors:
return
x = torch.stack(batch_tensors).to(args.device)
with torch.no_grad():
with torch.autocast(
"cuda" if args.device.startswith("cuda") else "cpu",
dtype = enc_dtype if enc_dtype != torch.float32 else torch.float32,
enabled = enc_dtype != torch.float32,
):
feats = encoder(x) # (B, P, 768)
feats = feats.to(torch.float16).cpu() # fp16 .pt → smaller on disk
for rel, f in zip(batch_paths, feats):
out = cache_dir / (rel + ".pt")
out.parent.mkdir(parents=True, exist_ok=True)
torch.save(f.contiguous(), out)
done += 1
pending_paths, pending_tensors = [], []
for i, rel in enumerate(image_paths):
out_path = cache_dir / (rel + ".pt")
if out_path.is_file() and not args.overwrite:
skipped += 1
continue
try:
img = Image.open(image_root / rel).convert("RGB")
t = transform(img)
pending_paths.append(rel)
pending_tensors.append(t)
except Exception as e:
failed += 1
print(f" [skip] {rel}: {type(e).__name__}: {e}")
continue
if len(pending_tensors) >= args.batch_size:
_flush_batch(pending_paths, pending_tensors)
pending_paths, pending_tensors = [], []
if (i + 1) % 500 == 0:
elapsed = time.time() - t0
rate = (done + skipped) / max(elapsed, 1e-6)
print(f" [{i+1:>6}/{len(image_paths)}] done={done} "
f"skipped={skipped} failed={failed} ({rate:.1f} img/s)")
_flush_batch(pending_paths, pending_tensors)
elapsed = time.time() - t0
print(f"\nFinished. encoded={done} skipped(existing)={skipped} failed={failed} "
f"elapsed={elapsed/60:.1f} min")
print(f"\nNext: set this in configs/train_config.yaml under `data:` ↓")
print(f" feature_cache_dir: \"{cache_dir}\"")
if __name__ == "__main__":
main()