|
|
import math |
|
|
import os |
|
|
import re |
|
|
import json |
|
|
from typing import List, Optional, Dict, Tuple, Union |
|
|
|
|
|
from PIL import Image |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration |
|
|
|
|
|
|
|
|
_EMPTY_SENTINELS = {"", "-1", "none", "null", "na", "n/a", "nan", "<na>"} |
|
|
|
|
|
def _is_empty_cell(x) -> bool: |
|
|
"""True if x should be considered 'missing'.""" |
|
|
if x is None: |
|
|
return True |
|
|
|
|
|
try: |
|
|
if isinstance(x, float) and math.isnan(x): |
|
|
return True |
|
|
except Exception: |
|
|
pass |
|
|
s = str(x).strip().lower() |
|
|
return s in _EMPTY_SENTINELS |
|
|
|
|
|
def _clean_text_or_empty(x) -> str: |
|
|
"""Return a clean string or '' if missing.""" |
|
|
return "" if _is_empty_cell(x) else str(x).strip() |
|
|
|
|
|
try: |
|
|
from peft import LoraConfig, get_peft_model |
|
|
HAS_PEFT = True |
|
|
except Exception: |
|
|
HAS_PEFT = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-12) -> torch.Tensor: |
|
|
return x / (x.norm(dim=dim, keepdim=True) + eps) |
|
|
|
|
|
def masked_mean_pool(hidden: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
"""Mean over tokens where mask==True.""" |
|
|
if mask is None: |
|
|
return hidden.mean(dim=1) |
|
|
mask = mask.to(hidden.dtype) |
|
|
denom = mask.sum(dim=1, keepdim=True).clamp_min(1e-6) |
|
|
return (hidden * mask.unsqueeze(-1)).sum(dim=1) / denom |
|
|
|
|
|
def to_qwen_grid(img: Image.Image, target: int = 512, patch_size: int = 14, merge_size: int = 2) -> Image.Image: |
|
|
""" |
|
|
Resize image so H=W is a multiple of 28 (=patch_size*merge_size). |
|
|
FLOOR to nearest multiple (512->504, 1024->1008). |
|
|
""" |
|
|
grid = patch_size * merge_size |
|
|
new = max(grid, (target // grid) * grid) |
|
|
return img.resize((new, new), Image.BILINEAR) |
|
|
|
|
|
def _open_or_none(path: object, root: str = "") -> Optional[Image.Image]: |
|
|
"""Returns a PIL.Image or None. Handles '', NaN, '-1', <NA>, etc.""" |
|
|
if _is_empty_cell(path): |
|
|
return None |
|
|
p = str(path).strip() |
|
|
|
|
|
if root and not re.match(r'^[a-zA-Z][a-zA-Z0-9+\-.]*://', p): |
|
|
p = os.path.join(root, p) |
|
|
try: |
|
|
return Image.open(p).convert("RGB") |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
def build_image_map_from_row(row, root: str = "") -> dict: |
|
|
""" |
|
|
Mapping per your schema: |
|
|
- frontal_image <- img_path1 (also used as current_image) |
|
|
- lateral_image <- img_path2 |
|
|
- prior_image <- img_path3 |
|
|
""" |
|
|
m = { |
|
|
"frontal_image": _open_or_none(str(row.get("img_path1", "-1")), root), |
|
|
"lateral_image": _open_or_none(str(row.get("img_path2", "-1")), root), |
|
|
"prior_image": _open_or_none(str(row.get("img_path3", "-1")), root), |
|
|
} |
|
|
|
|
|
n1 = _open_or_none(str(row.get("neg_image1", row.get("neg_path1", "-1"))), root) |
|
|
n2 = _open_or_none(str(row.get("neg_image2", row.get("neg_path2", "-1"))), root) |
|
|
|
|
|
n3 = _open_or_none(str(row.get("neg_image3", row.get("neg_prior_image", row.get("neg_path3", "-1")))), root) |
|
|
if n1 is not None: |
|
|
m.update({"neg_image1": n1, "neg_path1": n1, "neg_frontal_image": n1}) |
|
|
if n2 is not None: |
|
|
m.update({"neg_image2": n2, "neg_path2": n2, "neg_lateral_image": n2}) |
|
|
if n3 is not None: |
|
|
m.update({"neg_prior_image": n3, "neg_image3": n3, "neg_path3": n3}) |
|
|
return m |
|
|
|
|
|
def _s(x): return "" if x is None else str(x) |
|
|
|
|
|
def build_text_map_from_row(row) -> Dict[str, str]: |
|
|
m = { |
|
|
"report": _clean_text_or_empty(row.get("report")), |
|
|
"prior_report": _clean_text_or_empty(row.get("prior_report")), |
|
|
"demographics": _clean_text_or_empty(row.get("demographics")), |
|
|
|
|
|
"lab_test": _clean_text_or_empty(row.get("lab_test")), |
|
|
"indication": _clean_text_or_empty(row.get("indication")), |
|
|
} |
|
|
|
|
|
return {k: v for k, v in m.items() if v} |
|
|
|
|
|
def parse_text_placeholders(s) -> dict: |
|
|
if isinstance(s, dict): |
|
|
d = s |
|
|
elif isinstance(s, str) and s.strip(): |
|
|
try: |
|
|
d = json.loads(s) |
|
|
except Exception: |
|
|
d = {} |
|
|
else: |
|
|
d = {} |
|
|
if not isinstance(d, dict): |
|
|
return {} |
|
|
out = {} |
|
|
for k, v in d.items(): |
|
|
val = _clean_text_or_empty(v) |
|
|
if val: |
|
|
out[str(k).lower()] = val |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LatentAttentionPooler(nn.Module): |
|
|
""" |
|
|
NV-Embed style: tokens (Q) attend to trainable latents (K=V), then MLP, |
|
|
then mean-pool over tokens (optionally masked). |
|
|
""" |
|
|
def __init__(self, dim: int, num_latents: int = 512, num_layers: int = 1, |
|
|
num_heads: int = 8, mlp_ratio: float = 2.0): |
|
|
super().__init__() |
|
|
self.latents = nn.Parameter(torch.randn(num_latents, dim) / math.sqrt(dim)) |
|
|
self.layers = nn.ModuleList() |
|
|
self.ln_q = nn.LayerNorm(dim) |
|
|
self.ln_kv = nn.LayerNorm(dim) |
|
|
|
|
|
for _ in range(num_layers): |
|
|
attn = nn.MultiheadAttention(dim, num_heads, batch_first=True) |
|
|
ffn = nn.Sequential( |
|
|
nn.Linear(dim, int(dim * mlp_ratio)), |
|
|
nn.GELU(), |
|
|
nn.Linear(int(dim * mlp_ratio), dim), |
|
|
) |
|
|
self.layers.append(nn.ModuleDict({"attn": attn, "ffn": ffn})) |
|
|
|
|
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
|
|
|
B, S, D = x.shape |
|
|
|
|
|
|
|
|
q = self.ln_q(x) |
|
|
lat = self.latents.unsqueeze(0).expand(B, -1, -1).contiguous() |
|
|
kv = self.ln_kv(lat) |
|
|
|
|
|
|
|
|
for blk in self.layers: |
|
|
y = blk["attn"](q, kv, kv, need_weights=False)[0] |
|
|
q = q + y |
|
|
q = q + blk["ffn"](q) |
|
|
|
|
|
|
|
|
return masked_mean_pool(q, mask) |
|
|
|
|
|
class Projection(nn.Module): |
|
|
def __init__(self, in_dim: int, out_dim: int = 1024, hidden: Optional[int] = None): |
|
|
super().__init__() |
|
|
if hidden is None: |
|
|
self.proj = nn.Sequential(nn.Linear(in_dim, out_dim, bias=False)) |
|
|
else: |
|
|
self.proj = nn.Sequential(nn.Linear(in_dim, hidden), nn.GELU(), nn.Linear(hidden, out_dim, bias=False)) |
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return l2norm(self.proj(x)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LingshuEmbedder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "lingshu-medical-mllm/Lingshu-7B", |
|
|
attn_implementation: str = "flash_attention_2", |
|
|
torch_dtype: torch.dtype = torch.bfloat16, |
|
|
embed_dim: int = 1024, |
|
|
|
|
|
|
|
|
pool_mode: str = "latent_attention", |
|
|
num_latents_unified: int = 512, |
|
|
|
|
|
|
|
|
image_size: int = 504, |
|
|
min_grid: int = 256, |
|
|
max_grid: int = 1296, |
|
|
|
|
|
|
|
|
|
|
|
use_lora: bool = False, |
|
|
lora_r: int = 64, lora_alpha: int = 64, lora_dropout: float = 0.0, |
|
|
apply_lora_to_vision: bool = False, |
|
|
|
|
|
|
|
|
bidirectional: bool = True, |
|
|
|
|
|
|
|
|
max_text_tokens: int = 2560, |
|
|
|
|
|
|
|
|
enable_gradient_checkpointing: bool = False, |
|
|
|
|
|
device: Optional[Union[str, torch.device]] = None, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
if device is None: |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
else: |
|
|
device = torch.device(device) |
|
|
if device.type != "cuda": |
|
|
attn_implementation = "sdpa" |
|
|
if torch_dtype in (torch.float16, torch.bfloat16): |
|
|
torch_dtype = torch.float32 |
|
|
|
|
|
|
|
|
self.vl = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
|
model_name, torch_dtype=torch_dtype, attn_implementation=attn_implementation |
|
|
) |
|
|
self.processor = AutoProcessor.from_pretrained( |
|
|
model_name, |
|
|
min_pixels=min_grid * 28 * 28, |
|
|
max_pixels=max_grid * 28 * 28, |
|
|
) |
|
|
self._propagate_attn_impl(attn_implementation) |
|
|
|
|
|
|
|
|
for p in self.vl.parameters(): |
|
|
p.requires_grad_(False) |
|
|
|
|
|
|
|
|
|
|
|
unfrozen_modules = [] |
|
|
for name, module in self.vl.named_modules(): |
|
|
|
|
|
if any(x in name.lower() for x in ['visual.merger', 'visual.proj', 'vision_proj', 'mm_projector']): |
|
|
n_params = sum(p.numel() for p in module.parameters()) |
|
|
for p in module.parameters(): |
|
|
p.requires_grad_(True) |
|
|
unfrozen_modules.append((name, n_params)) |
|
|
|
|
|
if unfrozen_modules: |
|
|
print(f"[model] Unfrozen vision projector modules for memorization:") |
|
|
for name, n_params in unfrozen_modules: |
|
|
print(f" - {name}: {n_params:,} parameters") |
|
|
|
|
|
|
|
|
txt_hidden = getattr(self.vl.config, "text_config", None) |
|
|
vis_hidden = getattr(self.vl.config, "vision_config", None) |
|
|
self.text_hidden = getattr(txt_hidden, "hidden_size", None) |
|
|
self.vision_hidden = getattr(vis_hidden, "out_hidden_size", None) or getattr(vis_hidden, "hidden_size", None) |
|
|
|
|
|
|
|
|
self.text_proj = Projection(self.text_hidden, embed_dim, hidden=None) |
|
|
self.image_proj = Projection(self.vision_hidden, embed_dim, hidden=None) |
|
|
self.unified_proj = Projection(self.text_hidden, embed_dim, hidden=None) |
|
|
|
|
|
self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07))) |
|
|
|
|
|
|
|
|
self.pool_mode = pool_mode |
|
|
if self.pool_mode == "latent_attention": |
|
|
self.unified_pooler = LatentAttentionPooler( |
|
|
dim=self.text_hidden, |
|
|
num_latents=num_latents_unified, |
|
|
num_layers=1, |
|
|
num_heads=8 |
|
|
) |
|
|
else: |
|
|
self.unified_pooler = None |
|
|
|
|
|
|
|
|
if image_size % 28 != 0: |
|
|
raise ValueError(f"image_size must be a multiple of 28, got {image_size}") |
|
|
self.image_size = image_size |
|
|
|
|
|
|
|
|
self.peft_active = False |
|
|
if use_lora: |
|
|
if not HAS_PEFT: |
|
|
raise ImportError("peft not installed") |
|
|
targets_text = ("q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj") |
|
|
targets_vision = ("qkv", "proj") |
|
|
targets = list(set(targets_text + (targets_vision if apply_lora_to_vision else tuple()))) |
|
|
cfg = LoraConfig(r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, |
|
|
target_modules=targets, bias="none", task_type="CAUSAL_LM") |
|
|
self.vl = get_peft_model(self.vl, cfg) |
|
|
self.peft_active = True |
|
|
|
|
|
|
|
|
if bidirectional: |
|
|
self._enable_bidirectional_attention() |
|
|
|
|
|
|
|
|
if enable_gradient_checkpointing: |
|
|
|
|
|
try: |
|
|
self.vl.gradient_checkpointing_enable( |
|
|
gradient_checkpointing_kwargs={"use_reentrant": False} |
|
|
) |
|
|
except TypeError: |
|
|
|
|
|
self.vl.gradient_checkpointing_enable() |
|
|
try: |
|
|
self.vl.config.use_cache = False |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
self.to(device) |
|
|
self.device = device |
|
|
|
|
|
|
|
|
base_dtype = next(self.vl.parameters()).dtype |
|
|
if getattr(self, "unified_pooler", None) is not None: |
|
|
self.unified_pooler.to(device=device, dtype=base_dtype) |
|
|
|
|
|
|
|
|
self.max_text_tokens = int(max_text_tokens) |
|
|
|
|
|
|
|
|
|
|
|
def _propagate_attn_impl(self, impl: str): |
|
|
cfgs = [getattr(self.vl, "config", None)] |
|
|
if cfgs[0] is not None: |
|
|
for sub in ("text_config", "vision_config"): |
|
|
cfgs.append(getattr(cfgs[0], sub, None)) |
|
|
for cfg in cfgs: |
|
|
if cfg is None: |
|
|
continue |
|
|
try: |
|
|
cfg._attn_implementation = impl |
|
|
cfg.attn_implementation = impl |
|
|
if hasattr(cfg, "use_flash_attention_2"): |
|
|
cfg.use_flash_attention_2 = (impl == "flash_attention_2") |
|
|
except Exception: |
|
|
pass |
|
|
for _, module in self.vl.named_modules(): |
|
|
if hasattr(module, "config"): |
|
|
try: |
|
|
module.config._attn_implementation = impl |
|
|
module.config.attn_implementation = impl |
|
|
if hasattr(module.config, "use_flash_attention_2"): |
|
|
module.config.use_flash_attention_2 = (impl == "flash_attention_2") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def _enable_bidirectional_attention(self): |
|
|
"""Best-effort removal of causal masking.""" |
|
|
cfg = getattr(self.vl, "config", None) |
|
|
if cfg is not None: |
|
|
if hasattr(cfg, "is_decoder"): cfg.is_decoder = False |
|
|
if hasattr(cfg, "use_cache"): cfg.use_cache = False |
|
|
core = getattr(self.vl, "model", self.vl) |
|
|
core_cfg = getattr(core, "config", None) |
|
|
if core_cfg is not None: |
|
|
if hasattr(core_cfg, "is_decoder"): core_cfg.is_decoder = False |
|
|
if hasattr(core_cfg, "use_cache"): core_cfg.use_cache = False |
|
|
for m in self.vl.modules(): |
|
|
if hasattr(m, "is_causal"): |
|
|
try: |
|
|
m.is_causal = False |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def _get_text_module(self): |
|
|
core = getattr(self.vl, "model", self.vl) |
|
|
for attr in ("language_model", "text_model", "lm"): |
|
|
m = getattr(core, attr, None) |
|
|
if m is not None and hasattr(m, "forward"): |
|
|
return m |
|
|
for _, module in self.vl.named_modules(): |
|
|
cname = module.__class__.__name__.lower() |
|
|
if "vision" in cname: |
|
|
continue |
|
|
if hasattr(module, "forward") and hasattr(module, "embed_tokens"): |
|
|
return module |
|
|
raise AttributeError("Could not locate the text submodule in Qwen-VL.") |
|
|
|
|
|
def _get_vision_module(self): |
|
|
core = getattr(self.vl, "model", self.vl) |
|
|
for attr in ("vision_model", "vision_tower", "visual", "vision"): |
|
|
m = getattr(core, attr, None) |
|
|
if m is not None and hasattr(m, "forward"): |
|
|
return m |
|
|
for _, module in self.vl.named_modules(): |
|
|
if "vision" in module.__class__.__name__.lower(): |
|
|
return module |
|
|
raise AttributeError("Could not locate the vision submodule in Qwen-VL.") |
|
|
|
|
|
def _get_vision_entry(self): |
|
|
""" |
|
|
Return the top-level VisionModel object that accepts: |
|
|
forward(pixel_values=..., grid_thw=..., output_hidden_states=..., return_dict=True) |
|
|
Avoid returning the inner transformer which expects (hidden_states, grid_thw). |
|
|
""" |
|
|
core = getattr(self.vl, "model", self.vl) |
|
|
|
|
|
vis = getattr(core, "vision_model", None) |
|
|
if vis is not None: |
|
|
return vis |
|
|
|
|
|
for _, m in core.named_modules(): |
|
|
name = m.__class__.__name__.lower() |
|
|
if name.endswith("visionmodel"): |
|
|
return m |
|
|
|
|
|
return self._get_vision_module() |
|
|
|
|
|
|
|
|
|
|
|
def _target_from_image_size(self, image_size: Optional[int]) -> int: |
|
|
""" |
|
|
Return a pixel target that will be floored to a multiple of 28 by to_qwen_grid(). |
|
|
Any multiple of 28 works (e.g., 448, 504, 1008). |
|
|
""" |
|
|
sz = image_size if isinstance(image_size, int) and image_size % 28 == 0 else self.image_size |
|
|
return int(sz) |
|
|
|
|
|
def _build_interleaved_content(self, text: str, imgs: List[Image.Image], append_unused_images: bool = False) -> Tuple[list, list]: |
|
|
""" |
|
|
NUMERIC placeholders: <image1>, <image2>, ... |
|
|
Returns (content_list, images_in_order). |
|
|
""" |
|
|
if text is None: |
|
|
text = "" |
|
|
content: list = [] |
|
|
ordered_images: list = [] |
|
|
imgs = imgs or [] |
|
|
|
|
|
pat = re.compile(r"<image\s*(\d+)\s*>", re.IGNORECASE) |
|
|
pos = 0 |
|
|
matches = list(pat.finditer(text)) |
|
|
|
|
|
if not matches: |
|
|
|
|
|
if text.strip(): |
|
|
content.append({"type": "text", "text": text}) |
|
|
if append_unused_images: |
|
|
for im in imgs: |
|
|
content.append({"type": "image", "image": im}) |
|
|
ordered_images.append(im) |
|
|
return content, ordered_images |
|
|
|
|
|
for m in matches: |
|
|
s, e = m.span() |
|
|
if s > pos: |
|
|
seg = text[pos:s] |
|
|
if seg.strip(): |
|
|
content.append({"type": "text", "text": seg}) |
|
|
idx = int(m.group(1)) - 1 |
|
|
if 0 <= idx < len(imgs): |
|
|
content.append({"type": "image", "image": imgs[idx]}) |
|
|
ordered_images.append(imgs[idx]) |
|
|
pos = e |
|
|
|
|
|
if pos < len(text): |
|
|
seg = text[pos:] |
|
|
if seg.strip(): |
|
|
content.append({"type": "text", "text": seg}) |
|
|
|
|
|
if append_unused_images: |
|
|
used = set(ordered_images) |
|
|
for im in imgs: |
|
|
if im not in used: |
|
|
content.append({"type": "image", "image": im}) |
|
|
ordered_images.append(im) |
|
|
|
|
|
return content, ordered_images |
|
|
|
|
|
def _build_content_from_template( |
|
|
self, |
|
|
template: str, |
|
|
image_map: Optional[Dict[str, Image.Image]], |
|
|
text_map: Optional[Dict[str, str]], |
|
|
append_unused_images: bool = False, |
|
|
) -> Tuple[list, list]: |
|
|
""" |
|
|
NAMED placeholders: <frontal_image>, <lateral_image>, <prior_image>, <report>, <prior_report>, <demographics>, ... |
|
|
Also supports alias: <current_image> -> <frontal_image>. |
|
|
""" |
|
|
template = template or "" |
|
|
image_map = {k.lower(): v for k, v in (image_map or {}).items() if v is not None} |
|
|
text_map = {k.lower(): v for k, v in (text_map or {}).items() if v is not None and str(v).strip()} |
|
|
|
|
|
content: list = [] |
|
|
images_in_order: list = [] |
|
|
|
|
|
pat = re.compile(r"<\s*([A-Za-z_]\w*)\s*>") |
|
|
pos = 0 |
|
|
for m in pat.finditer(template): |
|
|
s, e = m.span() |
|
|
if s > pos: |
|
|
seg = template[pos:s] |
|
|
if seg.strip(): |
|
|
content.append({"type": "text", "text": seg}) |
|
|
|
|
|
name = m.group(1).lower() |
|
|
|
|
|
if name == "current_image": |
|
|
name = "frontal_image" |
|
|
|
|
|
if name in image_map: |
|
|
img = image_map.get(name) |
|
|
if img is not None: |
|
|
content.append({"type": "image", "image": img}) |
|
|
images_in_order.append(img) |
|
|
else: |
|
|
val = text_map.get(name) |
|
|
if val is not None: |
|
|
content.append({"type": "text", "text": str(val)}) |
|
|
|
|
|
pos = e |
|
|
|
|
|
if pos < len(template): |
|
|
tail = template[pos:] |
|
|
if tail.strip(): |
|
|
content.append({"type": "text", "text": tail}) |
|
|
|
|
|
|
|
|
if append_unused_images: |
|
|
for key, img in image_map.items(): |
|
|
if img is not None and img not in images_in_order: |
|
|
content.append({"type": "image", "image": img}) |
|
|
images_in_order.append(img) |
|
|
|
|
|
return content, images_in_order |
|
|
|
|
|
def _mask_last_role_block(self, inputs: dict, hidden: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Boolean mask (B,S) selecting tokens inside the **last** role block (user/assistant), |
|
|
excluding the final <|im_end|>, for **any** batch size. |
|
|
Falls back to attention_mask if special tokens are unavailable. |
|
|
""" |
|
|
device = hidden.device |
|
|
ids = inputs.get("input_ids", None) |
|
|
attn = inputs.get("attention_mask", None) |
|
|
if ids is None: |
|
|
return (attn if attn is not None else torch.ones(hidden.shape[:2], device=device, dtype=torch.long)).bool() |
|
|
|
|
|
B, S = ids.shape |
|
|
mask = torch.zeros((B, S), device=device, dtype=torch.bool) |
|
|
|
|
|
|
|
|
try: |
|
|
start_id = self.processor.tokenizer.convert_tokens_to_ids("<|im_start|>") |
|
|
except Exception: |
|
|
start_id = None |
|
|
try: |
|
|
end_id = self.processor.tokenizer.convert_tokens_to_ids("<|im_end|>") |
|
|
except Exception: |
|
|
end_id = None |
|
|
|
|
|
if end_id is None: |
|
|
return (attn if attn is not None else torch.ones((B, S), device=device, dtype=torch.long)).bool() |
|
|
|
|
|
for b in range(B): |
|
|
|
|
|
if attn is not None: |
|
|
valid_len = int(attn[b].sum().item()) |
|
|
else: |
|
|
valid_len = S |
|
|
valid_len = max(1, min(valid_len, S)) |
|
|
seq = ids[b, :valid_len] |
|
|
|
|
|
ends = (seq == end_id).nonzero(as_tuple=False).flatten() |
|
|
if ends.numel() == 0: |
|
|
|
|
|
mask[b, :valid_len] = True |
|
|
continue |
|
|
last_end = int(ends[-1].item()) |
|
|
|
|
|
last_start = -1 |
|
|
if start_id is not None: |
|
|
starts = (seq == start_id).nonzero(as_tuple=False).flatten() |
|
|
starts_before = starts[starts < last_end] if starts.numel() > 0 else None |
|
|
if starts_before is not None and starts_before.numel() > 0: |
|
|
last_start = int(starts_before[-1].item()) |
|
|
elif ends.numel() >= 2: |
|
|
|
|
|
last_start = int(ends[-2].item()) |
|
|
else: |
|
|
if ends.numel() >= 2: |
|
|
last_start = int(ends[-2].item()) |
|
|
|
|
|
left = max(last_start + 1, 0) |
|
|
right = max(last_end - 1, left) |
|
|
mask[b, left:right + 1] = True |
|
|
|
|
|
if attn is not None: |
|
|
mask = mask & attn.bool() |
|
|
return mask |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_text_unified(self, instructions: List[Optional[str]], texts: List[str], role: str = "user", |
|
|
normalize: bool = True) -> torch.Tensor: |
|
|
"""Text-only, but still go through the unified VL path for consistency.""" |
|
|
empty_images = [[] for _ in texts] |
|
|
return self.encode_interleaved(instructions, texts, empty_images, role=role, normalize=normalize) |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_images_unified(self, instructions: List[Optional[str]], image_templates: List[str], |
|
|
image_maps: List[Dict[str, Image.Image]], role: str = "user", |
|
|
normalize: bool = True, image_size: Optional[int] = None) -> torch.Tensor: |
|
|
""" |
|
|
Image-only via unified path. Pass templates like "<frontal_image>" or "" (images only included if explicitly referenced). |
|
|
""" |
|
|
empty_text_maps = [{} for _ in image_templates] |
|
|
return self.encode_interleaved_with_ph(instructions, image_templates, image_maps, empty_text_maps, |
|
|
role=role, normalize=normalize, image_size=image_size) |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_interleaved( |
|
|
self, |
|
|
instructions: List[Optional[str]], |
|
|
contents: List[str], |
|
|
images: List[List[Image.Image]], |
|
|
role: str = "user", |
|
|
normalize: bool = True, |
|
|
image_size: Optional[int] = None, |
|
|
) -> torch.Tensor: |
|
|
device = self.device |
|
|
vm = self._get_vision_module() |
|
|
vision_dtype = next(vm.parameters()).dtype |
|
|
|
|
|
assert len(instructions) == len(contents) == len(images), "length mismatch" |
|
|
out_vecs = [] |
|
|
target = self._target_from_image_size(image_size) |
|
|
|
|
|
for inst, text, imgs in zip(instructions, contents, images): |
|
|
proc_imgs = [to_qwen_grid(im, target=target) for im in (imgs or [])] |
|
|
content_list, images_in_order = self._build_interleaved_content( |
|
|
text or "", proc_imgs, append_unused_images=False |
|
|
) |
|
|
|
|
|
msgs = [] |
|
|
if inst and str(inst).strip(): |
|
|
msgs.append({"role": "system", "content": [{"type": "text", "text": inst}]}) |
|
|
msgs.append({"role": role, "content": content_list}) |
|
|
|
|
|
chat_text = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) |
|
|
|
|
|
proc = self.processor( |
|
|
text=[chat_text], |
|
|
images=images_in_order if images_in_order else None, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
do_resize=False, |
|
|
max_length=self.max_text_tokens, |
|
|
) |
|
|
inputs = {k: v.to(device) for k, v in proc.items()} |
|
|
if "pixel_values" in inputs: |
|
|
inputs["pixel_values"] = inputs["pixel_values"].to(device=device, dtype=vision_dtype) |
|
|
if "image_grid_thw" in inputs: |
|
|
inputs["image_grid_thw"] = inputs["image_grid_thw"].to(device) |
|
|
|
|
|
out = self.vl(**inputs, output_hidden_states=True, use_cache=False) |
|
|
hidden = out.hidden_states[-1] |
|
|
span_mask = self._mask_last_role_block(inputs, hidden) |
|
|
|
|
|
if self.pool_mode == "latent_attention": |
|
|
pool_dtype = next(self.unified_pooler.parameters()).dtype |
|
|
if hidden.dtype != pool_dtype: |
|
|
hidden = hidden.to(dtype=pool_dtype) |
|
|
vec = self.unified_pooler(hidden, span_mask).squeeze(0) |
|
|
else: |
|
|
vec = masked_mean_pool(hidden, span_mask).squeeze(0) |
|
|
|
|
|
out_vecs.append(vec) |
|
|
|
|
|
embs = torch.stack(out_vecs, dim=0) |
|
|
proj_dtype = next(self.unified_proj.parameters()).dtype |
|
|
emb = self.unified_proj(embs.to(dtype=proj_dtype)) |
|
|
if normalize: |
|
|
emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12) |
|
|
return emb |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_interleaved_with_ph( |
|
|
self, |
|
|
instructions: List[Optional[str]], |
|
|
templates: List[str], |
|
|
image_maps: List[Optional[Dict[str, Image.Image]]], |
|
|
text_maps: List[Optional[Dict[str, str]]], |
|
|
role: str = "user", |
|
|
normalize: bool = True, |
|
|
image_size: Optional[int] = None, |
|
|
) -> torch.Tensor: |
|
|
device = self.device |
|
|
vm = self._get_vision_module() |
|
|
vision_dtype = next(vm.parameters()).dtype |
|
|
|
|
|
assert len(instructions) == len(templates) == len(image_maps) == len(text_maps), "length mismatch" |
|
|
|
|
|
vecs = [] |
|
|
target = self._target_from_image_size(image_size) |
|
|
|
|
|
for inst, tmpl, imap, tmap in zip(instructions, templates, image_maps, text_maps): |
|
|
proc_imap: Dict[str, Image.Image] = {} |
|
|
if imap: |
|
|
for k, im in imap.items(): |
|
|
if im is not None: |
|
|
proc_imap[k.lower()] = to_qwen_grid(im, target=target) |
|
|
|
|
|
content_list, images_in_order = self._build_content_from_template(tmpl or "", proc_imap, (tmap or {})) |
|
|
|
|
|
msgs = [] |
|
|
if inst and str(inst).strip(): |
|
|
msgs.append({"role": "system", "content": [{"type": "text", "text": inst}]}) |
|
|
msgs.append({"role": role, "content": content_list}) |
|
|
|
|
|
chat_text = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) |
|
|
|
|
|
proc = self.processor( |
|
|
text=[chat_text], |
|
|
images=images_in_order if images_in_order else None, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
do_resize=False, |
|
|
max_length=self.max_text_tokens, |
|
|
) |
|
|
inputs = {k: v.to(device) for k, v in proc.items()} |
|
|
if "pixel_values" in inputs: |
|
|
inputs["pixel_values"] = inputs["pixel_values"].to(device=device, dtype=vision_dtype) |
|
|
if "image_grid_thw" in inputs: |
|
|
inputs["image_grid_thw"] = inputs["image_grid_thw"].to(device) |
|
|
|
|
|
out = self.vl(**inputs, output_hidden_states=True, use_cache=False) |
|
|
hidden = out.hidden_states[-1] |
|
|
span_mask = self._mask_last_role_block(inputs, hidden) |
|
|
|
|
|
if self.pool_mode == "latent_attention": |
|
|
pool_dtype = next(self.unified_pooler.parameters()).dtype |
|
|
if hidden.dtype != pool_dtype: |
|
|
hidden = hidden.to(dtype=pool_dtype) |
|
|
vec = self.unified_pooler(hidden, span_mask).squeeze(0) |
|
|
else: |
|
|
vec = masked_mean_pool(hidden, span_mask).squeeze(0) |
|
|
|
|
|
vecs.append(vec) |
|
|
|
|
|
embs = torch.stack(vecs, dim=0) |
|
|
proj_dtype = next(self.unified_proj.parameters()).dtype |
|
|
emb = self.unified_proj(embs.to(dtype=proj_dtype)) |
|
|
if normalize: |
|
|
emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12) |
|
|
return emb |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_text_dual(self, texts: List[str], normalize: bool = True) -> torch.Tensor: |
|
|
device = self.device |
|
|
tok = self.processor.tokenizer(text=texts, padding=True, truncation=True, return_tensors="pt", max_length=self.max_text_tokens) |
|
|
tok = {k: v.to(device) for k, v in tok.items()} |
|
|
lm = self._get_text_module() |
|
|
out = lm(**tok, output_hidden_states=True, use_cache=False) |
|
|
hidden = out.last_hidden_state |
|
|
mask = tok.get("attention_mask") |
|
|
pooled = masked_mean_pool(hidden, mask) |
|
|
proj_dtype = next(self.text_proj.parameters()).dtype |
|
|
emb = self.text_proj(pooled.to(dtype=proj_dtype)) |
|
|
if normalize: |
|
|
emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12) |
|
|
return emb |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_images_dual(self, images: List[List[Image.Image]], normalize: bool = True, |
|
|
image_size: Optional[int] = None) -> torch.Tensor: |
|
|
device = self.device |
|
|
flat = [img for group in images for img in group] |
|
|
counts = [len(g) for g in images] |
|
|
if len(flat) == 0: |
|
|
proj_dtype = next(self.image_proj.parameters()).dtype |
|
|
zeros = torch.zeros((len(images), self.vision_hidden), device=device, dtype=proj_dtype) |
|
|
emb = self.image_proj(zeros) |
|
|
if normalize: |
|
|
emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12) |
|
|
return emb |
|
|
target = self._target_from_image_size(image_size) |
|
|
processed = [to_qwen_grid(img, target=target) for img in flat] |
|
|
proc = self.processor.image_processor(images=processed, return_tensors="pt", do_resize=False) |
|
|
vm = self._get_vision_module() |
|
|
vision_dtype = next(vm.parameters()).dtype |
|
|
pixel_values = proc["pixel_values"].to(device=device, dtype=vision_dtype) |
|
|
vis_out = vm(pixel_values=pixel_values, output_hidden_states=True) |
|
|
feats = vis_out[0] if isinstance(vis_out, (tuple, list)) else getattr(vis_out, "last_hidden_state", None) |
|
|
if feats is None: |
|
|
feats = getattr(vis_out, "pooler_output", None) |
|
|
if feats is None: |
|
|
raise RuntimeError("Vision backbone did not return features as expected.") |
|
|
per_img = feats.mean(dim=1) if feats.ndim == 3 else feats |
|
|
splits = torch.split(per_img, counts, dim=0) |
|
|
set_vecs = torch.stack([s.mean(dim=0) if s.ndim > 1 else s for s in splits], dim=0) |
|
|
proj_dtype = next(self.image_proj.parameters()).dtype |
|
|
emb = self.image_proj(set_vecs.to(dtype=proj_dtype)) |
|
|
if normalize: |
|
|
emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12) |
|
|
return emb |
|
|
|
|
|
|
|
|
|
|
|
def _find_subsequence(self, haystack: list, needle: list) -> list: |
|
|
"""Return start indices where 'needle' occurs in 'haystack' (exact match).""" |
|
|
if not haystack or not needle or len(needle) > len(haystack): |
|
|
return [] |
|
|
hits = [] |
|
|
n = len(needle) |
|
|
for i in range(len(haystack) - n + 1): |
|
|
if haystack[i:i+n] == needle: |
|
|
hits.append(i) |
|
|
return hits |
|
|
|
|
|
def _window_decode_matches(self, tokenizer, ids, target_lower: str) -> list: |
|
|
"""Fallback: sliding-window decode match (robust to BPE splits). Returns window (start,end) indices.""" |
|
|
hits = [] |
|
|
L = len(ids) |
|
|
|
|
|
for w in range(1, 8): |
|
|
for i in range(0, L - w + 1): |
|
|
s, e = i, i + w |
|
|
text = tokenizer.decode(ids[s:e], skip_special_tokens=True).lower().replace(" ", "") |
|
|
if target_lower in text: |
|
|
hits.append((s, e)) |
|
|
|
|
|
hits = sorted(set(hits), key=lambda x: (x[1]-x[0], x[0])) |
|
|
return hits |
|
|
|
|
|
def _resize_heatmap_like(self, hm_np, target_w, target_h): |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
H, W = hm_np.shape |
|
|
im = Image.fromarray((hm_np * 255.0).astype("uint8"), mode="L") |
|
|
im = im.resize((target_w, target_h), Image.BILINEAR) |
|
|
out = (np.array(im).astype("float32") / 255.0) |
|
|
return out |
|
|
|
|
|
def _overlay_heatmap_on_image(self, img_pil, hm_np, alpha=0.45): |
|
|
"""Return PIL with heatmap overlay; hm_np in [0,1] same size as img.""" |
|
|
import matplotlib |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
img = np.array(img_pil.convert("RGB")).astype("float32") / 255.0 |
|
|
H, W = img.shape[:2] |
|
|
hm = np.clip(hm_np, 0.0, 1.0) |
|
|
if hm.shape[:2] != (H, W): |
|
|
raise ValueError("Heatmap and image size mismatch") |
|
|
|
|
|
cmap = matplotlib.cm.get_cmap("jet") |
|
|
color_hm = cmap(hm)[..., :3] |
|
|
blended = (1.0 - alpha) * img + alpha * color_hm |
|
|
blended = np.clip(blended, 0.0, 1.0) |
|
|
return Image.fromarray((blended * 255).astype("uint8")) |
|
|
|
|
|
def phrase_ground_and_visualize( |
|
|
self, |
|
|
word: str, |
|
|
template: str, |
|
|
row, |
|
|
role: str = "user", |
|
|
instruction: str = None, |
|
|
image_size: int = None, |
|
|
layer_for_text: int = -1, |
|
|
save_dir: str = None, |
|
|
return_arrays: bool = False, |
|
|
): |
|
|
""" |
|
|
Compute patch-level grounding for a word against images referenced in `template` filled by `row`. |
|
|
Returns a PhraseGroundingOutput, and optionally writes overlay PNGs. |
|
|
|
|
|
Strategy: |
|
|
- Build a single-sample chat like encode_interleaved_with_ph(). |
|
|
- Forward Qwen-VL with hidden_states (+ attention if available). |
|
|
- Locate word tokens inside last role block. |
|
|
- Run vision tower once to get per-patch features per image. |
|
|
- Project (text token avg) with text_proj, patches with image_proj; cosine sim per patch → heatmap. |
|
|
- (Optional) also compute LM self-attn from word tokens to any image placeholders if available. |
|
|
""" |
|
|
import os, numpy as np, torch |
|
|
from PIL import Image |
|
|
|
|
|
device = self.device |
|
|
tok = self.processor.tokenizer |
|
|
target = self._target_from_image_size(image_size) |
|
|
|
|
|
|
|
|
imap = build_image_map_from_row(row, root="") |
|
|
|
|
|
|
|
|
proc_imap = {k.lower(): to_qwen_grid(v, target=target) for k, v in (imap or {}).items() if v is not None} |
|
|
tmap = build_text_map_from_row(row) |
|
|
|
|
|
content_list, images_in_order = self._build_content_from_template(template or "", proc_imap, (tmap or {}), append_unused_images=False) |
|
|
|
|
|
msgs = [] |
|
|
if instruction and str(instruction).strip(): |
|
|
msgs.append({"role": "system", "content": [{"type": "text", "text": f"INSTRUCTION:\n{instruction}"}]}) |
|
|
msgs.append({"role": role, "content": content_list}) |
|
|
chat_text = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) |
|
|
|
|
|
vm = self._get_vision_module() |
|
|
vision_dtype = next(vm.parameters()).dtype |
|
|
|
|
|
proc = self.processor( |
|
|
text=[chat_text], |
|
|
images=images_in_order if images_in_order else None, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
do_resize=False, |
|
|
max_length=self.max_text_tokens, |
|
|
) |
|
|
inputs = {k: v.to(device) for k, v in proc.items()} |
|
|
if "pixel_values" in inputs: |
|
|
inputs["pixel_values"] = inputs["pixel_values"].to(device=device, dtype=vision_dtype) |
|
|
if "image_grid_thw" in inputs: |
|
|
inputs["image_grid_thw"] = inputs["image_grid_thw"].to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
out = self.vl(**inputs, output_hidden_states=True, output_attentions=True, use_cache=False, return_dict=True) |
|
|
|
|
|
hidden = out.hidden_states[layer_for_text] |
|
|
span_mask = self._mask_last_role_block(inputs, hidden)[0].bool() |
|
|
seq_ids = inputs["input_ids"][0].tolist() |
|
|
|
|
|
|
|
|
|
|
|
tgt_ids = tok(word, add_special_tokens=False)["input_ids"] |
|
|
last_role_positions = [i for i, m in enumerate(span_mask.tolist()) if m] |
|
|
id_seq_in_span = [seq_ids[i] for i in last_role_positions] |
|
|
hits = self._find_subsequence(id_seq_in_span, tgt_ids) |
|
|
token_span = None |
|
|
if hits: |
|
|
start_in_span = hits[0] |
|
|
abs_start = last_role_positions[start_in_span] |
|
|
abs_end = last_role_positions[start_in_span + len(tgt_ids) - 1] + 1 |
|
|
token_span = (abs_start, abs_end) |
|
|
else: |
|
|
|
|
|
win_hits = self._window_decode_matches(tok, id_seq_in_span, target_lower=word.lower().replace(" ", "")) |
|
|
if win_hits: |
|
|
s, e = win_hits[0] |
|
|
abs_start = last_role_positions[s] |
|
|
abs_end = last_role_positions[e - 1] + 1 |
|
|
token_span = (abs_start, abs_end) |
|
|
|
|
|
if token_span is None: |
|
|
|
|
|
|
|
|
last_idx = last_role_positions[-1] |
|
|
token_span = (last_idx, last_idx + 1) |
|
|
|
|
|
s_idx, e_idx = token_span |
|
|
word_tokens = hidden[0, s_idx:e_idx, :] |
|
|
|
|
|
word_vec_txt = word_tokens.mean(dim=0, keepdim=True) |
|
|
|
|
|
|
|
|
heatmaps = [] |
|
|
per_image_debug = [] |
|
|
if "pixel_values" in inputs: |
|
|
|
|
|
vmodel = self._get_vision_entry() |
|
|
with torch.no_grad(): |
|
|
vout = vmodel( |
|
|
pixel_values=inputs["pixel_values"], |
|
|
grid_thw=inputs.get("image_grid_thw", None), |
|
|
output_hidden_states=True, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
|
|
|
vlast = vout.last_hidden_state |
|
|
B, Svis, C = vlast.shape |
|
|
|
|
|
|
|
|
grids = inputs.get("image_grid_thw", None) |
|
|
if grids is not None: |
|
|
|
|
|
thw = grids.detach().cpu().tolist() |
|
|
if isinstance(thw[0], (int, float)): |
|
|
thw = [thw] |
|
|
else: |
|
|
thw = [[1, int(round(Svis ** 0.5)), int(round(Svis ** 0.5))] for _ in range(B)] |
|
|
|
|
|
|
|
|
per_img = [] |
|
|
offset = 0 |
|
|
for i in range(B): |
|
|
t, h, w = map(int, thw[i]) |
|
|
tokens_per = t * h * w |
|
|
take_from = 1 if (Svis == tokens_per + 1) else 0 |
|
|
patches = vlast[i, take_from:take_from + tokens_per, :] |
|
|
per_img.append((patches, (t, h, w))) |
|
|
|
|
|
proj_dtype_img = next(self.image_proj.parameters()).dtype |
|
|
proj_dtype_txt = next(self.text_proj.parameters()).dtype |
|
|
|
|
|
word_vec = self.text_proj(word_vec_txt.to(dtype=proj_dtype_txt)) |
|
|
word_vec = word_vec / (word_vec.norm(dim=-1, keepdim=True) + 1e-12) |
|
|
|
|
|
for (patches, (t, h, w)) in per_img: |
|
|
patch_emb = self.image_proj(patches.to(dtype=proj_dtype_img)) |
|
|
patch_emb = patch_emb / (patch_emb.norm(dim=-1, keepdim=True) + 1e-12) |
|
|
sim = (patch_emb @ word_vec[0].T).squeeze(-1) |
|
|
sim = sim.reshape(t, h, w).mean(dim=0) |
|
|
smin, smax = float(sim.min()), float(sim.max()) |
|
|
hm = (sim - smin) / max(1e-6, (smax - smin)) |
|
|
heatmaps.append(hm.detach().cpu().numpy()) |
|
|
per_image_debug.append({"tokens_per": t*h*w, "grid": (t, h, w)}) |
|
|
|
|
|
|
|
|
saved_paths = [] |
|
|
if save_dir and heatmaps: |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
for i, im in enumerate(images_in_order): |
|
|
|
|
|
tgt_w, tgt_h = im.size |
|
|
hm_np = self._resize_heatmap_like(heatmaps[i], tgt_w, tgt_h) |
|
|
overlay = self._overlay_heatmap_on_image(im, hm_np, alpha=0.45) |
|
|
fname = os.path.join(save_dir, f"ground_{i:02d}_{word.replace(' ','_')}.png") |
|
|
overlay.save(fname) |
|
|
saved_paths.append(fname) |
|
|
|
|
|
result = PhraseGroundingOutput( |
|
|
token_span=(int(s_idx), int(e_idx)), |
|
|
per_image=[{ |
|
|
"heatmap": (heatmaps[i] if return_arrays else None), |
|
|
"saved_path": (saved_paths[i] if i < len(saved_paths) else None), |
|
|
"grid": per_image_debug[i].get("grid", None), |
|
|
"tokens_per": per_image_debug[i].get("tokens_per", None), |
|
|
"placeholder_attn": per_image_debug[i].get("placeholder_attn", None), |
|
|
} for i in range(len(heatmaps))] |
|
|
) |
|
|
return result |
|
|
|
|
|
|
|
|
class PhraseGroundingOutput: |
|
|
def __init__(self, token_span, per_image): |
|
|
self.token_span = token_span |
|
|
self.per_image = per_image |
|
|
|