lukeingawesome commited on
Commit
f88f88f
·
verified ·
1 Parent(s): 5944c7c

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CxREmbed (multi-image / multi-text unified embeddings)
2
+
3
+ This repository contains **lightweight inference code + trained embedding heads** for a multi-modal CXR embedding model built on top of the base **Lingshu-7B / Qwen2.5-VL** backbone.
4
+
5
+ The repo is structured to upload only the *delta weights* (LoRA adapter + pooling/projection heads). The base model weights remain in the original upstream repository.
6
+
7
+ ## What is included
8
+
9
+ - `lora/` (optional) — PEFT LoRA adapter weights
10
+ - `unified_pooler.pt` — pooling head
11
+ - `unified_proj.pt` — projection head to the unified embedding space
12
+ - `text_proj.pt` / `image_proj.pt` (optional)
13
+ - `cxrembed_config.json` — minimal configuration
14
+ - `cxrembed/` — small Python package with an inference wrapper
15
+
16
+ ## Quickstart
17
+
18
+ ```python
19
+ import torch
20
+ from cxrembed import CxREmbedder
21
+
22
+ # Download from the Hub and load the backbone + adapters + heads
23
+ m = CxREmbedder.from_pretrained(
24
+ "<ORG>/<REPO>",
25
+ device="cuda" if torch.cuda.is_available() else "cpu",
26
+ amp=True,
27
+ )
28
+
29
+ # Embed a structured record (multi-image + multi-text)
30
+ emb = m.embed_record(
31
+ current_img="/path/to/current_frontal.png",
32
+ lateral_img="/path/to/lateral.png",
33
+ prior_img="/path/to/prior.png",
34
+ additional_img=None,
35
+ prior_report="...",
36
+ current_report="...",
37
+ demographics="Age 67, male",
38
+ lab_test="WBC 12.3",
39
+ history="SOB, fever",
40
+ additional_txt="Question: pneumonia?",
41
+ instruction="Embed this clinical record for retrieval.",
42
+ )
43
+
44
+ # Embed a candidate answer (text-only)
45
+ ans = m.embed_answer("Right lower lobe consolidation consistent with pneumonia.")
46
+
47
+ # Similarity in embedding space
48
+ score = float((emb @ ans.T).item())
49
+ print(score)
50
+ ```
51
+
52
+ ## Placeholders supported in templates
53
+
54
+ Images:
55
+ - `<current_image>` (alias of `<frontal_image>`)
56
+ - `<lateral_image>`
57
+ - `<prior_image>`
58
+ - `<additional_image>`
59
+ - `<additional_image1>`, `<additional_image2>`, ... if you pass a list to `additional_img`
60
+
61
+ Texts:
62
+ - `<current_report>` (alias `<report>`)
63
+ - `<prior_report>`
64
+ - `<demographics>`
65
+ - `<lab_test>`
66
+ - `<history>`
67
+ - `<additional_txt>`
68
+
69
+ ## Notes
70
+
71
+ - This model is intended for **research** and may require additional validation for clinical use.
72
+ - Do not upload protected health information (PHI) to public repositories.
cxrembed/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """cxrembed
2
+
3
+ Lightweight Hugging Face Hub loader + inference utilities for a multi-image, multi-text
4
+ unified embedding model built on top of Lingshu/Qwen2.5-VL.
5
+
6
+ Primary entrypoint:
7
+ - CxREmbedder
8
+
9
+ This package is meant to live inside your Hugging Face model repo (or be vendored into
10
+ another codebase). It assumes the repo contains:
11
+ - lora/ (optional)
12
+ - unified_pooler.pt
13
+ - unified_proj.pt
14
+ - text_proj.pt (optional)
15
+ - image_proj.pt (optional)
16
+ - misc.pt (optional)
17
+ - cxrembed_config.json
18
+ """
19
+
20
+ from .embedder import CxREmbedder, CxRInputs
21
+
22
+ __all__ = ["CxREmbedder", "CxRInputs"]
cxrembed/embedder.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """cxrembed.embedder
3
+
4
+ A small, opinionated inference wrapper around your `LingshuEmbedder` (defined in
5
+ `model_embed.py`) that:
6
+ 1) loads your projection/pooling heads + (optional) LoRA adapter from either
7
+ - a local training checkpoint directory, or
8
+ - a Hugging Face Hub repo (snapshot)
9
+ 2) exposes a clean inference API that accepts explicit multi-image + multi-text inputs.
10
+
11
+ Why this wrapper?
12
+ - Your training/eval code is row/template driven. This wrapper offers an ergonomic
13
+ function signature while keeping the exact same placeholder interleaving path.
14
+ - It avoids implicitly attaching images/text unless a placeholder appears in the template.
15
+
16
+ Key assumptions about the checkpoint format (matching your training script):
17
+ - unified_pooler.pt
18
+ - unified_proj.pt
19
+ - text_proj.pt (optional)
20
+ - image_proj.pt (optional)
21
+ - misc.pt (optional; may include logit_scale)
22
+ - lora/ (optional; PEFT adapter)
23
+
24
+ This file is designed to be copied into your HF repo under `cxrembed/`.
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import json
30
+ import os
31
+ import re
32
+ from dataclasses import dataclass
33
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
34
+
35
+ import torch
36
+
37
+ try:
38
+ from PIL import Image
39
+ except Exception: # pragma: no cover
40
+ Image = Any # type: ignore
41
+
42
+
43
+ # --------------------------- Types ---------------------------
44
+
45
+ ImageLike = Union[str, "Image.Image"] # file path or PIL
46
+
47
+
48
+ @dataclass
49
+ class CxRInputs:
50
+ """Typed container for a single sample."""
51
+
52
+ # Images
53
+ current_img: Optional[ImageLike] = None
54
+ prior_img: Optional[ImageLike] = None
55
+ lateral_img: Optional[ImageLike] = None
56
+ additional_img: Optional[Union[ImageLike, Sequence[ImageLike]]] = None
57
+
58
+ # Texts
59
+ prior_report: Optional[str] = None
60
+ current_report: Optional[str] = None
61
+ demographics: Optional[str] = None
62
+ lab_test: Optional[str] = None
63
+ history: Optional[str] = None
64
+ additional_txt: Optional[str] = None
65
+
66
+ # Optional conditioning
67
+ instruction: Optional[str] = None
68
+
69
+
70
+ # --------------------------- Defaults ---------------------------
71
+
72
+ DEFAULT_RECORD_TEMPLATE = (
73
+ "<current_image> <lateral_image> <prior_image> <additional_image>\n"
74
+ "\n"
75
+ "DEMOGRAPHICS:\n<demographics>\n\n"
76
+ "HISTORY / INDICATION:\n<history>\n\n"
77
+ "LAB TESTS:\n<lab_test>\n\n"
78
+ "PRIOR REPORT:\n<prior_report>\n\n"
79
+ "CURRENT REPORT:\n<current_report>\n\n"
80
+ "<additional_txt>"
81
+ )
82
+
83
+ # We treat these as placeholders for named images.
84
+ # NOTE: LingshuEmbedder already aliases <current_image> -> <frontal_image>.
85
+ _IMAGE_KEYS = (
86
+ "current_image",
87
+ "frontal_image",
88
+ "lateral_image",
89
+ "prior_image",
90
+ "additional_image",
91
+ )
92
+
93
+ _TEXT_KEYS = (
94
+ "prior_report",
95
+ "current_report",
96
+ "report", # alias of current_report
97
+ "demographics",
98
+ "lab_test",
99
+ "history",
100
+ "additional_txt",
101
+ )
102
+
103
+
104
+ # --------------------------- Small helpers ---------------------------
105
+
106
+ _EMPTY_SENTINELS = {"", "-1", "none", "null", "na", "n/a", "nan", "<na>"}
107
+
108
+
109
+ def _clean_text(x: Optional[str]) -> str:
110
+ if x is None:
111
+ return ""
112
+ s = str(x).strip()
113
+ return "" if s.lower() in _EMPTY_SENTINELS else s
114
+
115
+
116
+ def _load_image(x: Optional[ImageLike]) -> Optional["Image.Image"]:
117
+ """Load an image from a path or pass-through a PIL Image."""
118
+ if x is None:
119
+ return None
120
+ if Image is Any:
121
+ raise ImportError("Pillow is required to load images. Please `pip install pillow`.")
122
+ if isinstance(x, str):
123
+ s = x.strip()
124
+ if not s or s.lower() in _EMPTY_SENTINELS:
125
+ return None
126
+ return Image.open(s).convert("RGB")
127
+ # PIL.Image.Image
128
+ if hasattr(x, "convert"):
129
+ return x.convert("RGB")
130
+ raise TypeError(f"Unsupported image type: {type(x)}")
131
+
132
+
133
+ def _tmpl_uses_any_named_image_ph(tmpl: str) -> bool:
134
+ s = (tmpl or "").lower()
135
+ # Cheap check (we avoid compiling regex per sample)
136
+ return any(f"<{k}>" in s for k in _IMAGE_KEYS) or ("<image" in s)
137
+
138
+
139
+ def _tmpl_uses_any_text_ph(tmpl: str) -> bool:
140
+ s = (tmpl or "").lower()
141
+ return any(f"<{k}>" in s for k in _TEXT_KEYS)
142
+
143
+
144
+ def _warn_missing_referenced_images(tmpl: str, image_map: Dict[str, Optional["Image.Image"]]):
145
+ """Warn if the template referenced images but the caller didn't supply them."""
146
+ s = (tmpl or "").lower()
147
+ for k in _IMAGE_KEYS:
148
+ tag = f"<{k}>"
149
+ if tag in s and image_map.get(k) is None:
150
+ # Keep it as print (no logging dependency) since this is a template repo.
151
+ print(f"[cxrembed] WARNING: template references {tag} but it is missing")
152
+
153
+
154
+ def _build_image_map_from_inputs(inp: CxRInputs) -> Dict[str, "Image.Image"]:
155
+ """Build the named image map consumed by LingshuEmbedder._build_content_from_template."""
156
+ cur = _load_image(inp.current_img)
157
+ lat = _load_image(inp.lateral_img)
158
+ prv = _load_image(inp.prior_img)
159
+
160
+ out: Dict[str, Optional["Image.Image"]] = {
161
+ "current_image": cur,
162
+ "frontal_image": cur, # alias
163
+ "lateral_image": lat,
164
+ "prior_image": prv,
165
+ }
166
+
167
+ add = inp.additional_img
168
+ if add is None:
169
+ out["additional_image"] = None
170
+ elif isinstance(add, (list, tuple)):
171
+ # Provide both a generic alias and indexed placeholders: <additional_image1>, ...
172
+ add_list = [_load_image(x) for x in add]
173
+ out["additional_image"] = next((x for x in add_list if x is not None), None)
174
+ for i, im in enumerate(add_list, start=1):
175
+ if im is not None:
176
+ out[f"additional_image{i}"] = im
177
+ else:
178
+ out["additional_image"] = _load_image(add)
179
+
180
+ # Drop None to keep the template builder tight.
181
+ return {k: v for k, v in out.items() if v is not None}
182
+
183
+
184
+ def _build_text_map_from_inputs(inp: CxRInputs) -> Dict[str, str]:
185
+ cur_r = _clean_text(inp.current_report)
186
+ out = {
187
+ "current_report": cur_r,
188
+ "report": cur_r, # alias
189
+ "prior_report": _clean_text(inp.prior_report),
190
+ "demographics": _clean_text(inp.demographics),
191
+ "lab_test": _clean_text(inp.lab_test),
192
+ "history": _clean_text(inp.history),
193
+ "additional_txt": _clean_text(inp.additional_txt),
194
+ }
195
+ return {k: v for k, v in out.items() if v}
196
+
197
+
198
+ def _read_json_if_exists(path: str) -> Dict[str, Any]:
199
+ if not path or not os.path.isfile(path):
200
+ return {}
201
+ try:
202
+ with open(path, "r", encoding="utf-8") as f:
203
+ return json.load(f)
204
+ except Exception:
205
+ return {}
206
+
207
+
208
+ # --------------------------- Main wrapper ---------------------------
209
+
210
+
211
+ class CxREmbedder:
212
+ """Inference wrapper.
213
+
214
+ Loads your Lingshu/Qwen2.5-VL backbone + (optional) LoRA + your embedding heads.
215
+
216
+ Minimal API:
217
+ - embed_record(...): embed a structured multi-modal record.
218
+ - embed_answer(text): embed a candidate answer in the same space.
219
+ - embed(...): lower-level template-based embedding.
220
+
221
+ Notes:
222
+ - The embedding path is **placeholder-aware**: images/texts are only attached
223
+ if the template includes the corresponding placeholders.
224
+ - The `mask_style` controls which tokens are pooled.
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ model, # LingshuEmbedder (kept untyped to avoid hard dependency at import-time)
230
+ device: Union[str, torch.device] = "cuda",
231
+ amp: bool = True,
232
+ ):
233
+ self.model = model
234
+ self.device = torch.device(device)
235
+ self.amp = bool(amp)
236
+
237
+ # Ensure eval mode
238
+ self.model.eval()
239
+
240
+ # ---------------------- constructors ----------------------
241
+
242
+ @classmethod
243
+ def from_local_checkpoint(
244
+ cls,
245
+ ckpt_dir: str,
246
+ *,
247
+ base_model_name: Optional[str] = None,
248
+ device: Union[str, torch.device] = "cuda",
249
+ amp: bool = True,
250
+ embed_dim: Optional[int] = None,
251
+ pool_mode: Optional[str] = None,
252
+ image_size: int = 504,
253
+ max_text_tokens: int = 1560,
254
+ apply_lora_to_vision: bool = False,
255
+ bidirectional: bool = True,
256
+ ) -> "CxREmbedder":
257
+ """Load from a training checkpoint directory.
258
+
259
+ Expected contents:
260
+ - unified_pooler.pt, unified_proj.pt, (optional) text_proj.pt, image_proj.pt
261
+ - misc.pt (optional)
262
+ - lora/ (optional)
263
+ - cxrembed_config.json (optional)
264
+ """
265
+
266
+ ckpt_dir = os.path.abspath(ckpt_dir)
267
+ cfg_path = os.path.join(ckpt_dir, "cxrembed_config.json")
268
+ cfg = _read_json_if_exists(cfg_path)
269
+
270
+ # Defer import so the package can be inspected without transformers installed.
271
+ from model_embed import LingshuEmbedder # type: ignore
272
+
273
+ if base_model_name is None:
274
+ base_model_name = cfg.get("base_model_name") or cfg.get("model_name") or "lingshu-medical-mllm/Lingshu-7B"
275
+
276
+ if embed_dim is None:
277
+ embed_dim = int(cfg.get("embed_dim", 1280))
278
+ if pool_mode is None:
279
+ pool_mode = str(cfg.get("pool_mode", "latent_attention"))
280
+
281
+ # If CUDA isn't available, force CPU.
282
+ dev = torch.device(device)
283
+ if dev.type == "cuda" and not torch.cuda.is_available():
284
+ print("[cxrembed] CUDA requested but not available; falling back to CPU")
285
+ dev = torch.device("cpu")
286
+
287
+ use_cuda = (dev.type == "cuda")
288
+
289
+ # IMPORTANT: if a LoRA folder exists, build the base backbone WITHOUT LoRA first,
290
+ # then load LoRA adapter weights with PEFT.
291
+ lora_dir = os.path.join(ckpt_dir, "lora")
292
+ force_no_lora = os.path.isdir(lora_dir)
293
+
294
+ m = LingshuEmbedder(
295
+ model_name=base_model_name,
296
+ attn_implementation=("flash_attention_2" if use_cuda else "sdpa"),
297
+ torch_dtype=(torch.bfloat16 if use_cuda else torch.float32),
298
+ embed_dim=int(embed_dim),
299
+ pool_mode=str(pool_mode),
300
+ image_size=int(image_size),
301
+ max_grid=1296,
302
+ bidirectional=bool(bidirectional),
303
+ use_lora=(False if force_no_lora else False), # never build train-time LoRA modules for inference
304
+ apply_lora_to_vision=bool(apply_lora_to_vision),
305
+ max_text_tokens=int(max_text_tokens),
306
+ enable_gradient_checkpointing=False,
307
+ device=str(dev),
308
+ )
309
+
310
+ # Load LoRA adapter (optional)
311
+ if os.path.isdir(lora_dir):
312
+ try:
313
+ from peft import PeftModel # type: ignore
314
+
315
+ m.vl = PeftModel.from_pretrained(m.vl, lora_dir, is_trainable=False)
316
+ if hasattr(m.vl, "set_adapter"):
317
+ m.vl.set_adapter("default")
318
+ print(f"[cxrembed] loaded LoRA adapter from: {lora_dir}")
319
+ except Exception as e:
320
+ print(f"[cxrembed] WARNING: failed to load LoRA adapter from {lora_dir}: {e}")
321
+
322
+ # Load heads (strict)
323
+ _load_heads(m, ckpt_dir, device=dev)
324
+
325
+ return cls(model=m, device=dev, amp=amp)
326
+
327
+ @classmethod
328
+ def from_pretrained(
329
+ cls,
330
+ repo_id: str,
331
+ *,
332
+ revision: Optional[str] = None,
333
+ cache_dir: Optional[str] = None,
334
+ device: Union[str, torch.device] = "cuda",
335
+ amp: bool = True,
336
+ **kwargs,
337
+ ) -> "CxREmbedder":
338
+ """Load from a Hugging Face Hub repo.
339
+
340
+ The repo should contain the same files as from_local_checkpoint().
341
+
342
+ Under the hood we `snapshot_download()` and then call from_local_checkpoint().
343
+ """
344
+ try:
345
+ from huggingface_hub import snapshot_download # type: ignore
346
+ except Exception as e: # pragma: no cover
347
+ raise ImportError("Please `pip install huggingface_hub` to load from HF.") from e
348
+
349
+ local_dir = snapshot_download(
350
+ repo_id=repo_id,
351
+ revision=revision,
352
+ cache_dir=cache_dir,
353
+ local_files_only=False,
354
+ )
355
+ return cls.from_local_checkpoint(local_dir, device=device, amp=amp, **kwargs)
356
+
357
+ # ---------------------- public embedding API ----------------------
358
+
359
+ @torch.no_grad()
360
+ def embed_record(
361
+ self,
362
+ *,
363
+ current_img: Optional[ImageLike] = None,
364
+ prior_img: Optional[ImageLike] = None,
365
+ lateral_img: Optional[ImageLike] = None,
366
+ additional_img: Optional[Union[ImageLike, Sequence[ImageLike]]] = None,
367
+ prior_report: Optional[str] = None,
368
+ current_report: Optional[str] = None,
369
+ demographics: Optional[str] = None,
370
+ lab_test: Optional[str] = None,
371
+ history: Optional[str] = None,
372
+ additional_txt: Optional[str] = None,
373
+ instruction: Optional[str] = None,
374
+ template: str = DEFAULT_RECORD_TEMPLATE,
375
+ image_size: Optional[int] = None,
376
+ mask_style: str = "q_full_a_last",
377
+ normalize: bool = True,
378
+ ) -> torch.Tensor:
379
+ """Embed a single multi-modal record into the unified embedding space."""
380
+ inp = CxRInputs(
381
+ current_img=current_img,
382
+ prior_img=prior_img,
383
+ lateral_img=lateral_img,
384
+ additional_img=additional_img,
385
+ prior_report=prior_report,
386
+ current_report=current_report,
387
+ demographics=demographics,
388
+ lab_test=lab_test,
389
+ history=history,
390
+ additional_txt=additional_txt,
391
+ instruction=instruction,
392
+ )
393
+ return self.embed(
394
+ inputs=[inp],
395
+ templates=[template],
396
+ role="user",
397
+ image_size=image_size,
398
+ mask_style=mask_style,
399
+ normalize=normalize,
400
+ )[0]
401
+
402
+ @torch.no_grad()
403
+ def embed_answer(
404
+ self,
405
+ answer: str,
406
+ *,
407
+ normalize: bool = True,
408
+ mask_style: str = "q_full_a_last",
409
+ ) -> torch.Tensor:
410
+ """Embed a candidate answer string in the unified embedding space."""
411
+ inp = CxRInputs()
412
+ return self.embed(
413
+ inputs=[inp],
414
+ templates=[str(answer)],
415
+ role="assistant",
416
+ image_size=None,
417
+ mask_style=mask_style,
418
+ normalize=normalize,
419
+ )[0]
420
+
421
+ @torch.no_grad()
422
+ def embed(
423
+ self,
424
+ *,
425
+ inputs: Sequence[CxRInputs],
426
+ templates: Sequence[str],
427
+ role: str,
428
+ image_size: Optional[int] = None,
429
+ mask_style: str = "q_full_a_last",
430
+ normalize: bool = True,
431
+ ) -> torch.Tensor:
432
+ """Low-level batched embedding with templates.
433
+
434
+ Args:
435
+ inputs: list of CxRInputs
436
+ templates: list of templates (same length)
437
+ role: "user" or "assistant"
438
+ image_size: optional override (must be multiple of 28 for Qwen grid)
439
+ mask_style:
440
+ - q_full_a_last: queries use full attention (system+user), assistant uses last role block.
441
+ - both_full: both sides use full attention.
442
+ - both_last: both sides use last role block.
443
+ Returns:
444
+ torch.FloatTensor [B, D] on CPU.
445
+ """
446
+
447
+ if role not in {"user", "assistant"}:
448
+ raise ValueError("role must be 'user' or 'assistant'")
449
+ if len(inputs) != len(templates):
450
+ raise ValueError("inputs and templates must have the same length")
451
+
452
+ wrapper_model = self.model
453
+ device = self.device
454
+
455
+ # We rely on internals from LingshuEmbedder.
456
+ vm = wrapper_model._get_vision_module()
457
+ vision_dtype = next(vm.parameters()).dtype
458
+
459
+ # Determine target image size (flooring to multiple of 28 happens inside to_qwen_grid).
460
+ target = wrapper_model._target_from_image_size(image_size)
461
+
462
+ texts: List[str] = []
463
+ flat_images: List["Image.Image"] = []
464
+
465
+ for inp, tmpl_raw in zip(inputs, templates):
466
+ tmpl = str(tmpl_raw or "")
467
+
468
+ want_img = _tmpl_uses_any_named_image_ph(tmpl)
469
+ want_txt = _tmpl_uses_any_text_ph(tmpl)
470
+
471
+ imap = _build_image_map_from_inputs(inp) if want_img else {}
472
+ tmap = _build_text_map_from_inputs(inp) if want_txt else {}
473
+
474
+ _warn_missing_referenced_images(tmpl, {k: imap.get(k) for k in _IMAGE_KEYS})
475
+
476
+ # Resize only the images actually present
477
+ if want_img and imap:
478
+ # Import here to reuse your exact resizing behavior.
479
+ from model_embed import to_qwen_grid # type: ignore
480
+
481
+ imap = {k.lower(): to_qwen_grid(im, target=target) for k, im in imap.items() if im is not None}
482
+ else:
483
+ imap = {}
484
+
485
+ content_list, images_in_order = wrapper_model._build_content_from_template(
486
+ tmpl, image_map=imap, text_map=tmap, append_unused_images=False
487
+ )
488
+
489
+ # Ensure last-role masking is stable (non-empty role block)
490
+ if not content_list:
491
+ content_list = [{"type": "text", "text": " "}]
492
+
493
+ msgs = []
494
+ if role == "user":
495
+ inst = _clean_text(inp.instruction)
496
+ if inst:
497
+ msgs.append({"role": "system", "content": [{"type": "text", "text": f"INSTRUCTION:\n{inst}"}]})
498
+ msgs.append({"role": role, "content": content_list})
499
+
500
+ chat_text = wrapper_model.processor.apply_chat_template(
501
+ msgs, tokenize=False, add_generation_prompt=False
502
+ )
503
+ texts.append(chat_text)
504
+ for im in images_in_order:
505
+ flat_images.append(im)
506
+
507
+ proc = wrapper_model.processor(
508
+ text=texts,
509
+ images=(flat_images if len(flat_images) > 0 else None),
510
+ return_tensors="pt",
511
+ padding=True,
512
+ truncation=True,
513
+ max_length=getattr(wrapper_model, "max_text_tokens", 2560),
514
+ do_resize=False,
515
+ )
516
+ proc = {k: v.to(device) for k, v in proc.items()}
517
+ if "pixel_values" in proc:
518
+ proc["pixel_values"] = proc["pixel_values"].to(device=device, dtype=vision_dtype)
519
+ if "image_grid_thw" in proc:
520
+ proc["image_grid_thw"] = proc["image_grid_thw"].to(device)
521
+
522
+ autocast_dtype = torch.bfloat16 if (self.amp and device.type == "cuda") else None
523
+ with torch.autocast(
524
+ device_type="cuda",
525
+ dtype=autocast_dtype,
526
+ enabled=(autocast_dtype is not None and device.type == "cuda"),
527
+ ):
528
+ out = wrapper_model.vl(**proc, output_hidden_states=True, use_cache=False)
529
+ hidden = out.hidden_states[-1]
530
+
531
+ # Pooling mask
532
+ if mask_style == "both_full" or (mask_style == "q_full_a_last" and role == "user"):
533
+ attn = proc.get("attention_mask", None)
534
+ span_mask = attn.bool() if attn is not None else torch.ones(hidden.shape[:2], device=hidden.device, dtype=torch.bool)
535
+ elif mask_style in {"both_last", "q_full_a_last"}:
536
+ span_mask = wrapper_model._mask_last_role_block(proc, hidden)
537
+ else:
538
+ raise ValueError(f"Unknown mask_style: {mask_style}")
539
+
540
+ if getattr(wrapper_model, "pool_mode", "latent_attention") == "latent_attention":
541
+ pooler = wrapper_model.unified_pooler
542
+ pool_dtype = next(pooler.parameters()).dtype
543
+ if hidden.dtype != pool_dtype:
544
+ hidden = hidden.to(dtype=pool_dtype)
545
+ vec = pooler(hidden, span_mask)
546
+ else:
547
+ mask = span_mask.to(hidden.dtype)
548
+ denom = mask.sum(dim=1, keepdim=True).clamp_min(1e-6)
549
+ vec = (hidden * mask.unsqueeze(-1)).sum(dim=1) / denom
550
+
551
+ proj = wrapper_model.unified_proj
552
+ proj_dtype = next(proj.parameters()).dtype
553
+ emb = proj(vec.to(dtype=proj_dtype))
554
+
555
+ if normalize:
556
+ emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12)
557
+
558
+ return emb.detach().float().cpu()
559
+
560
+
561
+ # --------------------------- checkpoint loading ---------------------------
562
+
563
+
564
+ def _load_state_dict_strict(module: torch.nn.Module, path: str, device: torch.device):
565
+ state = torch.load(path, map_location=device)
566
+ if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()):
567
+ state = {k.replace("module.", "", 1): v for k, v in state.items()}
568
+ module.load_state_dict(state, strict=True)
569
+
570
+
571
+ def _load_heads(model, ckpt_dir: str, device: torch.device):
572
+ """Load your trained pooling/projection heads into a freshly constructed LingshuEmbedder."""
573
+
574
+ # files are named consistently with your training save_checkpoint() helper.
575
+ heads: List[Tuple[str, Optional[torch.nn.Module]]] = [
576
+ ("unified_pooler.pt", getattr(model, "unified_pooler", None)),
577
+ ("unified_proj.pt", getattr(model, "unified_proj", None)),
578
+ ("text_proj.pt", getattr(model, "text_proj", None)),
579
+ ("image_proj.pt", getattr(model, "image_proj", None)),
580
+ ]
581
+
582
+ for fname, mod in heads:
583
+ if mod is None:
584
+ continue
585
+ p = os.path.join(ckpt_dir, fname)
586
+ if os.path.isfile(p):
587
+ _load_state_dict_strict(mod, p, device=device)
588
+ print(f"[cxrembed] loaded {fname}")
589
+
590
+ # Optional misc (temperature)
591
+ misc_p = os.path.join(ckpt_dir, "misc.pt")
592
+ if os.path.isfile(misc_p):
593
+ try:
594
+ misc = torch.load(misc_p, map_location="cpu")
595
+ if isinstance(misc, dict) and "logit_scale" in misc and hasattr(model, "logit_scale"):
596
+ with torch.no_grad():
597
+ model.logit_scale.data = torch.tensor(float(misc["logit_scale"]), dtype=torch.float32, device=device)
598
+ except Exception:
599
+ pass
600
+
cxrembed_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model_name": "lingshu-medical-mllm/Lingshu-7B",
3
+ "embed_dim": 1280,
4
+ "pool_mode": "latent_attention",
5
+ "image_size": 504,
6
+ "max_text_tokens": 1560,
7
+ "format": "cxrembed_adapter_v1"
8
+ }
model_embed.py ADDED
@@ -0,0 +1,1067 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import re
4
+ import json
5
+ from typing import List, Optional, Dict, Tuple, Union
6
+
7
+ from PIL import Image
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
12
+
13
+ # Treat these as empty/missing (case-insensitive, whitespace-tolerant)
14
+ _EMPTY_SENTINELS = {"", "-1", "none", "null", "na", "n/a", "nan", "<na>"}
15
+
16
+ def _is_empty_cell(x) -> bool:
17
+ """True if x should be considered 'missing'."""
18
+ if x is None:
19
+ return True
20
+ # float('nan') and numpy.float64('nan')
21
+ try:
22
+ if isinstance(x, float) and math.isnan(x):
23
+ return True
24
+ except Exception:
25
+ pass
26
+ s = str(x).strip().lower()
27
+ return s in _EMPTY_SENTINELS
28
+
29
+ def _clean_text_or_empty(x) -> str:
30
+ """Return a clean string or '' if missing."""
31
+ return "" if _is_empty_cell(x) else str(x).strip()
32
+
33
+ try:
34
+ from peft import LoraConfig, get_peft_model
35
+ HAS_PEFT = True
36
+ except Exception:
37
+ HAS_PEFT = False
38
+
39
+
40
+ # ----------------------- misc utils -----------------------
41
+
42
+ def l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-12) -> torch.Tensor:
43
+ return x / (x.norm(dim=dim, keepdim=True) + eps)
44
+
45
+ def masked_mean_pool(hidden: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
46
+ """Mean over tokens where mask==True."""
47
+ if mask is None:
48
+ return hidden.mean(dim=1)
49
+ mask = mask.to(hidden.dtype)
50
+ denom = mask.sum(dim=1, keepdim=True).clamp_min(1e-6)
51
+ return (hidden * mask.unsqueeze(-1)).sum(dim=1) / denom
52
+
53
+ def to_qwen_grid(img: Image.Image, target: int = 512, patch_size: int = 14, merge_size: int = 2) -> Image.Image:
54
+ """
55
+ Resize image so H=W is a multiple of 28 (=patch_size*merge_size).
56
+ FLOOR to nearest multiple (512->504, 1024->1008).
57
+ """
58
+ grid = patch_size * merge_size # 28
59
+ new = max(grid, (target // grid) * grid)
60
+ return img.resize((new, new), Image.BILINEAR)
61
+
62
+ def _open_or_none(path: object, root: str = "") -> Optional[Image.Image]:
63
+ """Returns a PIL.Image or None. Handles '', NaN, '-1', <NA>, etc."""
64
+ if _is_empty_cell(path):
65
+ return None
66
+ p = str(path).strip()
67
+ # Don't join URI-like paths
68
+ if root and not re.match(r'^[a-zA-Z][a-zA-Z0-9+\-.]*://', p):
69
+ p = os.path.join(root, p)
70
+ try:
71
+ return Image.open(p).convert("RGB")
72
+ except Exception:
73
+ return None
74
+
75
+ def build_image_map_from_row(row, root: str = "") -> dict:
76
+ """
77
+ Mapping per your schema:
78
+ - frontal_image <- img_path1 (also used as current_image)
79
+ - lateral_image <- img_path2
80
+ - prior_image <- img_path3
81
+ """
82
+ m = {
83
+ "frontal_image": _open_or_none(str(row.get("img_path1", "-1")), root),
84
+ "lateral_image": _open_or_none(str(row.get("img_path2", "-1")), root),
85
+ "prior_image": _open_or_none(str(row.get("img_path3", "-1")), root),
86
+ }
87
+ # --- NEW: negative images available to templates ---
88
+ n1 = _open_or_none(str(row.get("neg_image1", row.get("neg_path1", "-1"))), root)
89
+ n2 = _open_or_none(str(row.get("neg_image2", row.get("neg_path2", "-1"))), root)
90
+ # support either column name for prior: neg_image3 or neg_prior_image, also neg_path3
91
+ n3 = _open_or_none(str(row.get("neg_image3", row.get("neg_prior_image", row.get("neg_path3", "-1")))), root)
92
+ if n1 is not None:
93
+ m.update({"neg_image1": n1, "neg_path1": n1, "neg_frontal_image": n1})
94
+ if n2 is not None:
95
+ m.update({"neg_image2": n2, "neg_path2": n2, "neg_lateral_image": n2})
96
+ if n3 is not None:
97
+ m.update({"neg_prior_image": n3, "neg_image3": n3, "neg_path3": n3})
98
+ return m
99
+
100
+ def _s(x): return "" if x is None else str(x)
101
+
102
+ def build_text_map_from_row(row) -> Dict[str, str]:
103
+ m = {
104
+ "report": _clean_text_or_empty(row.get("report")),
105
+ "prior_report": _clean_text_or_empty(row.get("prior_report")),
106
+ "demographics": _clean_text_or_empty(row.get("demographics")),
107
+ # --- NEW ---
108
+ "lab_test": _clean_text_or_empty(row.get("lab_test")),
109
+ "indication": _clean_text_or_empty(row.get("indication")),
110
+ }
111
+ # drop empties
112
+ return {k: v for k, v in m.items() if v}
113
+
114
+ def parse_text_placeholders(s) -> dict:
115
+ if isinstance(s, dict):
116
+ d = s
117
+ elif isinstance(s, str) and s.strip():
118
+ try:
119
+ d = json.loads(s)
120
+ except Exception:
121
+ d = {}
122
+ else:
123
+ d = {}
124
+ if not isinstance(d, dict):
125
+ return {}
126
+ out = {}
127
+ for k, v in d.items():
128
+ val = _clean_text_or_empty(v)
129
+ if val:
130
+ out[str(k).lower()] = val
131
+ return out
132
+
133
+
134
+ # ----------------------- pooling modules -----------------------
135
+
136
+ class LatentAttentionPooler(nn.Module):
137
+ """
138
+ NV-Embed style: tokens (Q) attend to trainable latents (K=V), then MLP,
139
+ then mean-pool over tokens (optionally masked).
140
+ """
141
+ def __init__(self, dim: int, num_latents: int = 512, num_layers: int = 1,
142
+ num_heads: int = 8, mlp_ratio: float = 2.0):
143
+ super().__init__()
144
+ self.latents = nn.Parameter(torch.randn(num_latents, dim) / math.sqrt(dim))
145
+ self.layers = nn.ModuleList()
146
+ self.ln_q = nn.LayerNorm(dim) # for token queries
147
+ self.ln_kv = nn.LayerNorm(dim) # for latent K/V
148
+
149
+ for _ in range(num_layers):
150
+ attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
151
+ ffn = nn.Sequential(
152
+ nn.Linear(dim, int(dim * mlp_ratio)),
153
+ nn.GELU(),
154
+ nn.Linear(int(dim * mlp_ratio), dim),
155
+ )
156
+ self.layers.append(nn.ModuleDict({"attn": attn, "ffn": ffn}))
157
+
158
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
159
+ # x: (B, S, D) last-layer token states from the LLM
160
+ B, S, D = x.shape
161
+
162
+ # Prepare Q (tokens) and K,V (trainable latents)
163
+ q = self.ln_q(x)
164
+ lat = self.latents.unsqueeze(0).expand(B, -1, -1).contiguous()
165
+ kv = self.ln_kv(lat)
166
+
167
+ # Cross-attn: tokens query the latent dictionary (no key padding mask on latents)
168
+ for blk in self.layers:
169
+ y = blk["attn"](q, kv, kv, need_weights=False)[0]
170
+ q = q + y # residual
171
+ q = q + blk["ffn"](q) # MLP + residual
172
+
173
+ # Mean-pool over **tokens**; mask only applied here
174
+ return masked_mean_pool(q, mask) # (B, D)
175
+
176
+ class Projection(nn.Module):
177
+ def __init__(self, in_dim: int, out_dim: int = 1024, hidden: Optional[int] = None):
178
+ super().__init__()
179
+ if hidden is None:
180
+ self.proj = nn.Sequential(nn.Linear(in_dim, out_dim, bias=False))
181
+ else:
182
+ self.proj = nn.Sequential(nn.Linear(in_dim, hidden), nn.GELU(), nn.Linear(hidden, out_dim, bias=False))
183
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
184
+ return l2norm(self.proj(x))
185
+
186
+
187
+ # ----------------------- main wrapper -----------------------
188
+
189
+ class LingshuEmbedder(nn.Module):
190
+ def __init__(
191
+ self,
192
+ model_name: str = "lingshu-medical-mllm/Lingshu-7B",
193
+ attn_implementation: str = "flash_attention_2",
194
+ torch_dtype: torch.dtype = torch.bfloat16,
195
+ embed_dim: int = 1024,
196
+
197
+ # unified pooling mode
198
+ pool_mode: str = "latent_attention", # "latent_attention" | "mean"
199
+ num_latents_unified: int = 512,
200
+
201
+ # image grid control (supports 504 and 1008)
202
+ image_size: int = 504, # default grid; per-call override allowed (504 or 1008)
203
+ min_grid: int = 256,
204
+ max_grid: int = 1296, # up to 36x36 (for 1008)
205
+
206
+ # LoRA (optional) - tuned for memorization
207
+ # r=64 for balanced performance; increase to 128 if VRAM allows
208
+ use_lora: bool = False,
209
+ lora_r: int = 64, lora_alpha: int = 64, lora_dropout: float = 0.0, # alpha=r, dropout=0 for memorization
210
+ apply_lora_to_vision: bool = False,
211
+
212
+ # make attention bi-directional (remove causal masking)
213
+ bidirectional: bool = True,
214
+
215
+ # text token budget (read by the training script)
216
+ max_text_tokens: int = 2560,
217
+
218
+ # gradient checkpointing
219
+ enable_gradient_checkpointing: bool = False,
220
+
221
+ device: Optional[Union[str, torch.device]] = None,
222
+ ) -> None:
223
+ super().__init__()
224
+
225
+ # ---- device & backend ----
226
+ if device is None:
227
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
228
+ else:
229
+ device = torch.device(device)
230
+ if device.type != "cuda":
231
+ attn_implementation = "sdpa"
232
+ if torch_dtype in (torch.float16, torch.bfloat16):
233
+ torch_dtype = torch.float32
234
+
235
+ # ---- load backbone + processor ----
236
+ self.vl = Qwen2_5_VLForConditionalGeneration.from_pretrained(
237
+ model_name, torch_dtype=torch_dtype, attn_implementation=attn_implementation
238
+ )
239
+ self.processor = AutoProcessor.from_pretrained(
240
+ model_name,
241
+ min_pixels=min_grid * 28 * 28,
242
+ max_pixels=max_grid * 28 * 28,
243
+ )
244
+ self._propagate_attn_impl(attn_implementation)
245
+
246
+ # freeze base
247
+ for p in self.vl.parameters():
248
+ p.requires_grad_(False)
249
+
250
+ # UNFREEZE vision projector for better image→text binding
251
+ # Qwen2.5-VL has a visual projection module
252
+ unfrozen_modules = []
253
+ for name, module in self.vl.named_modules():
254
+ # Look for vision projector: often named 'visual', 'vision_proj', 'mm_projector', etc.
255
+ if any(x in name.lower() for x in ['visual.merger', 'visual.proj', 'vision_proj', 'mm_projector']):
256
+ n_params = sum(p.numel() for p in module.parameters())
257
+ for p in module.parameters():
258
+ p.requires_grad_(True)
259
+ unfrozen_modules.append((name, n_params))
260
+
261
+ if unfrozen_modules:
262
+ print(f"[model] Unfrozen vision projector modules for memorization:")
263
+ for name, n_params in unfrozen_modules:
264
+ print(f" - {name}: {n_params:,} parameters")
265
+
266
+ # dims
267
+ txt_hidden = getattr(self.vl.config, "text_config", None)
268
+ vis_hidden = getattr(self.vl.config, "vision_config", None)
269
+ self.text_hidden = getattr(txt_hidden, "hidden_size", None)
270
+ self.vision_hidden = getattr(vis_hidden, "out_hidden_size", None) or getattr(vis_hidden, "hidden_size", None)
271
+
272
+ # projections (unified/text/image all project to same embed_dim space)
273
+ self.text_proj = Projection(self.text_hidden, embed_dim, hidden=None)
274
+ self.image_proj = Projection(self.vision_hidden, embed_dim, hidden=None)
275
+ self.unified_proj = Projection(self.text_hidden, embed_dim, hidden=None)
276
+
277
+ self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))
278
+
279
+ # unified pooling config
280
+ self.pool_mode = pool_mode
281
+ if self.pool_mode == "latent_attention":
282
+ self.unified_pooler = LatentAttentionPooler(
283
+ dim=self.text_hidden,
284
+ num_latents=num_latents_unified, # set default to 512 to match paper
285
+ num_layers=1,
286
+ num_heads=8
287
+ )
288
+ else:
289
+ self.unified_pooler = None
290
+
291
+ # image size handling (any multiple of 28 is allowed, e.g., 448, 504, 1008)
292
+ if image_size % 28 != 0:
293
+ raise ValueError(f"image_size must be a multiple of 28, got {image_size}")
294
+ self.image_size = image_size # default; can override per call
295
+
296
+ # optional LoRA
297
+ self.peft_active = False
298
+ if use_lora:
299
+ if not HAS_PEFT:
300
+ raise ImportError("peft not installed")
301
+ targets_text = ("q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj")
302
+ targets_vision = ("qkv", "proj")
303
+ targets = list(set(targets_text + (targets_vision if apply_lora_to_vision else tuple())))
304
+ cfg = LoraConfig(r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
305
+ target_modules=targets, bias="none", task_type="CAUSAL_LM")
306
+ self.vl = get_peft_model(self.vl, cfg)
307
+ self.peft_active = True
308
+
309
+ # make bi-directional if requested
310
+ if bidirectional:
311
+ self._enable_bidirectional_attention()
312
+
313
+ # gradient checkpointing
314
+ if enable_gradient_checkpointing:
315
+ # Use the non-reentrant variant to avoid "requires_grad" warnings
316
+ try:
317
+ self.vl.gradient_checkpointing_enable(
318
+ gradient_checkpointing_kwargs={"use_reentrant": False}
319
+ )
320
+ except TypeError:
321
+ # older transformers fallback
322
+ self.vl.gradient_checkpointing_enable()
323
+ try:
324
+ self.vl.config.use_cache = False
325
+ except Exception:
326
+ pass
327
+
328
+ # move to device
329
+ self.to(device)
330
+ self.device = device
331
+
332
+ # align pooler dtype with model (and device)
333
+ base_dtype = next(self.vl.parameters()).dtype
334
+ if getattr(self, "unified_pooler", None) is not None:
335
+ self.unified_pooler.to(device=device, dtype=base_dtype)
336
+
337
+ # expose text token budget for processor calls in training script
338
+ self.max_text_tokens = int(max_text_tokens)
339
+
340
+ # ---------- internals ----------
341
+
342
+ def _propagate_attn_impl(self, impl: str):
343
+ cfgs = [getattr(self.vl, "config", None)]
344
+ if cfgs[0] is not None:
345
+ for sub in ("text_config", "vision_config"):
346
+ cfgs.append(getattr(cfgs[0], sub, None))
347
+ for cfg in cfgs:
348
+ if cfg is None:
349
+ continue
350
+ try:
351
+ cfg._attn_implementation = impl
352
+ cfg.attn_implementation = impl
353
+ if hasattr(cfg, "use_flash_attention_2"):
354
+ cfg.use_flash_attention_2 = (impl == "flash_attention_2")
355
+ except Exception:
356
+ pass
357
+ for _, module in self.vl.named_modules():
358
+ if hasattr(module, "config"):
359
+ try:
360
+ module.config._attn_implementation = impl
361
+ module.config.attn_implementation = impl
362
+ if hasattr(module.config, "use_flash_attention_2"):
363
+ module.config.use_flash_attention_2 = (impl == "flash_attention_2")
364
+ except Exception:
365
+ pass
366
+
367
+ def _enable_bidirectional_attention(self):
368
+ """Best-effort removal of causal masking."""
369
+ cfg = getattr(self.vl, "config", None)
370
+ if cfg is not None:
371
+ if hasattr(cfg, "is_decoder"): cfg.is_decoder = False
372
+ if hasattr(cfg, "use_cache"): cfg.use_cache = False
373
+ core = getattr(self.vl, "model", self.vl)
374
+ core_cfg = getattr(core, "config", None)
375
+ if core_cfg is not None:
376
+ if hasattr(core_cfg, "is_decoder"): core_cfg.is_decoder = False
377
+ if hasattr(core_cfg, "use_cache"): core_cfg.use_cache = False
378
+ for m in self.vl.modules():
379
+ if hasattr(m, "is_causal"):
380
+ try:
381
+ m.is_causal = False
382
+ except Exception:
383
+ pass
384
+
385
+ def _get_text_module(self):
386
+ core = getattr(self.vl, "model", self.vl)
387
+ for attr in ("language_model", "text_model", "lm"):
388
+ m = getattr(core, attr, None)
389
+ if m is not None and hasattr(m, "forward"):
390
+ return m
391
+ for _, module in self.vl.named_modules():
392
+ cname = module.__class__.__name__.lower()
393
+ if "vision" in cname:
394
+ continue
395
+ if hasattr(module, "forward") and hasattr(module, "embed_tokens"):
396
+ return module
397
+ raise AttributeError("Could not locate the text submodule in Qwen-VL.")
398
+
399
+ def _get_vision_module(self):
400
+ core = getattr(self.vl, "model", self.vl)
401
+ for attr in ("vision_model", "vision_tower", "visual", "vision"):
402
+ m = getattr(core, attr, None)
403
+ if m is not None and hasattr(m, "forward"):
404
+ return m
405
+ for _, module in self.vl.named_modules():
406
+ if "vision" in module.__class__.__name__.lower():
407
+ return module
408
+ raise AttributeError("Could not locate the vision submodule in Qwen-VL.")
409
+
410
+ def _get_vision_entry(self):
411
+ """
412
+ Return the top-level VisionModel object that accepts:
413
+ forward(pixel_values=..., grid_thw=..., output_hidden_states=..., return_dict=True)
414
+ Avoid returning the inner transformer which expects (hidden_states, grid_thw).
415
+ """
416
+ core = getattr(self.vl, "model", self.vl)
417
+ # Prefer the canonical attribute if present
418
+ vis = getattr(core, "vision_model", None)
419
+ if vis is not None:
420
+ return vis
421
+ # Fallback: search modules for something named *VisionModel
422
+ for _, m in core.named_modules():
423
+ name = m.__class__.__name__.lower()
424
+ if name.endswith("visionmodel"):
425
+ return m
426
+ # Last resort: previous generic getter (may return transformer; not ideal)
427
+ return self._get_vision_module()
428
+
429
+ # ----- chat/content builders & masking -----
430
+
431
+ def _target_from_image_size(self, image_size: Optional[int]) -> int:
432
+ """
433
+ Return a pixel target that will be floored to a multiple of 28 by to_qwen_grid().
434
+ Any multiple of 28 works (e.g., 448, 504, 1008).
435
+ """
436
+ sz = image_size if isinstance(image_size, int) and image_size % 28 == 0 else self.image_size
437
+ return int(sz)
438
+
439
+ def _build_interleaved_content(self, text: str, imgs: List[Image.Image], append_unused_images: bool = False) -> Tuple[list, list]:
440
+ """
441
+ NUMERIC placeholders: <image1>, <image2>, ...
442
+ Returns (content_list, images_in_order).
443
+ """
444
+ if text is None:
445
+ text = ""
446
+ content: list = []
447
+ ordered_images: list = []
448
+ imgs = imgs or []
449
+
450
+ pat = re.compile(r"<image\s*(\d+)\s*>", re.IGNORECASE)
451
+ pos = 0
452
+ matches = list(pat.finditer(text))
453
+
454
+ if not matches:
455
+ # Do not auto-append images unless explicitly requested
456
+ if text.strip():
457
+ content.append({"type": "text", "text": text})
458
+ if append_unused_images:
459
+ for im in imgs:
460
+ content.append({"type": "image", "image": im})
461
+ ordered_images.append(im)
462
+ return content, ordered_images
463
+
464
+ for m in matches:
465
+ s, e = m.span()
466
+ if s > pos:
467
+ seg = text[pos:s]
468
+ if seg.strip():
469
+ content.append({"type": "text", "text": seg})
470
+ idx = int(m.group(1)) - 1
471
+ if 0 <= idx < len(imgs):
472
+ content.append({"type": "image", "image": imgs[idx]})
473
+ ordered_images.append(imgs[idx])
474
+ pos = e
475
+
476
+ if pos < len(text):
477
+ seg = text[pos:]
478
+ if seg.strip():
479
+ content.append({"type": "text", "text": seg})
480
+
481
+ if append_unused_images:
482
+ used = set(ordered_images)
483
+ for im in imgs:
484
+ if im not in used:
485
+ content.append({"type": "image", "image": im})
486
+ ordered_images.append(im)
487
+
488
+ return content, ordered_images
489
+
490
+ def _build_content_from_template(
491
+ self,
492
+ template: str,
493
+ image_map: Optional[Dict[str, Image.Image]],
494
+ text_map: Optional[Dict[str, str]],
495
+ append_unused_images: bool = False,
496
+ ) -> Tuple[list, list]:
497
+ """
498
+ NAMED placeholders: <frontal_image>, <lateral_image>, <prior_image>, <report>, <prior_report>, <demographics>, ...
499
+ Also supports alias: <current_image> -> <frontal_image>.
500
+ """
501
+ template = template or ""
502
+ image_map = {k.lower(): v for k, v in (image_map or {}).items() if v is not None}
503
+ text_map = {k.lower(): v for k, v in (text_map or {}).items() if v is not None and str(v).strip()}
504
+
505
+ content: list = []
506
+ images_in_order: list = []
507
+
508
+ pat = re.compile(r"<\s*([A-Za-z_]\w*)\s*>")
509
+ pos = 0
510
+ for m in pat.finditer(template):
511
+ s, e = m.span()
512
+ if s > pos:
513
+ seg = template[pos:s]
514
+ if seg.strip():
515
+ content.append({"type": "text", "text": seg})
516
+
517
+ name = m.group(1).lower()
518
+ # alias: current_image -> frontal_image
519
+ if name == "current_image":
520
+ name = "frontal_image"
521
+
522
+ if name in image_map: # <<< generalized image handling
523
+ img = image_map.get(name)
524
+ if img is not None:
525
+ content.append({"type": "image", "image": img})
526
+ images_in_order.append(img)
527
+ else:
528
+ val = text_map.get(name)
529
+ if val is not None:
530
+ content.append({"type": "text", "text": str(val)})
531
+
532
+ pos = e
533
+
534
+ if pos < len(template):
535
+ tail = template[pos:]
536
+ if tail.strip():
537
+ content.append({"type": "text", "text": tail})
538
+
539
+ # Append any not-yet-used images at the end (conditionally)
540
+ if append_unused_images:
541
+ for key, img in image_map.items():
542
+ if img is not None and img not in images_in_order:
543
+ content.append({"type": "image", "image": img})
544
+ images_in_order.append(img)
545
+
546
+ return content, images_in_order
547
+
548
+ def _mask_last_role_block(self, inputs: dict, hidden: torch.Tensor) -> torch.Tensor:
549
+ """
550
+ Boolean mask (B,S) selecting tokens inside the **last** role block (user/assistant),
551
+ excluding the final <|im_end|>, for **any** batch size.
552
+ Falls back to attention_mask if special tokens are unavailable.
553
+ """
554
+ device = hidden.device
555
+ ids = inputs.get("input_ids", None)
556
+ attn = inputs.get("attention_mask", None)
557
+ if ids is None:
558
+ return (attn if attn is not None else torch.ones(hidden.shape[:2], device=device, dtype=torch.long)).bool()
559
+
560
+ B, S = ids.shape
561
+ mask = torch.zeros((B, S), device=device, dtype=torch.bool)
562
+
563
+ # Try to get ChatML boundary tokens
564
+ try:
565
+ start_id = self.processor.tokenizer.convert_tokens_to_ids("<|im_start|>")
566
+ except Exception:
567
+ start_id = None
568
+ try:
569
+ end_id = self.processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
570
+ except Exception:
571
+ end_id = None
572
+
573
+ if end_id is None:
574
+ return (attn if attn is not None else torch.ones((B, S), device=device, dtype=torch.long)).bool()
575
+
576
+ for b in range(B):
577
+ # Limit search to valid tokens when attention mask is present
578
+ if attn is not None:
579
+ valid_len = int(attn[b].sum().item())
580
+ else:
581
+ valid_len = S
582
+ valid_len = max(1, min(valid_len, S))
583
+ seq = ids[b, :valid_len]
584
+
585
+ ends = (seq == end_id).nonzero(as_tuple=False).flatten()
586
+ if ends.numel() == 0:
587
+ # No explicit blocks; fall back to all valid tokens
588
+ mask[b, :valid_len] = True
589
+ continue
590
+ last_end = int(ends[-1].item())
591
+
592
+ last_start = -1
593
+ if start_id is not None:
594
+ starts = (seq == start_id).nonzero(as_tuple=False).flatten()
595
+ starts_before = starts[starts < last_end] if starts.numel() > 0 else None
596
+ if starts_before is not None and starts_before.numel() > 0:
597
+ last_start = int(starts_before[-1].item())
598
+ elif ends.numel() >= 2:
599
+ # Heuristic: if no <|im_start|>, use previous end as start
600
+ last_start = int(ends[-2].item())
601
+ else:
602
+ if ends.numel() >= 2:
603
+ last_start = int(ends[-2].item())
604
+
605
+ left = max(last_start + 1, 0)
606
+ right = max(last_end - 1, left)
607
+ mask[b, left:right + 1] = True
608
+
609
+ if attn is not None:
610
+ mask = mask & attn.bool()
611
+ return mask
612
+
613
+ # ---------- encoders (unified everywhere) ----------
614
+
615
+ @torch.no_grad()
616
+ def encode_text_unified(self, instructions: List[Optional[str]], texts: List[str], role: str = "user",
617
+ normalize: bool = True) -> torch.Tensor:
618
+ """Text-only, but still go through the unified VL path for consistency."""
619
+ empty_images = [[] for _ in texts]
620
+ return self.encode_interleaved(instructions, texts, empty_images, role=role, normalize=normalize)
621
+
622
+ @torch.no_grad()
623
+ def encode_images_unified(self, instructions: List[Optional[str]], image_templates: List[str],
624
+ image_maps: List[Dict[str, Image.Image]], role: str = "user",
625
+ normalize: bool = True, image_size: Optional[int] = None) -> torch.Tensor:
626
+ """
627
+ Image-only via unified path. Pass templates like "<frontal_image>" or "" (images only included if explicitly referenced).
628
+ """
629
+ empty_text_maps = [{} for _ in image_templates]
630
+ return self.encode_interleaved_with_ph(instructions, image_templates, image_maps, empty_text_maps,
631
+ role=role, normalize=normalize, image_size=image_size)
632
+
633
+ @torch.no_grad()
634
+ def encode_interleaved(
635
+ self,
636
+ instructions: List[Optional[str]],
637
+ contents: List[str],
638
+ images: List[List[Image.Image]],
639
+ role: str = "user",
640
+ normalize: bool = True,
641
+ image_size: Optional[int] = None, # 504 or 1008 override
642
+ ) -> torch.Tensor:
643
+ device = self.device
644
+ vm = self._get_vision_module()
645
+ vision_dtype = next(vm.parameters()).dtype
646
+
647
+ assert len(instructions) == len(contents) == len(images), "length mismatch"
648
+ out_vecs = []
649
+ target = self._target_from_image_size(image_size)
650
+
651
+ for inst, text, imgs in zip(instructions, contents, images):
652
+ proc_imgs = [to_qwen_grid(im, target=target) for im in (imgs or [])]
653
+ content_list, images_in_order = self._build_interleaved_content(
654
+ text or "", proc_imgs, append_unused_images=False
655
+ )
656
+
657
+ msgs = []
658
+ if inst and str(inst).strip():
659
+ msgs.append({"role": "system", "content": [{"type": "text", "text": inst}]})
660
+ msgs.append({"role": role, "content": content_list})
661
+
662
+ chat_text = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
663
+
664
+ proc = self.processor(
665
+ text=[chat_text],
666
+ images=images_in_order if images_in_order else None,
667
+ return_tensors="pt",
668
+ padding=True,
669
+ truncation=True,
670
+ do_resize=False,
671
+ max_length=self.max_text_tokens,
672
+ )
673
+ inputs = {k: v.to(device) for k, v in proc.items()}
674
+ if "pixel_values" in inputs:
675
+ inputs["pixel_values"] = inputs["pixel_values"].to(device=device, dtype=vision_dtype)
676
+ if "image_grid_thw" in inputs:
677
+ inputs["image_grid_thw"] = inputs["image_grid_thw"].to(device)
678
+
679
+ out = self.vl(**inputs, output_hidden_states=True, use_cache=False)
680
+ hidden = out.hidden_states[-1] # (1, S, H)
681
+ span_mask = self._mask_last_role_block(inputs, hidden) # (1, S)
682
+
683
+ if self.pool_mode == "latent_attention":
684
+ pool_dtype = next(self.unified_pooler.parameters()).dtype
685
+ if hidden.dtype != pool_dtype:
686
+ hidden = hidden.to(dtype=pool_dtype)
687
+ vec = self.unified_pooler(hidden, span_mask).squeeze(0)
688
+ else:
689
+ vec = masked_mean_pool(hidden, span_mask).squeeze(0)
690
+
691
+ out_vecs.append(vec)
692
+
693
+ embs = torch.stack(out_vecs, dim=0)
694
+ proj_dtype = next(self.unified_proj.parameters()).dtype
695
+ emb = self.unified_proj(embs.to(dtype=proj_dtype))
696
+ if normalize:
697
+ emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12)
698
+ return emb
699
+
700
+ @torch.no_grad()
701
+ def encode_interleaved_with_ph(
702
+ self,
703
+ instructions: List[Optional[str]],
704
+ templates: List[str],
705
+ image_maps: List[Optional[Dict[str, Image.Image]]],
706
+ text_maps: List[Optional[Dict[str, str]]],
707
+ role: str = "user",
708
+ normalize: bool = True,
709
+ image_size: Optional[int] = None, # 504 or 1008 override
710
+ ) -> torch.Tensor:
711
+ device = self.device
712
+ vm = self._get_vision_module()
713
+ vision_dtype = next(vm.parameters()).dtype
714
+
715
+ assert len(instructions) == len(templates) == len(image_maps) == len(text_maps), "length mismatch"
716
+
717
+ vecs = []
718
+ target = self._target_from_image_size(image_size)
719
+
720
+ for inst, tmpl, imap, tmap in zip(instructions, templates, image_maps, text_maps):
721
+ proc_imap: Dict[str, Image.Image] = {}
722
+ if imap:
723
+ for k, im in imap.items():
724
+ if im is not None:
725
+ proc_imap[k.lower()] = to_qwen_grid(im, target=target)
726
+
727
+ content_list, images_in_order = self._build_content_from_template(tmpl or "", proc_imap, (tmap or {}))
728
+
729
+ msgs = []
730
+ if inst and str(inst).strip():
731
+ msgs.append({"role": "system", "content": [{"type": "text", "text": inst}]})
732
+ msgs.append({"role": role, "content": content_list})
733
+
734
+ chat_text = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
735
+
736
+ proc = self.processor(
737
+ text=[chat_text],
738
+ images=images_in_order if images_in_order else None,
739
+ return_tensors="pt",
740
+ padding=True,
741
+ truncation=True,
742
+ do_resize=False,
743
+ max_length=self.max_text_tokens,
744
+ )
745
+ inputs = {k: v.to(device) for k, v in proc.items()}
746
+ if "pixel_values" in inputs:
747
+ inputs["pixel_values"] = inputs["pixel_values"].to(device=device, dtype=vision_dtype)
748
+ if "image_grid_thw" in inputs:
749
+ inputs["image_grid_thw"] = inputs["image_grid_thw"].to(device)
750
+
751
+ out = self.vl(**inputs, output_hidden_states=True, use_cache=False)
752
+ hidden = out.hidden_states[-1] # (1, S, H)
753
+ span_mask = self._mask_last_role_block(inputs, hidden) # (1, S)
754
+
755
+ if self.pool_mode == "latent_attention":
756
+ pool_dtype = next(self.unified_pooler.parameters()).dtype
757
+ if hidden.dtype != pool_dtype:
758
+ hidden = hidden.to(dtype=pool_dtype)
759
+ vec = self.unified_pooler(hidden, span_mask).squeeze(0)
760
+ else:
761
+ vec = masked_mean_pool(hidden, span_mask).squeeze(0)
762
+
763
+ vecs.append(vec)
764
+
765
+ embs = torch.stack(vecs, dim=0)
766
+ proj_dtype = next(self.unified_proj.parameters()).dtype
767
+ emb = self.unified_proj(embs.to(dtype=proj_dtype))
768
+ if normalize:
769
+ emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12)
770
+ return emb
771
+
772
+ # ------------- (dual encoders for debugging) -------------
773
+
774
+ @torch.no_grad()
775
+ def encode_text_dual(self, texts: List[str], normalize: bool = True) -> torch.Tensor:
776
+ device = self.device
777
+ tok = self.processor.tokenizer(text=texts, padding=True, truncation=True, return_tensors="pt", max_length=self.max_text_tokens)
778
+ tok = {k: v.to(device) for k, v in tok.items()}
779
+ lm = self._get_text_module()
780
+ out = lm(**tok, output_hidden_states=True, use_cache=False)
781
+ hidden = out.last_hidden_state
782
+ mask = tok.get("attention_mask")
783
+ pooled = masked_mean_pool(hidden, mask)
784
+ proj_dtype = next(self.text_proj.parameters()).dtype
785
+ emb = self.text_proj(pooled.to(dtype=proj_dtype))
786
+ if normalize:
787
+ emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12)
788
+ return emb
789
+
790
+ @torch.no_grad()
791
+ def encode_images_dual(self, images: List[List[Image.Image]], normalize: bool = True,
792
+ image_size: Optional[int] = None) -> torch.Tensor:
793
+ device = self.device
794
+ flat = [img for group in images for img in group]
795
+ counts = [len(g) for g in images]
796
+ if len(flat) == 0:
797
+ proj_dtype = next(self.image_proj.parameters()).dtype
798
+ zeros = torch.zeros((len(images), self.vision_hidden), device=device, dtype=proj_dtype)
799
+ emb = self.image_proj(zeros)
800
+ if normalize:
801
+ emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12)
802
+ return emb
803
+ target = self._target_from_image_size(image_size)
804
+ processed = [to_qwen_grid(img, target=target) for img in flat]
805
+ proc = self.processor.image_processor(images=processed, return_tensors="pt", do_resize=False)
806
+ vm = self._get_vision_module()
807
+ vision_dtype = next(vm.parameters()).dtype
808
+ pixel_values = proc["pixel_values"].to(device=device, dtype=vision_dtype)
809
+ vis_out = vm(pixel_values=pixel_values, output_hidden_states=True)
810
+ feats = vis_out[0] if isinstance(vis_out, (tuple, list)) else getattr(vis_out, "last_hidden_state", None)
811
+ if feats is None:
812
+ feats = getattr(vis_out, "pooler_output", None)
813
+ if feats is None:
814
+ raise RuntimeError("Vision backbone did not return features as expected.")
815
+ per_img = feats.mean(dim=1) if feats.ndim == 3 else feats
816
+ splits = torch.split(per_img, counts, dim=0)
817
+ set_vecs = torch.stack([s.mean(dim=0) if s.ndim > 1 else s for s in splits], dim=0)
818
+ proj_dtype = next(self.image_proj.parameters()).dtype
819
+ emb = self.image_proj(set_vecs.to(dtype=proj_dtype))
820
+ if normalize:
821
+ emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12)
822
+ return emb
823
+
824
+ # ===================== PHRASE GROUNDING UTILS =====================
825
+
826
+ def _find_subsequence(self, haystack: list, needle: list) -> list:
827
+ """Return start indices where 'needle' occurs in 'haystack' (exact match)."""
828
+ if not haystack or not needle or len(needle) > len(haystack):
829
+ return []
830
+ hits = []
831
+ n = len(needle)
832
+ for i in range(len(haystack) - n + 1):
833
+ if haystack[i:i+n] == needle:
834
+ hits.append(i)
835
+ return hits
836
+
837
+ def _window_decode_matches(self, tokenizer, ids, target_lower: str) -> list:
838
+ """Fallback: sliding-window decode match (robust to BPE splits). Returns window (start,end) indices."""
839
+ hits = []
840
+ L = len(ids)
841
+ # Small cap on window length to avoid expensive decode; most medical terms fit <= 5 tokens.
842
+ for w in range(1, 8):
843
+ for i in range(0, L - w + 1):
844
+ s, e = i, i + w
845
+ text = tokenizer.decode(ids[s:e], skip_special_tokens=True).lower().replace(" ", "")
846
+ if target_lower in text:
847
+ hits.append((s, e))
848
+ # De-duplicate overlapping windows by preferring shortest span
849
+ hits = sorted(set(hits), key=lambda x: (x[1]-x[0], x[0]))
850
+ return hits
851
+
852
+ def _resize_heatmap_like(self, hm_np, target_w, target_h):
853
+ from PIL import Image
854
+ import numpy as np
855
+ # hm_np: (H, W) in [0,1]; resize with bilinear to (target_h, target_w)
856
+ H, W = hm_np.shape
857
+ im = Image.fromarray((hm_np * 255.0).astype("uint8"), mode="L")
858
+ im = im.resize((target_w, target_h), Image.BILINEAR)
859
+ out = (np.array(im).astype("float32") / 255.0)
860
+ return out
861
+
862
+ def _overlay_heatmap_on_image(self, img_pil, hm_np, alpha=0.45):
863
+ """Return PIL with heatmap overlay; hm_np in [0,1] same size as img."""
864
+ import matplotlib
865
+ import numpy as np
866
+ from PIL import Image
867
+
868
+ img = np.array(img_pil.convert("RGB")).astype("float32") / 255.0
869
+ H, W = img.shape[:2]
870
+ hm = np.clip(hm_np, 0.0, 1.0)
871
+ if hm.shape[:2] != (H, W):
872
+ raise ValueError("Heatmap and image size mismatch")
873
+ # Use a perceptually reasonable colormap without fixing colors for downstream tools.
874
+ cmap = matplotlib.cm.get_cmap("jet")
875
+ color_hm = cmap(hm)[..., :3] # (H,W,3)
876
+ blended = (1.0 - alpha) * img + alpha * color_hm
877
+ blended = np.clip(blended, 0.0, 1.0)
878
+ return Image.fromarray((blended * 255).astype("uint8"))
879
+
880
+ def phrase_ground_and_visualize(
881
+ self,
882
+ word: str,
883
+ template: str,
884
+ row,
885
+ role: str = "user",
886
+ instruction: str = None,
887
+ image_size: int = None, # multiples of 28; defaults to self.image_size
888
+ layer_for_text: int = -1, # which hidden_states layer to pull token reps from
889
+ save_dir: str = None, # if set, saves overlays as PNGs
890
+ return_arrays: bool = False, # if True, return heatmaps as numpy arrays
891
+ ):
892
+ """
893
+ Compute patch-level grounding for a word against images referenced in `template` filled by `row`.
894
+ Returns a PhraseGroundingOutput, and optionally writes overlay PNGs.
895
+
896
+ Strategy:
897
+ - Build a single-sample chat like encode_interleaved_with_ph().
898
+ - Forward Qwen-VL with hidden_states (+ attention if available).
899
+ - Locate word tokens inside last role block.
900
+ - Run vision tower once to get per-patch features per image.
901
+ - Project (text token avg) with text_proj, patches with image_proj; cosine sim per patch → heatmap.
902
+ - (Optional) also compute LM self-attn from word tokens to any image placeholders if available.
903
+ """
904
+ import os, numpy as np, torch
905
+ from PIL import Image
906
+
907
+ device = self.device
908
+ tok = self.processor.tokenizer
909
+ target = self._target_from_image_size(image_size)
910
+
911
+ # --- Build content exactly like your training path ---
912
+ imap = build_image_map_from_row(row, root="")
913
+ # resize to Qwen grid (only for actually referenced keys)
914
+ # We won't pre-filter keys; _build_content_from_template handles which placeholders are used.
915
+ proc_imap = {k.lower(): to_qwen_grid(v, target=target) for k, v in (imap or {}).items() if v is not None}
916
+ tmap = build_text_map_from_row(row)
917
+
918
+ content_list, images_in_order = self._build_content_from_template(template or "", proc_imap, (tmap or {}), append_unused_images=False)
919
+
920
+ msgs = []
921
+ if instruction and str(instruction).strip():
922
+ msgs.append({"role": "system", "content": [{"type": "text", "text": f"INSTRUCTION:\n{instruction}"}]})
923
+ msgs.append({"role": role, "content": content_list})
924
+ chat_text = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
925
+
926
+ vm = self._get_vision_module()
927
+ vision_dtype = next(vm.parameters()).dtype
928
+
929
+ proc = self.processor(
930
+ text=[chat_text],
931
+ images=images_in_order if images_in_order else None,
932
+ return_tensors="pt",
933
+ padding=True,
934
+ truncation=True,
935
+ do_resize=False,
936
+ max_length=self.max_text_tokens,
937
+ )
938
+ inputs = {k: v.to(device) for k, v in proc.items()}
939
+ if "pixel_values" in inputs:
940
+ inputs["pixel_values"] = inputs["pixel_values"].to(device=device, dtype=vision_dtype)
941
+ if "image_grid_thw" in inputs:
942
+ inputs["image_grid_thw"] = inputs["image_grid_thw"].to(device)
943
+
944
+ # --- Forward with hidden states (+ attentions if the model exposes them) ---
945
+ with torch.no_grad():
946
+ out = self.vl(**inputs, output_hidden_states=True, output_attentions=True, use_cache=False, return_dict=True)
947
+
948
+ hidden = out.hidden_states[layer_for_text] # (1, S, H)
949
+ span_mask = self._mask_last_role_block(inputs, hidden)[0].bool() # (S,)
950
+ seq_ids = inputs["input_ids"][0].tolist()
951
+
952
+ # --- Find token indices for the word inside the last role block ---
953
+ # 1) exact subsequence match of token ids
954
+ tgt_ids = tok(word, add_special_tokens=False)["input_ids"]
955
+ last_role_positions = [i for i, m in enumerate(span_mask.tolist()) if m]
956
+ id_seq_in_span = [seq_ids[i] for i in last_role_positions]
957
+ hits = self._find_subsequence(id_seq_in_span, tgt_ids)
958
+ token_span = None # (abs_start, abs_end)
959
+ if hits:
960
+ start_in_span = hits[0]
961
+ abs_start = last_role_positions[start_in_span]
962
+ abs_end = last_role_positions[start_in_span + len(tgt_ids) - 1] + 1 # exclusive
963
+ token_span = (abs_start, abs_end)
964
+ else:
965
+ # 2) fallback: decode windows in-span and fuzzy match lowercase without spaces
966
+ win_hits = self._window_decode_matches(tok, id_seq_in_span, target_lower=word.lower().replace(" ", ""))
967
+ if win_hits:
968
+ s, e = win_hits[0]
969
+ abs_start = last_role_positions[s]
970
+ abs_end = last_role_positions[e - 1] + 1
971
+ token_span = (abs_start, abs_end)
972
+
973
+ if token_span is None:
974
+ # If the word cannot be located, we center on the last token in the last-role block.
975
+ # This keeps the visualization functional for debugging.
976
+ last_idx = last_role_positions[-1]
977
+ token_span = (last_idx, last_idx + 1)
978
+
979
+ s_idx, e_idx = token_span
980
+ word_tokens = hidden[0, s_idx:e_idx, :] # (T_word, Htxt)
981
+ # Average sub-tokens → one vector
982
+ word_vec_txt = word_tokens.mean(dim=0, keepdim=True) # (1, Htxt)
983
+
984
+ # --- Get vision patch features per image ---
985
+ heatmaps = []
986
+ per_image_debug = []
987
+ if "pixel_values" in inputs:
988
+ # Use the TOP-LEVEL vision model entry
989
+ vmodel = self._get_vision_entry()
990
+ with torch.no_grad():
991
+ vout = vmodel(
992
+ pixel_values=inputs["pixel_values"],
993
+ grid_thw=inputs.get("image_grid_thw", None),
994
+ output_hidden_states=True,
995
+ return_dict=True,
996
+ )
997
+
998
+ # vout.last_hidden_state: (B, Svis, C)
999
+ vlast = vout.last_hidden_state
1000
+ B, Svis, C = vlast.shape
1001
+
1002
+ # Grid sizes per image (T,H,W)
1003
+ grids = inputs.get("image_grid_thw", None)
1004
+ if grids is not None:
1005
+ # grids shape: (B, 3) => (T, H, W)
1006
+ thw = grids.detach().cpu().tolist()
1007
+ if isinstance(thw[0], (int, float)): # single image edge case
1008
+ thw = [thw]
1009
+ else:
1010
+ thw = [[1, int(round(Svis ** 0.5)), int(round(Svis ** 0.5))] for _ in range(B)]
1011
+
1012
+ # If a CLS token exists, Svis == T*H*W + 1; drop it
1013
+ per_img = []
1014
+ offset = 0
1015
+ for i in range(B):
1016
+ t, h, w = map(int, thw[i])
1017
+ tokens_per = t * h * w
1018
+ take_from = 1 if (Svis == tokens_per + 1) else 0
1019
+ patches = vlast[i, take_from:take_from + tokens_per, :] # (T*H*W, C)
1020
+ per_img.append((patches, (t, h, w)))
1021
+
1022
+ proj_dtype_img = next(self.image_proj.parameters()).dtype
1023
+ proj_dtype_txt = next(self.text_proj.parameters()).dtype
1024
+
1025
+ word_vec = self.text_proj(word_vec_txt.to(dtype=proj_dtype_txt))
1026
+ word_vec = word_vec / (word_vec.norm(dim=-1, keepdim=True) + 1e-12)
1027
+
1028
+ for (patches, (t, h, w)) in per_img:
1029
+ patch_emb = self.image_proj(patches.to(dtype=proj_dtype_img))
1030
+ patch_emb = patch_emb / (patch_emb.norm(dim=-1, keepdim=True) + 1e-12)
1031
+ sim = (patch_emb @ word_vec[0].T).squeeze(-1) # (P,)
1032
+ sim = sim.reshape(t, h, w).mean(dim=0) # (H, W)
1033
+ smin, smax = float(sim.min()), float(sim.max())
1034
+ hm = (sim - smin) / max(1e-6, (smax - smin))
1035
+ heatmaps.append(hm.detach().cpu().numpy())
1036
+ per_image_debug.append({"tokens_per": t*h*w, "grid": (t, h, w)})
1037
+
1038
+ # --- Save overlays if requested ---
1039
+ saved_paths = []
1040
+ if save_dir and heatmaps:
1041
+ os.makedirs(save_dir, exist_ok=True)
1042
+ for i, im in enumerate(images_in_order):
1043
+ # Ensure the heatmap is resized to the same (square) size we fed Qwen
1044
+ tgt_w, tgt_h = im.size
1045
+ hm_np = self._resize_heatmap_like(heatmaps[i], tgt_w, tgt_h)
1046
+ overlay = self._overlay_heatmap_on_image(im, hm_np, alpha=0.45)
1047
+ fname = os.path.join(save_dir, f"ground_{i:02d}_{word.replace(' ','_')}.png")
1048
+ overlay.save(fname)
1049
+ saved_paths.append(fname)
1050
+
1051
+ result = PhraseGroundingOutput(
1052
+ token_span=(int(s_idx), int(e_idx)),
1053
+ per_image=[{
1054
+ "heatmap": (heatmaps[i] if return_arrays else None),
1055
+ "saved_path": (saved_paths[i] if i < len(saved_paths) else None),
1056
+ "grid": per_image_debug[i].get("grid", None),
1057
+ "tokens_per": per_image_debug[i].get("tokens_per", None),
1058
+ "placeholder_attn": per_image_debug[i].get("placeholder_attn", None),
1059
+ } for i in range(len(heatmaps))]
1060
+ )
1061
+ return result
1062
+
1063
+
1064
+ class PhraseGroundingOutput:
1065
+ def __init__(self, token_span, per_image):
1066
+ self.token_span = token_span # (start_idx, end_idx) within last-role span
1067
+ self.per_image = per_image # list of dicts with fields below
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ peft
4
+ huggingface_hub
5
+ pillow