import math import os import re import json from typing import List, Optional, Dict, Tuple, Union from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration # Treat these as empty/missing (case-insensitive, whitespace-tolerant) _EMPTY_SENTINELS = {"", "-1", "none", "null", "na", "n/a", "nan", ""} def _is_empty_cell(x) -> bool: """True if x should be considered 'missing'.""" if x is None: return True # float('nan') and numpy.float64('nan') try: if isinstance(x, float) and math.isnan(x): return True except Exception: pass s = str(x).strip().lower() return s in _EMPTY_SENTINELS def _clean_text_or_empty(x) -> str: """Return a clean string or '' if missing.""" return "" if _is_empty_cell(x) else str(x).strip() try: from peft import LoraConfig, get_peft_model HAS_PEFT = True except Exception: HAS_PEFT = False # ----------------------- misc utils ----------------------- def l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-12) -> torch.Tensor: return x / (x.norm(dim=dim, keepdim=True) + eps) def masked_mean_pool(hidden: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Mean over tokens where mask==True.""" if mask is None: return hidden.mean(dim=1) mask = mask.to(hidden.dtype) denom = mask.sum(dim=1, keepdim=True).clamp_min(1e-6) return (hidden * mask.unsqueeze(-1)).sum(dim=1) / denom def to_qwen_grid(img: Image.Image, target: int = 512, patch_size: int = 14, merge_size: int = 2) -> Image.Image: """ Resize image so H=W is a multiple of 28 (=patch_size*merge_size). FLOOR to nearest multiple (512->504, 1024->1008). """ grid = patch_size * merge_size # 28 new = max(grid, (target // grid) * grid) return img.resize((new, new), Image.BILINEAR) def _open_or_none(path: object, root: str = "") -> Optional[Image.Image]: """Returns a PIL.Image or None. Handles '', NaN, '-1', , etc.""" if _is_empty_cell(path): return None p = str(path).strip() # Don't join URI-like paths if root and not re.match(r'^[a-zA-Z][a-zA-Z0-9+\-.]*://', p): p = os.path.join(root, p) try: return Image.open(p).convert("RGB") except Exception: return None def build_image_map_from_row(row, root: str = "") -> dict: """ Mapping per your schema: - frontal_image <- img_path1 (also used as current_image) - lateral_image <- img_path2 - prior_image <- img_path3 """ m = { "frontal_image": _open_or_none(str(row.get("img_path1", "-1")), root), "lateral_image": _open_or_none(str(row.get("img_path2", "-1")), root), "prior_image": _open_or_none(str(row.get("img_path3", "-1")), root), } # --- NEW: negative images available to templates --- n1 = _open_or_none(str(row.get("neg_image1", row.get("neg_path1", "-1"))), root) n2 = _open_or_none(str(row.get("neg_image2", row.get("neg_path2", "-1"))), root) # support either column name for prior: neg_image3 or neg_prior_image, also neg_path3 n3 = _open_or_none(str(row.get("neg_image3", row.get("neg_prior_image", row.get("neg_path3", "-1")))), root) if n1 is not None: m.update({"neg_image1": n1, "neg_path1": n1, "neg_frontal_image": n1}) if n2 is not None: m.update({"neg_image2": n2, "neg_path2": n2, "neg_lateral_image": n2}) if n3 is not None: m.update({"neg_prior_image": n3, "neg_image3": n3, "neg_path3": n3}) return m def _s(x): return "" if x is None else str(x) def build_text_map_from_row(row) -> Dict[str, str]: m = { "report": _clean_text_or_empty(row.get("report")), "prior_report": _clean_text_or_empty(row.get("prior_report")), "demographics": _clean_text_or_empty(row.get("demographics")), # --- NEW --- "lab_test": _clean_text_or_empty(row.get("lab_test")), "indication": _clean_text_or_empty(row.get("indication")), } # drop empties return {k: v for k, v in m.items() if v} def parse_text_placeholders(s) -> dict: if isinstance(s, dict): d = s elif isinstance(s, str) and s.strip(): try: d = json.loads(s) except Exception: d = {} else: d = {} if not isinstance(d, dict): return {} out = {} for k, v in d.items(): val = _clean_text_or_empty(v) if val: out[str(k).lower()] = val return out # ----------------------- pooling modules ----------------------- class LatentAttentionPooler(nn.Module): """ NV-Embed style: tokens (Q) attend to trainable latents (K=V), then MLP, then mean-pool over tokens (optionally masked). """ def __init__(self, dim: int, num_latents: int = 512, num_layers: int = 1, num_heads: int = 8, mlp_ratio: float = 2.0): super().__init__() self.latents = nn.Parameter(torch.randn(num_latents, dim) / math.sqrt(dim)) self.layers = nn.ModuleList() self.ln_q = nn.LayerNorm(dim) # for token queries self.ln_kv = nn.LayerNorm(dim) # for latent K/V for _ in range(num_layers): attn = nn.MultiheadAttention(dim, num_heads, batch_first=True) ffn = nn.Sequential( nn.Linear(dim, int(dim * mlp_ratio)), nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim), ) self.layers.append(nn.ModuleDict({"attn": attn, "ffn": ffn})) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: # x: (B, S, D) last-layer token states from the LLM B, S, D = x.shape # Prepare Q (tokens) and K,V (trainable latents) q = self.ln_q(x) lat = self.latents.unsqueeze(0).expand(B, -1, -1).contiguous() kv = self.ln_kv(lat) # Cross-attn: tokens query the latent dictionary (no key padding mask on latents) for blk in self.layers: y = blk["attn"](q, kv, kv, need_weights=False)[0] q = q + y # residual q = q + blk["ffn"](q) # MLP + residual # Mean-pool over **tokens**; mask only applied here return masked_mean_pool(q, mask) # (B, D) class Projection(nn.Module): def __init__(self, in_dim: int, out_dim: int = 1024, hidden: Optional[int] = None): super().__init__() if hidden is None: self.proj = nn.Sequential(nn.Linear(in_dim, out_dim, bias=False)) else: self.proj = nn.Sequential(nn.Linear(in_dim, hidden), nn.GELU(), nn.Linear(hidden, out_dim, bias=False)) def forward(self, x: torch.Tensor) -> torch.Tensor: return l2norm(self.proj(x)) # ----------------------- main wrapper ----------------------- class LingshuEmbedder(nn.Module): def __init__( self, model_name: str = "lingshu-medical-mllm/Lingshu-7B", attn_implementation: str = "flash_attention_2", torch_dtype: torch.dtype = torch.bfloat16, embed_dim: int = 1024, # unified pooling mode pool_mode: str = "latent_attention", # "latent_attention" | "mean" num_latents_unified: int = 512, # image grid control (supports 504 and 1008) image_size: int = 504, # default grid; per-call override allowed (504 or 1008) min_grid: int = 256, max_grid: int = 1296, # up to 36x36 (for 1008) # LoRA (optional) - tuned for memorization # r=64 for balanced performance; increase to 128 if VRAM allows use_lora: bool = False, lora_r: int = 64, lora_alpha: int = 64, lora_dropout: float = 0.0, # alpha=r, dropout=0 for memorization apply_lora_to_vision: bool = False, # make attention bi-directional (remove causal masking) bidirectional: bool = True, # text token budget (read by the training script) max_text_tokens: int = 2560, # gradient checkpointing enable_gradient_checkpointing: bool = False, device: Optional[Union[str, torch.device]] = None, ) -> None: super().__init__() # ---- device & backend ---- if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(device) if device.type != "cuda": attn_implementation = "sdpa" if torch_dtype in (torch.float16, torch.bfloat16): torch_dtype = torch.float32 # ---- load backbone + processor ---- self.vl = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_name, torch_dtype=torch_dtype, attn_implementation=attn_implementation ) self.processor = AutoProcessor.from_pretrained( model_name, min_pixels=min_grid * 28 * 28, max_pixels=max_grid * 28 * 28, ) self._propagate_attn_impl(attn_implementation) # freeze base for p in self.vl.parameters(): p.requires_grad_(False) # UNFREEZE vision projector for better image→text binding # Qwen2.5-VL has a visual projection module unfrozen_modules = [] for name, module in self.vl.named_modules(): # Look for vision projector: often named 'visual', 'vision_proj', 'mm_projector', etc. if any(x in name.lower() for x in ['visual.merger', 'visual.proj', 'vision_proj', 'mm_projector']): n_params = sum(p.numel() for p in module.parameters()) for p in module.parameters(): p.requires_grad_(True) unfrozen_modules.append((name, n_params)) if unfrozen_modules: print(f"[model] Unfrozen vision projector modules for memorization:") for name, n_params in unfrozen_modules: print(f" - {name}: {n_params:,} parameters") # dims txt_hidden = getattr(self.vl.config, "text_config", None) vis_hidden = getattr(self.vl.config, "vision_config", None) self.text_hidden = getattr(txt_hidden, "hidden_size", None) self.vision_hidden = getattr(vis_hidden, "out_hidden_size", None) or getattr(vis_hidden, "hidden_size", None) # projections (unified/text/image all project to same embed_dim space) self.text_proj = Projection(self.text_hidden, embed_dim, hidden=None) self.image_proj = Projection(self.vision_hidden, embed_dim, hidden=None) self.unified_proj = Projection(self.text_hidden, embed_dim, hidden=None) self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07))) # unified pooling config self.pool_mode = pool_mode if self.pool_mode == "latent_attention": self.unified_pooler = LatentAttentionPooler( dim=self.text_hidden, num_latents=num_latents_unified, # set default to 512 to match paper num_layers=1, num_heads=8 ) else: self.unified_pooler = None # image size handling (any multiple of 28 is allowed, e.g., 448, 504, 1008) if image_size % 28 != 0: raise ValueError(f"image_size must be a multiple of 28, got {image_size}") self.image_size = image_size # default; can override per call # optional LoRA self.peft_active = False if use_lora: if not HAS_PEFT: raise ImportError("peft not installed") targets_text = ("q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj") targets_vision = ("qkv", "proj") targets = list(set(targets_text + (targets_vision if apply_lora_to_vision else tuple()))) cfg = LoraConfig(r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, target_modules=targets, bias="none", task_type="CAUSAL_LM") self.vl = get_peft_model(self.vl, cfg) self.peft_active = True # make bi-directional if requested if bidirectional: self._enable_bidirectional_attention() # gradient checkpointing if enable_gradient_checkpointing: # Use the non-reentrant variant to avoid "requires_grad" warnings try: self.vl.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) except TypeError: # older transformers fallback self.vl.gradient_checkpointing_enable() try: self.vl.config.use_cache = False except Exception: pass # move to device self.to(device) self.device = device # align pooler dtype with model (and device) base_dtype = next(self.vl.parameters()).dtype if getattr(self, "unified_pooler", None) is not None: self.unified_pooler.to(device=device, dtype=base_dtype) # expose text token budget for processor calls in training script self.max_text_tokens = int(max_text_tokens) # ---------- internals ---------- def _propagate_attn_impl(self, impl: str): cfgs = [getattr(self.vl, "config", None)] if cfgs[0] is not None: for sub in ("text_config", "vision_config"): cfgs.append(getattr(cfgs[0], sub, None)) for cfg in cfgs: if cfg is None: continue try: cfg._attn_implementation = impl cfg.attn_implementation = impl if hasattr(cfg, "use_flash_attention_2"): cfg.use_flash_attention_2 = (impl == "flash_attention_2") except Exception: pass for _, module in self.vl.named_modules(): if hasattr(module, "config"): try: module.config._attn_implementation = impl module.config.attn_implementation = impl if hasattr(module.config, "use_flash_attention_2"): module.config.use_flash_attention_2 = (impl == "flash_attention_2") except Exception: pass def _enable_bidirectional_attention(self): """Best-effort removal of causal masking.""" cfg = getattr(self.vl, "config", None) if cfg is not None: if hasattr(cfg, "is_decoder"): cfg.is_decoder = False if hasattr(cfg, "use_cache"): cfg.use_cache = False core = getattr(self.vl, "model", self.vl) core_cfg = getattr(core, "config", None) if core_cfg is not None: if hasattr(core_cfg, "is_decoder"): core_cfg.is_decoder = False if hasattr(core_cfg, "use_cache"): core_cfg.use_cache = False for m in self.vl.modules(): if hasattr(m, "is_causal"): try: m.is_causal = False except Exception: pass def _get_text_module(self): core = getattr(self.vl, "model", self.vl) for attr in ("language_model", "text_model", "lm"): m = getattr(core, attr, None) if m is not None and hasattr(m, "forward"): return m for _, module in self.vl.named_modules(): cname = module.__class__.__name__.lower() if "vision" in cname: continue if hasattr(module, "forward") and hasattr(module, "embed_tokens"): return module raise AttributeError("Could not locate the text submodule in Qwen-VL.") def _get_vision_module(self): core = getattr(self.vl, "model", self.vl) for attr in ("vision_model", "vision_tower", "visual", "vision"): m = getattr(core, attr, None) if m is not None and hasattr(m, "forward"): return m for _, module in self.vl.named_modules(): if "vision" in module.__class__.__name__.lower(): return module raise AttributeError("Could not locate the vision submodule in Qwen-VL.") def _get_vision_entry(self): """ Return the top-level VisionModel object that accepts: forward(pixel_values=..., grid_thw=..., output_hidden_states=..., return_dict=True) Avoid returning the inner transformer which expects (hidden_states, grid_thw). """ core = getattr(self.vl, "model", self.vl) # Prefer the canonical attribute if present vis = getattr(core, "vision_model", None) if vis is not None: return vis # Fallback: search modules for something named *VisionModel for _, m in core.named_modules(): name = m.__class__.__name__.lower() if name.endswith("visionmodel"): return m # Last resort: previous generic getter (may return transformer; not ideal) return self._get_vision_module() # ----- chat/content builders & masking ----- def _target_from_image_size(self, image_size: Optional[int]) -> int: """ Return a pixel target that will be floored to a multiple of 28 by to_qwen_grid(). Any multiple of 28 works (e.g., 448, 504, 1008). """ sz = image_size if isinstance(image_size, int) and image_size % 28 == 0 else self.image_size return int(sz) def _build_interleaved_content(self, text: str, imgs: List[Image.Image], append_unused_images: bool = False) -> Tuple[list, list]: """ NUMERIC placeholders: , , ... Returns (content_list, images_in_order). """ if text is None: text = "" content: list = [] ordered_images: list = [] imgs = imgs or [] pat = re.compile(r"", re.IGNORECASE) pos = 0 matches = list(pat.finditer(text)) if not matches: # Do not auto-append images unless explicitly requested if text.strip(): content.append({"type": "text", "text": text}) if append_unused_images: for im in imgs: content.append({"type": "image", "image": im}) ordered_images.append(im) return content, ordered_images for m in matches: s, e = m.span() if s > pos: seg = text[pos:s] if seg.strip(): content.append({"type": "text", "text": seg}) idx = int(m.group(1)) - 1 if 0 <= idx < len(imgs): content.append({"type": "image", "image": imgs[idx]}) ordered_images.append(imgs[idx]) pos = e if pos < len(text): seg = text[pos:] if seg.strip(): content.append({"type": "text", "text": seg}) if append_unused_images: used = set(ordered_images) for im in imgs: if im not in used: content.append({"type": "image", "image": im}) ordered_images.append(im) return content, ordered_images def _build_content_from_template( self, template: str, image_map: Optional[Dict[str, Image.Image]], text_map: Optional[Dict[str, str]], append_unused_images: bool = False, ) -> Tuple[list, list]: """ NAMED placeholders: , , , , , , ... Also supports alias: -> . """ template = template or "" image_map = {k.lower(): v for k, v in (image_map or {}).items() if v is not None} text_map = {k.lower(): v for k, v in (text_map or {}).items() if v is not None and str(v).strip()} content: list = [] images_in_order: list = [] pat = re.compile(r"<\s*([A-Za-z_]\w*)\s*>") pos = 0 for m in pat.finditer(template): s, e = m.span() if s > pos: seg = template[pos:s] if seg.strip(): content.append({"type": "text", "text": seg}) name = m.group(1).lower() # alias: current_image -> frontal_image if name == "current_image": name = "frontal_image" if name in image_map: # <<< generalized image handling img = image_map.get(name) if img is not None: content.append({"type": "image", "image": img}) images_in_order.append(img) else: val = text_map.get(name) if val is not None: content.append({"type": "text", "text": str(val)}) pos = e if pos < len(template): tail = template[pos:] if tail.strip(): content.append({"type": "text", "text": tail}) # Append any not-yet-used images at the end (conditionally) if append_unused_images: for key, img in image_map.items(): if img is not None and img not in images_in_order: content.append({"type": "image", "image": img}) images_in_order.append(img) return content, images_in_order def _mask_last_role_block(self, inputs: dict, hidden: torch.Tensor) -> torch.Tensor: """ Boolean mask (B,S) selecting tokens inside the **last** role block (user/assistant), excluding the final <|im_end|>, for **any** batch size. Falls back to attention_mask if special tokens are unavailable. """ device = hidden.device ids = inputs.get("input_ids", None) attn = inputs.get("attention_mask", None) if ids is None: return (attn if attn is not None else torch.ones(hidden.shape[:2], device=device, dtype=torch.long)).bool() B, S = ids.shape mask = torch.zeros((B, S), device=device, dtype=torch.bool) # Try to get ChatML boundary tokens try: start_id = self.processor.tokenizer.convert_tokens_to_ids("<|im_start|>") except Exception: start_id = None try: end_id = self.processor.tokenizer.convert_tokens_to_ids("<|im_end|>") except Exception: end_id = None if end_id is None: return (attn if attn is not None else torch.ones((B, S), device=device, dtype=torch.long)).bool() for b in range(B): # Limit search to valid tokens when attention mask is present if attn is not None: valid_len = int(attn[b].sum().item()) else: valid_len = S valid_len = max(1, min(valid_len, S)) seq = ids[b, :valid_len] ends = (seq == end_id).nonzero(as_tuple=False).flatten() if ends.numel() == 0: # No explicit blocks; fall back to all valid tokens mask[b, :valid_len] = True continue last_end = int(ends[-1].item()) last_start = -1 if start_id is not None: starts = (seq == start_id).nonzero(as_tuple=False).flatten() starts_before = starts[starts < last_end] if starts.numel() > 0 else None if starts_before is not None and starts_before.numel() > 0: last_start = int(starts_before[-1].item()) elif ends.numel() >= 2: # Heuristic: if no <|im_start|>, use previous end as start last_start = int(ends[-2].item()) else: if ends.numel() >= 2: last_start = int(ends[-2].item()) left = max(last_start + 1, 0) right = max(last_end - 1, left) mask[b, left:right + 1] = True if attn is not None: mask = mask & attn.bool() return mask # ---------- encoders (unified everywhere) ---------- @torch.no_grad() def encode_text_unified(self, instructions: List[Optional[str]], texts: List[str], role: str = "user", normalize: bool = True) -> torch.Tensor: """Text-only, but still go through the unified VL path for consistency.""" empty_images = [[] for _ in texts] return self.encode_interleaved(instructions, texts, empty_images, role=role, normalize=normalize) @torch.no_grad() def encode_images_unified(self, instructions: List[Optional[str]], image_templates: List[str], image_maps: List[Dict[str, Image.Image]], role: str = "user", normalize: bool = True, image_size: Optional[int] = None) -> torch.Tensor: """ Image-only via unified path. Pass templates like "" or "" (images only included if explicitly referenced). """ empty_text_maps = [{} for _ in image_templates] return self.encode_interleaved_with_ph(instructions, image_templates, image_maps, empty_text_maps, role=role, normalize=normalize, image_size=image_size) @torch.no_grad() def encode_interleaved( self, instructions: List[Optional[str]], contents: List[str], images: List[List[Image.Image]], role: str = "user", normalize: bool = True, image_size: Optional[int] = None, # 504 or 1008 override ) -> torch.Tensor: device = self.device vm = self._get_vision_module() vision_dtype = next(vm.parameters()).dtype assert len(instructions) == len(contents) == len(images), "length mismatch" out_vecs = [] target = self._target_from_image_size(image_size) for inst, text, imgs in zip(instructions, contents, images): proc_imgs = [to_qwen_grid(im, target=target) for im in (imgs or [])] content_list, images_in_order = self._build_interleaved_content( text or "", proc_imgs, append_unused_images=False ) msgs = [] if inst and str(inst).strip(): msgs.append({"role": "system", "content": [{"type": "text", "text": inst}]}) msgs.append({"role": role, "content": content_list}) chat_text = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) proc = self.processor( text=[chat_text], images=images_in_order if images_in_order else None, return_tensors="pt", padding=True, truncation=True, do_resize=False, max_length=self.max_text_tokens, ) inputs = {k: v.to(device) for k, v in proc.items()} if "pixel_values" in inputs: inputs["pixel_values"] = inputs["pixel_values"].to(device=device, dtype=vision_dtype) if "image_grid_thw" in inputs: inputs["image_grid_thw"] = inputs["image_grid_thw"].to(device) out = self.vl(**inputs, output_hidden_states=True, use_cache=False) hidden = out.hidden_states[-1] # (1, S, H) span_mask = self._mask_last_role_block(inputs, hidden) # (1, S) if self.pool_mode == "latent_attention": pool_dtype = next(self.unified_pooler.parameters()).dtype if hidden.dtype != pool_dtype: hidden = hidden.to(dtype=pool_dtype) vec = self.unified_pooler(hidden, span_mask).squeeze(0) else: vec = masked_mean_pool(hidden, span_mask).squeeze(0) out_vecs.append(vec) embs = torch.stack(out_vecs, dim=0) proj_dtype = next(self.unified_proj.parameters()).dtype emb = self.unified_proj(embs.to(dtype=proj_dtype)) if normalize: emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12) return emb @torch.no_grad() def encode_interleaved_with_ph( self, instructions: List[Optional[str]], templates: List[str], image_maps: List[Optional[Dict[str, Image.Image]]], text_maps: List[Optional[Dict[str, str]]], role: str = "user", normalize: bool = True, image_size: Optional[int] = None, # 504 or 1008 override ) -> torch.Tensor: device = self.device vm = self._get_vision_module() vision_dtype = next(vm.parameters()).dtype assert len(instructions) == len(templates) == len(image_maps) == len(text_maps), "length mismatch" vecs = [] target = self._target_from_image_size(image_size) for inst, tmpl, imap, tmap in zip(instructions, templates, image_maps, text_maps): proc_imap: Dict[str, Image.Image] = {} if imap: for k, im in imap.items(): if im is not None: proc_imap[k.lower()] = to_qwen_grid(im, target=target) content_list, images_in_order = self._build_content_from_template(tmpl or "", proc_imap, (tmap or {})) msgs = [] if inst and str(inst).strip(): msgs.append({"role": "system", "content": [{"type": "text", "text": inst}]}) msgs.append({"role": role, "content": content_list}) chat_text = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) proc = self.processor( text=[chat_text], images=images_in_order if images_in_order else None, return_tensors="pt", padding=True, truncation=True, do_resize=False, max_length=self.max_text_tokens, ) inputs = {k: v.to(device) for k, v in proc.items()} if "pixel_values" in inputs: inputs["pixel_values"] = inputs["pixel_values"].to(device=device, dtype=vision_dtype) if "image_grid_thw" in inputs: inputs["image_grid_thw"] = inputs["image_grid_thw"].to(device) out = self.vl(**inputs, output_hidden_states=True, use_cache=False) hidden = out.hidden_states[-1] # (1, S, H) span_mask = self._mask_last_role_block(inputs, hidden) # (1, S) if self.pool_mode == "latent_attention": pool_dtype = next(self.unified_pooler.parameters()).dtype if hidden.dtype != pool_dtype: hidden = hidden.to(dtype=pool_dtype) vec = self.unified_pooler(hidden, span_mask).squeeze(0) else: vec = masked_mean_pool(hidden, span_mask).squeeze(0) vecs.append(vec) embs = torch.stack(vecs, dim=0) proj_dtype = next(self.unified_proj.parameters()).dtype emb = self.unified_proj(embs.to(dtype=proj_dtype)) if normalize: emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12) return emb # ------------- (dual encoders for debugging) ------------- @torch.no_grad() def encode_text_dual(self, texts: List[str], normalize: bool = True) -> torch.Tensor: device = self.device tok = self.processor.tokenizer(text=texts, padding=True, truncation=True, return_tensors="pt", max_length=self.max_text_tokens) tok = {k: v.to(device) for k, v in tok.items()} lm = self._get_text_module() out = lm(**tok, output_hidden_states=True, use_cache=False) hidden = out.last_hidden_state mask = tok.get("attention_mask") pooled = masked_mean_pool(hidden, mask) proj_dtype = next(self.text_proj.parameters()).dtype emb = self.text_proj(pooled.to(dtype=proj_dtype)) if normalize: emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12) return emb @torch.no_grad() def encode_images_dual(self, images: List[List[Image.Image]], normalize: bool = True, image_size: Optional[int] = None) -> torch.Tensor: device = self.device flat = [img for group in images for img in group] counts = [len(g) for g in images] if len(flat) == 0: proj_dtype = next(self.image_proj.parameters()).dtype zeros = torch.zeros((len(images), self.vision_hidden), device=device, dtype=proj_dtype) emb = self.image_proj(zeros) if normalize: emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12) return emb target = self._target_from_image_size(image_size) processed = [to_qwen_grid(img, target=target) for img in flat] proc = self.processor.image_processor(images=processed, return_tensors="pt", do_resize=False) vm = self._get_vision_module() vision_dtype = next(vm.parameters()).dtype pixel_values = proc["pixel_values"].to(device=device, dtype=vision_dtype) vis_out = vm(pixel_values=pixel_values, output_hidden_states=True) feats = vis_out[0] if isinstance(vis_out, (tuple, list)) else getattr(vis_out, "last_hidden_state", None) if feats is None: feats = getattr(vis_out, "pooler_output", None) if feats is None: raise RuntimeError("Vision backbone did not return features as expected.") per_img = feats.mean(dim=1) if feats.ndim == 3 else feats splits = torch.split(per_img, counts, dim=0) set_vecs = torch.stack([s.mean(dim=0) if s.ndim > 1 else s for s in splits], dim=0) proj_dtype = next(self.image_proj.parameters()).dtype emb = self.image_proj(set_vecs.to(dtype=proj_dtype)) if normalize: emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12) return emb # ===================== PHRASE GROUNDING UTILS ===================== def _find_subsequence(self, haystack: list, needle: list) -> list: """Return start indices where 'needle' occurs in 'haystack' (exact match).""" if not haystack or not needle or len(needle) > len(haystack): return [] hits = [] n = len(needle) for i in range(len(haystack) - n + 1): if haystack[i:i+n] == needle: hits.append(i) return hits def _window_decode_matches(self, tokenizer, ids, target_lower: str) -> list: """Fallback: sliding-window decode match (robust to BPE splits). Returns window (start,end) indices.""" hits = [] L = len(ids) # Small cap on window length to avoid expensive decode; most medical terms fit <= 5 tokens. for w in range(1, 8): for i in range(0, L - w + 1): s, e = i, i + w text = tokenizer.decode(ids[s:e], skip_special_tokens=True).lower().replace(" ", "") if target_lower in text: hits.append((s, e)) # De-duplicate overlapping windows by preferring shortest span hits = sorted(set(hits), key=lambda x: (x[1]-x[0], x[0])) return hits def _resize_heatmap_like(self, hm_np, target_w, target_h): from PIL import Image import numpy as np # hm_np: (H, W) in [0,1]; resize with bilinear to (target_h, target_w) H, W = hm_np.shape im = Image.fromarray((hm_np * 255.0).astype("uint8"), mode="L") im = im.resize((target_w, target_h), Image.BILINEAR) out = (np.array(im).astype("float32") / 255.0) return out def _overlay_heatmap_on_image(self, img_pil, hm_np, alpha=0.45): """Return PIL with heatmap overlay; hm_np in [0,1] same size as img.""" import matplotlib import numpy as np from PIL import Image img = np.array(img_pil.convert("RGB")).astype("float32") / 255.0 H, W = img.shape[:2] hm = np.clip(hm_np, 0.0, 1.0) if hm.shape[:2] != (H, W): raise ValueError("Heatmap and image size mismatch") # Use a perceptually reasonable colormap without fixing colors for downstream tools. cmap = matplotlib.cm.get_cmap("jet") color_hm = cmap(hm)[..., :3] # (H,W,3) blended = (1.0 - alpha) * img + alpha * color_hm blended = np.clip(blended, 0.0, 1.0) return Image.fromarray((blended * 255).astype("uint8")) def phrase_ground_and_visualize( self, word: str, template: str, row, role: str = "user", instruction: str = None, image_size: int = None, # multiples of 28; defaults to self.image_size layer_for_text: int = -1, # which hidden_states layer to pull token reps from save_dir: str = None, # if set, saves overlays as PNGs return_arrays: bool = False, # if True, return heatmaps as numpy arrays ): """ Compute patch-level grounding for a word against images referenced in `template` filled by `row`. Returns a PhraseGroundingOutput, and optionally writes overlay PNGs. Strategy: - Build a single-sample chat like encode_interleaved_with_ph(). - Forward Qwen-VL with hidden_states (+ attention if available). - Locate word tokens inside last role block. - Run vision tower once to get per-patch features per image. - Project (text token avg) with text_proj, patches with image_proj; cosine sim per patch → heatmap. - (Optional) also compute LM self-attn from word tokens to any image placeholders if available. """ import os, numpy as np, torch from PIL import Image device = self.device tok = self.processor.tokenizer target = self._target_from_image_size(image_size) # --- Build content exactly like your training path --- imap = build_image_map_from_row(row, root="") # resize to Qwen grid (only for actually referenced keys) # We won't pre-filter keys; _build_content_from_template handles which placeholders are used. proc_imap = {k.lower(): to_qwen_grid(v, target=target) for k, v in (imap or {}).items() if v is not None} tmap = build_text_map_from_row(row) content_list, images_in_order = self._build_content_from_template(template or "", proc_imap, (tmap or {}), append_unused_images=False) msgs = [] if instruction and str(instruction).strip(): msgs.append({"role": "system", "content": [{"type": "text", "text": f"INSTRUCTION:\n{instruction}"}]}) msgs.append({"role": role, "content": content_list}) chat_text = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) vm = self._get_vision_module() vision_dtype = next(vm.parameters()).dtype proc = self.processor( text=[chat_text], images=images_in_order if images_in_order else None, return_tensors="pt", padding=True, truncation=True, do_resize=False, max_length=self.max_text_tokens, ) inputs = {k: v.to(device) for k, v in proc.items()} if "pixel_values" in inputs: inputs["pixel_values"] = inputs["pixel_values"].to(device=device, dtype=vision_dtype) if "image_grid_thw" in inputs: inputs["image_grid_thw"] = inputs["image_grid_thw"].to(device) # --- Forward with hidden states (+ attentions if the model exposes them) --- with torch.no_grad(): out = self.vl(**inputs, output_hidden_states=True, output_attentions=True, use_cache=False, return_dict=True) hidden = out.hidden_states[layer_for_text] # (1, S, H) span_mask = self._mask_last_role_block(inputs, hidden)[0].bool() # (S,) seq_ids = inputs["input_ids"][0].tolist() # --- Find token indices for the word inside the last role block --- # 1) exact subsequence match of token ids tgt_ids = tok(word, add_special_tokens=False)["input_ids"] last_role_positions = [i for i, m in enumerate(span_mask.tolist()) if m] id_seq_in_span = [seq_ids[i] for i in last_role_positions] hits = self._find_subsequence(id_seq_in_span, tgt_ids) token_span = None # (abs_start, abs_end) if hits: start_in_span = hits[0] abs_start = last_role_positions[start_in_span] abs_end = last_role_positions[start_in_span + len(tgt_ids) - 1] + 1 # exclusive token_span = (abs_start, abs_end) else: # 2) fallback: decode windows in-span and fuzzy match lowercase without spaces win_hits = self._window_decode_matches(tok, id_seq_in_span, target_lower=word.lower().replace(" ", "")) if win_hits: s, e = win_hits[0] abs_start = last_role_positions[s] abs_end = last_role_positions[e - 1] + 1 token_span = (abs_start, abs_end) if token_span is None: # If the word cannot be located, we center on the last token in the last-role block. # This keeps the visualization functional for debugging. last_idx = last_role_positions[-1] token_span = (last_idx, last_idx + 1) s_idx, e_idx = token_span word_tokens = hidden[0, s_idx:e_idx, :] # (T_word, Htxt) # Average sub-tokens → one vector word_vec_txt = word_tokens.mean(dim=0, keepdim=True) # (1, Htxt) # --- Get vision patch features per image --- heatmaps = [] per_image_debug = [] if "pixel_values" in inputs: # Use the TOP-LEVEL vision model entry vmodel = self._get_vision_entry() with torch.no_grad(): vout = vmodel( pixel_values=inputs["pixel_values"], grid_thw=inputs.get("image_grid_thw", None), output_hidden_states=True, return_dict=True, ) # vout.last_hidden_state: (B, Svis, C) vlast = vout.last_hidden_state B, Svis, C = vlast.shape # Grid sizes per image (T,H,W) grids = inputs.get("image_grid_thw", None) if grids is not None: # grids shape: (B, 3) => (T, H, W) thw = grids.detach().cpu().tolist() if isinstance(thw[0], (int, float)): # single image edge case thw = [thw] else: thw = [[1, int(round(Svis ** 0.5)), int(round(Svis ** 0.5))] for _ in range(B)] # If a CLS token exists, Svis == T*H*W + 1; drop it per_img = [] offset = 0 for i in range(B): t, h, w = map(int, thw[i]) tokens_per = t * h * w take_from = 1 if (Svis == tokens_per + 1) else 0 patches = vlast[i, take_from:take_from + tokens_per, :] # (T*H*W, C) per_img.append((patches, (t, h, w))) proj_dtype_img = next(self.image_proj.parameters()).dtype proj_dtype_txt = next(self.text_proj.parameters()).dtype word_vec = self.text_proj(word_vec_txt.to(dtype=proj_dtype_txt)) word_vec = word_vec / (word_vec.norm(dim=-1, keepdim=True) + 1e-12) for (patches, (t, h, w)) in per_img: patch_emb = self.image_proj(patches.to(dtype=proj_dtype_img)) patch_emb = patch_emb / (patch_emb.norm(dim=-1, keepdim=True) + 1e-12) sim = (patch_emb @ word_vec[0].T).squeeze(-1) # (P,) sim = sim.reshape(t, h, w).mean(dim=0) # (H, W) smin, smax = float(sim.min()), float(sim.max()) hm = (sim - smin) / max(1e-6, (smax - smin)) heatmaps.append(hm.detach().cpu().numpy()) per_image_debug.append({"tokens_per": t*h*w, "grid": (t, h, w)}) # --- Save overlays if requested --- saved_paths = [] if save_dir and heatmaps: os.makedirs(save_dir, exist_ok=True) for i, im in enumerate(images_in_order): # Ensure the heatmap is resized to the same (square) size we fed Qwen tgt_w, tgt_h = im.size hm_np = self._resize_heatmap_like(heatmaps[i], tgt_w, tgt_h) overlay = self._overlay_heatmap_on_image(im, hm_np, alpha=0.45) fname = os.path.join(save_dir, f"ground_{i:02d}_{word.replace(' ','_')}.png") overlay.save(fname) saved_paths.append(fname) result = PhraseGroundingOutput( token_span=(int(s_idx), int(e_idx)), per_image=[{ "heatmap": (heatmaps[i] if return_arrays else None), "saved_path": (saved_paths[i] if i < len(saved_paths) else None), "grid": per_image_debug[i].get("grid", None), "tokens_per": per_image_debug[i].get("tokens_per", None), "placeholder_attn": per_image_debug[i].get("placeholder_attn", None), } for i in range(len(heatmaps))] ) return result class PhraseGroundingOutput: def __init__(self, token_span, per_image): self.token_span = token_span # (start_idx, end_idx) within last-role span self.per_image = per_image # list of dicts with fields below