File size: 6,827 Bytes
c61f01a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""
precompute_image_features.py
----------------------------
One-shot pre-computation of frozen image-encoder patch features.

The image encoder (BioViL-T / RAD-DINO / ViT) is frozen at training time and
the dataset uses a deterministic transform (Resize + ToTensor + Normalize),
so the same image always produces the same (P, 768) patch feature tensor.
Running the encoder every step wastes I/O (re-decoding JPEG) and a small but
non-trivial slice of GPU compute.

This script walks the unified instruct JSON, encodes each UNIQUE image path
exactly once, and writes a `.pt` file under `feature_cache_dir` mirroring the
relative image path. At training time, set `data.feature_cache_dir` in
train_config.yaml — `data/dataset.py` loads the cached tensor on hit and
`model/cxr_vlm.py` detects the (P, 768) shape and skips the encoder.

Typical usage:

    python -m scripts.precompute_image_features \
        --model_config configs/model_config.yaml \
        --train_config configs/train_config.yaml \
        --cache_dir cache/image_features \
        --batch_size 16

After this finishes, edit train_config.yaml:
    data:
      feature_cache_dir: "cache/image_features"
"""

import argparse
import json
import sys
from pathlib import Path

import torch
from omegaconf import OmegaConf

# project root on path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from model.rad_dino import BioViLTEncoder
from utils.dataset_resolver import resolve_dataset_spec


def parse_args():
    p = argparse.ArgumentParser(description="Pre-compute & cache patch features")
    p.add_argument("--model_config", default="configs/model_config.yaml")
    p.add_argument("--train_config", default="configs/train_config.yaml")
    p.add_argument(
        "--cache_dir", required=True,
        help="Output dir. Each image is saved as {cache_dir}/{image_relpath}.pt"
    )
    p.add_argument("--batch_size", type=int, default=16)
    p.add_argument(
        "--device", default="cuda",
        help="cuda | cpu (cpu is fine — encoder is small)"
    )
    p.add_argument(
        "--limit", type=int, default=None,
        help="Cap total images processed (useful for smoke tests)"
    )
    p.add_argument(
        "--overwrite", action="store_true",
        help="Re-encode and overwrite existing .pt files (default: skip)"
    )
    return p.parse_args()


def collect_image_paths(instruct_json: str) -> list:
    """Walk the unified JSON, return sorted list of unique image relpaths
    (across ALL splits). Multi-image samples contribute each path separately."""
    with open(instruct_json, "r", encoding="utf-8") as f:
        samples = json.load(f)

    seen = set()
    for s in samples:
        if s.get("image_paths"):
            for p in s["image_paths"]:
                seen.add(p)
        elif s.get("image_path"):
            seen.add(s["image_path"])
    return sorted(seen)


def main():
    args = parse_args()

    train_cfg = OmegaConf.load(args.train_config)
    model_cfg = OmegaConf.load(args.model_config)
    spec      = resolve_dataset_spec(train_cfg)

    cache_dir = Path(args.cache_dir).resolve()
    cache_dir.mkdir(parents=True, exist_ok=True)
    print(f"cache_dir       : {cache_dir}")
    print(f"image_root      : {spec.image_root}")
    print(f"instruct_json   : {spec.instruct_json}")

    # ── Build encoder (frozen, in inference dtype) ───────────────────────
    _DTYPE_MAP = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
    enc_dtype = _DTYPE_MAP.get(model_cfg.llm.torch_dtype, torch.float32)
    encoder = BioViLTEncoder(
        frozen   = True,
        img_size = model_cfg.image_encoder.img_size,
        backend  = getattr(model_cfg.image_encoder, "backend", "auto"),
        dtype    = enc_dtype,
    ).to(args.device).eval()

    transform = BioViLTEncoder.get_transform("val")

    # ── Collect unique image paths ───────────────────────────────────────
    image_paths = collect_image_paths(spec.instruct_json)
    print(f"unique images   : {len(image_paths)}")
    if args.limit:
        image_paths = image_paths[: args.limit]
        print(f"(limited to {len(image_paths)})")

    # ── Encode in batches ────────────────────────────────────────────────
    from PIL import Image
    import time
    image_root = Path(spec.image_root)
    done = 0
    skipped = 0
    failed = 0
    t0 = time.time()

    def _flush_batch(batch_paths, batch_tensors):
        nonlocal done
        if not batch_tensors:
            return
        x = torch.stack(batch_tensors).to(args.device)
        with torch.no_grad():
            with torch.autocast(
                "cuda" if args.device.startswith("cuda") else "cpu",
                dtype = enc_dtype if enc_dtype != torch.float32 else torch.float32,
                enabled = enc_dtype != torch.float32,
            ):
                feats = encoder(x)                     # (B, P, 768)
        feats = feats.to(torch.float16).cpu()           # fp16 .pt → smaller on disk
        for rel, f in zip(batch_paths, feats):
            out = cache_dir / (rel + ".pt")
            out.parent.mkdir(parents=True, exist_ok=True)
            torch.save(f.contiguous(), out)
            done += 1

    pending_paths, pending_tensors = [], []
    for i, rel in enumerate(image_paths):
        out_path = cache_dir / (rel + ".pt")
        if out_path.is_file() and not args.overwrite:
            skipped += 1
            continue
        try:
            img = Image.open(image_root / rel).convert("RGB")
            t   = transform(img)
            pending_paths.append(rel)
            pending_tensors.append(t)
        except Exception as e:
            failed += 1
            print(f"  [skip] {rel}: {type(e).__name__}: {e}")
            continue

        if len(pending_tensors) >= args.batch_size:
            _flush_batch(pending_paths, pending_tensors)
            pending_paths, pending_tensors = [], []

        if (i + 1) % 500 == 0:
            elapsed = time.time() - t0
            rate    = (done + skipped) / max(elapsed, 1e-6)
            print(f"  [{i+1:>6}/{len(image_paths)}]  done={done}  "
                  f"skipped={skipped}  failed={failed}  ({rate:.1f} img/s)")

    _flush_batch(pending_paths, pending_tensors)

    elapsed = time.time() - t0
    print(f"\nFinished. encoded={done}  skipped(existing)={skipped}  failed={failed}  "
          f"elapsed={elapsed/60:.1f} min")
    print(f"\nNext: set this in configs/train_config.yaml under `data:` ↓")
    print(f"  feature_cache_dir: \"{cache_dir}\"")


if __name__ == "__main__":
    main()