File size: 6,827 Bytes
c61f01a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | """
precompute_image_features.py
----------------------------
One-shot pre-computation of frozen image-encoder patch features.
The image encoder (BioViL-T / RAD-DINO / ViT) is frozen at training time and
the dataset uses a deterministic transform (Resize + ToTensor + Normalize),
so the same image always produces the same (P, 768) patch feature tensor.
Running the encoder every step wastes I/O (re-decoding JPEG) and a small but
non-trivial slice of GPU compute.
This script walks the unified instruct JSON, encodes each UNIQUE image path
exactly once, and writes a `.pt` file under `feature_cache_dir` mirroring the
relative image path. At training time, set `data.feature_cache_dir` in
train_config.yaml — `data/dataset.py` loads the cached tensor on hit and
`model/cxr_vlm.py` detects the (P, 768) shape and skips the encoder.
Typical usage:
python -m scripts.precompute_image_features \
--model_config configs/model_config.yaml \
--train_config configs/train_config.yaml \
--cache_dir cache/image_features \
--batch_size 16
After this finishes, edit train_config.yaml:
data:
feature_cache_dir: "cache/image_features"
"""
import argparse
import json
import sys
from pathlib import Path
import torch
from omegaconf import OmegaConf
# project root on path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from model.rad_dino import BioViLTEncoder
from utils.dataset_resolver import resolve_dataset_spec
def parse_args():
p = argparse.ArgumentParser(description="Pre-compute & cache patch features")
p.add_argument("--model_config", default="configs/model_config.yaml")
p.add_argument("--train_config", default="configs/train_config.yaml")
p.add_argument(
"--cache_dir", required=True,
help="Output dir. Each image is saved as {cache_dir}/{image_relpath}.pt"
)
p.add_argument("--batch_size", type=int, default=16)
p.add_argument(
"--device", default="cuda",
help="cuda | cpu (cpu is fine — encoder is small)"
)
p.add_argument(
"--limit", type=int, default=None,
help="Cap total images processed (useful for smoke tests)"
)
p.add_argument(
"--overwrite", action="store_true",
help="Re-encode and overwrite existing .pt files (default: skip)"
)
return p.parse_args()
def collect_image_paths(instruct_json: str) -> list:
"""Walk the unified JSON, return sorted list of unique image relpaths
(across ALL splits). Multi-image samples contribute each path separately."""
with open(instruct_json, "r", encoding="utf-8") as f:
samples = json.load(f)
seen = set()
for s in samples:
if s.get("image_paths"):
for p in s["image_paths"]:
seen.add(p)
elif s.get("image_path"):
seen.add(s["image_path"])
return sorted(seen)
def main():
args = parse_args()
train_cfg = OmegaConf.load(args.train_config)
model_cfg = OmegaConf.load(args.model_config)
spec = resolve_dataset_spec(train_cfg)
cache_dir = Path(args.cache_dir).resolve()
cache_dir.mkdir(parents=True, exist_ok=True)
print(f"cache_dir : {cache_dir}")
print(f"image_root : {spec.image_root}")
print(f"instruct_json : {spec.instruct_json}")
# ── Build encoder (frozen, in inference dtype) ───────────────────────
_DTYPE_MAP = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
enc_dtype = _DTYPE_MAP.get(model_cfg.llm.torch_dtype, torch.float32)
encoder = BioViLTEncoder(
frozen = True,
img_size = model_cfg.image_encoder.img_size,
backend = getattr(model_cfg.image_encoder, "backend", "auto"),
dtype = enc_dtype,
).to(args.device).eval()
transform = BioViLTEncoder.get_transform("val")
# ── Collect unique image paths ───────────────────────────────────────
image_paths = collect_image_paths(spec.instruct_json)
print(f"unique images : {len(image_paths)}")
if args.limit:
image_paths = image_paths[: args.limit]
print(f"(limited to {len(image_paths)})")
# ── Encode in batches ────────────────────────────────────────────────
from PIL import Image
import time
image_root = Path(spec.image_root)
done = 0
skipped = 0
failed = 0
t0 = time.time()
def _flush_batch(batch_paths, batch_tensors):
nonlocal done
if not batch_tensors:
return
x = torch.stack(batch_tensors).to(args.device)
with torch.no_grad():
with torch.autocast(
"cuda" if args.device.startswith("cuda") else "cpu",
dtype = enc_dtype if enc_dtype != torch.float32 else torch.float32,
enabled = enc_dtype != torch.float32,
):
feats = encoder(x) # (B, P, 768)
feats = feats.to(torch.float16).cpu() # fp16 .pt → smaller on disk
for rel, f in zip(batch_paths, feats):
out = cache_dir / (rel + ".pt")
out.parent.mkdir(parents=True, exist_ok=True)
torch.save(f.contiguous(), out)
done += 1
pending_paths, pending_tensors = [], []
for i, rel in enumerate(image_paths):
out_path = cache_dir / (rel + ".pt")
if out_path.is_file() and not args.overwrite:
skipped += 1
continue
try:
img = Image.open(image_root / rel).convert("RGB")
t = transform(img)
pending_paths.append(rel)
pending_tensors.append(t)
except Exception as e:
failed += 1
print(f" [skip] {rel}: {type(e).__name__}: {e}")
continue
if len(pending_tensors) >= args.batch_size:
_flush_batch(pending_paths, pending_tensors)
pending_paths, pending_tensors = [], []
if (i + 1) % 500 == 0:
elapsed = time.time() - t0
rate = (done + skipped) / max(elapsed, 1e-6)
print(f" [{i+1:>6}/{len(image_paths)}] done={done} "
f"skipped={skipped} failed={failed} ({rate:.1f} img/s)")
_flush_batch(pending_paths, pending_tensors)
elapsed = time.time() - t0
print(f"\nFinished. encoded={done} skipped(existing)={skipped} failed={failed} "
f"elapsed={elapsed/60:.1f} min")
print(f"\nNext: set this in configs/train_config.yaml under `data:` ↓")
print(f" feature_cache_dir: \"{cache_dir}\"")
if __name__ == "__main__":
main()
|