Upload folder using huggingface_hub
Browse files- README.md +72 -0
- cxrembed/__init__.py +22 -0
- cxrembed/embedder.py +600 -0
- cxrembed_config.json +8 -0
- model_embed.py +1067 -0
- requirements.txt +5 -0
README.md
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CxREmbed (multi-image / multi-text unified embeddings)
|
| 2 |
+
|
| 3 |
+
This repository contains **lightweight inference code + trained embedding heads** for a multi-modal CXR embedding model built on top of the base **Lingshu-7B / Qwen2.5-VL** backbone.
|
| 4 |
+
|
| 5 |
+
The repo is structured to upload only the *delta weights* (LoRA adapter + pooling/projection heads). The base model weights remain in the original upstream repository.
|
| 6 |
+
|
| 7 |
+
## What is included
|
| 8 |
+
|
| 9 |
+
- `lora/` (optional) — PEFT LoRA adapter weights
|
| 10 |
+
- `unified_pooler.pt` — pooling head
|
| 11 |
+
- `unified_proj.pt` — projection head to the unified embedding space
|
| 12 |
+
- `text_proj.pt` / `image_proj.pt` (optional)
|
| 13 |
+
- `cxrembed_config.json` — minimal configuration
|
| 14 |
+
- `cxrembed/` — small Python package with an inference wrapper
|
| 15 |
+
|
| 16 |
+
## Quickstart
|
| 17 |
+
|
| 18 |
+
```python
|
| 19 |
+
import torch
|
| 20 |
+
from cxrembed import CxREmbedder
|
| 21 |
+
|
| 22 |
+
# Download from the Hub and load the backbone + adapters + heads
|
| 23 |
+
m = CxREmbedder.from_pretrained(
|
| 24 |
+
"<ORG>/<REPO>",
|
| 25 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 26 |
+
amp=True,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Embed a structured record (multi-image + multi-text)
|
| 30 |
+
emb = m.embed_record(
|
| 31 |
+
current_img="/path/to/current_frontal.png",
|
| 32 |
+
lateral_img="/path/to/lateral.png",
|
| 33 |
+
prior_img="/path/to/prior.png",
|
| 34 |
+
additional_img=None,
|
| 35 |
+
prior_report="...",
|
| 36 |
+
current_report="...",
|
| 37 |
+
demographics="Age 67, male",
|
| 38 |
+
lab_test="WBC 12.3",
|
| 39 |
+
history="SOB, fever",
|
| 40 |
+
additional_txt="Question: pneumonia?",
|
| 41 |
+
instruction="Embed this clinical record for retrieval.",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Embed a candidate answer (text-only)
|
| 45 |
+
ans = m.embed_answer("Right lower lobe consolidation consistent with pneumonia.")
|
| 46 |
+
|
| 47 |
+
# Similarity in embedding space
|
| 48 |
+
score = float((emb @ ans.T).item())
|
| 49 |
+
print(score)
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## Placeholders supported in templates
|
| 53 |
+
|
| 54 |
+
Images:
|
| 55 |
+
- `<current_image>` (alias of `<frontal_image>`)
|
| 56 |
+
- `<lateral_image>`
|
| 57 |
+
- `<prior_image>`
|
| 58 |
+
- `<additional_image>`
|
| 59 |
+
- `<additional_image1>`, `<additional_image2>`, ... if you pass a list to `additional_img`
|
| 60 |
+
|
| 61 |
+
Texts:
|
| 62 |
+
- `<current_report>` (alias `<report>`)
|
| 63 |
+
- `<prior_report>`
|
| 64 |
+
- `<demographics>`
|
| 65 |
+
- `<lab_test>`
|
| 66 |
+
- `<history>`
|
| 67 |
+
- `<additional_txt>`
|
| 68 |
+
|
| 69 |
+
## Notes
|
| 70 |
+
|
| 71 |
+
- This model is intended for **research** and may require additional validation for clinical use.
|
| 72 |
+
- Do not upload protected health information (PHI) to public repositories.
|
cxrembed/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""cxrembed
|
| 2 |
+
|
| 3 |
+
Lightweight Hugging Face Hub loader + inference utilities for a multi-image, multi-text
|
| 4 |
+
unified embedding model built on top of Lingshu/Qwen2.5-VL.
|
| 5 |
+
|
| 6 |
+
Primary entrypoint:
|
| 7 |
+
- CxREmbedder
|
| 8 |
+
|
| 9 |
+
This package is meant to live inside your Hugging Face model repo (or be vendored into
|
| 10 |
+
another codebase). It assumes the repo contains:
|
| 11 |
+
- lora/ (optional)
|
| 12 |
+
- unified_pooler.pt
|
| 13 |
+
- unified_proj.pt
|
| 14 |
+
- text_proj.pt (optional)
|
| 15 |
+
- image_proj.pt (optional)
|
| 16 |
+
- misc.pt (optional)
|
| 17 |
+
- cxrembed_config.json
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from .embedder import CxREmbedder, CxRInputs
|
| 21 |
+
|
| 22 |
+
__all__ = ["CxREmbedder", "CxRInputs"]
|
cxrembed/embedder.py
ADDED
|
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""cxrembed.embedder
|
| 3 |
+
|
| 4 |
+
A small, opinionated inference wrapper around your `LingshuEmbedder` (defined in
|
| 5 |
+
`model_embed.py`) that:
|
| 6 |
+
1) loads your projection/pooling heads + (optional) LoRA adapter from either
|
| 7 |
+
- a local training checkpoint directory, or
|
| 8 |
+
- a Hugging Face Hub repo (snapshot)
|
| 9 |
+
2) exposes a clean inference API that accepts explicit multi-image + multi-text inputs.
|
| 10 |
+
|
| 11 |
+
Why this wrapper?
|
| 12 |
+
- Your training/eval code is row/template driven. This wrapper offers an ergonomic
|
| 13 |
+
function signature while keeping the exact same placeholder interleaving path.
|
| 14 |
+
- It avoids implicitly attaching images/text unless a placeholder appears in the template.
|
| 15 |
+
|
| 16 |
+
Key assumptions about the checkpoint format (matching your training script):
|
| 17 |
+
- unified_pooler.pt
|
| 18 |
+
- unified_proj.pt
|
| 19 |
+
- text_proj.pt (optional)
|
| 20 |
+
- image_proj.pt (optional)
|
| 21 |
+
- misc.pt (optional; may include logit_scale)
|
| 22 |
+
- lora/ (optional; PEFT adapter)
|
| 23 |
+
|
| 24 |
+
This file is designed to be copied into your HF repo under `cxrembed/`.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
|
| 29 |
+
import json
|
| 30 |
+
import os
|
| 31 |
+
import re
|
| 32 |
+
from dataclasses import dataclass
|
| 33 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
from PIL import Image
|
| 39 |
+
except Exception: # pragma: no cover
|
| 40 |
+
Image = Any # type: ignore
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# --------------------------- Types ---------------------------
|
| 44 |
+
|
| 45 |
+
ImageLike = Union[str, "Image.Image"] # file path or PIL
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class CxRInputs:
|
| 50 |
+
"""Typed container for a single sample."""
|
| 51 |
+
|
| 52 |
+
# Images
|
| 53 |
+
current_img: Optional[ImageLike] = None
|
| 54 |
+
prior_img: Optional[ImageLike] = None
|
| 55 |
+
lateral_img: Optional[ImageLike] = None
|
| 56 |
+
additional_img: Optional[Union[ImageLike, Sequence[ImageLike]]] = None
|
| 57 |
+
|
| 58 |
+
# Texts
|
| 59 |
+
prior_report: Optional[str] = None
|
| 60 |
+
current_report: Optional[str] = None
|
| 61 |
+
demographics: Optional[str] = None
|
| 62 |
+
lab_test: Optional[str] = None
|
| 63 |
+
history: Optional[str] = None
|
| 64 |
+
additional_txt: Optional[str] = None
|
| 65 |
+
|
| 66 |
+
# Optional conditioning
|
| 67 |
+
instruction: Optional[str] = None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# --------------------------- Defaults ---------------------------
|
| 71 |
+
|
| 72 |
+
DEFAULT_RECORD_TEMPLATE = (
|
| 73 |
+
"<current_image> <lateral_image> <prior_image> <additional_image>\n"
|
| 74 |
+
"\n"
|
| 75 |
+
"DEMOGRAPHICS:\n<demographics>\n\n"
|
| 76 |
+
"HISTORY / INDICATION:\n<history>\n\n"
|
| 77 |
+
"LAB TESTS:\n<lab_test>\n\n"
|
| 78 |
+
"PRIOR REPORT:\n<prior_report>\n\n"
|
| 79 |
+
"CURRENT REPORT:\n<current_report>\n\n"
|
| 80 |
+
"<additional_txt>"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# We treat these as placeholders for named images.
|
| 84 |
+
# NOTE: LingshuEmbedder already aliases <current_image> -> <frontal_image>.
|
| 85 |
+
_IMAGE_KEYS = (
|
| 86 |
+
"current_image",
|
| 87 |
+
"frontal_image",
|
| 88 |
+
"lateral_image",
|
| 89 |
+
"prior_image",
|
| 90 |
+
"additional_image",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
_TEXT_KEYS = (
|
| 94 |
+
"prior_report",
|
| 95 |
+
"current_report",
|
| 96 |
+
"report", # alias of current_report
|
| 97 |
+
"demographics",
|
| 98 |
+
"lab_test",
|
| 99 |
+
"history",
|
| 100 |
+
"additional_txt",
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# --------------------------- Small helpers ---------------------------
|
| 105 |
+
|
| 106 |
+
_EMPTY_SENTINELS = {"", "-1", "none", "null", "na", "n/a", "nan", "<na>"}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _clean_text(x: Optional[str]) -> str:
|
| 110 |
+
if x is None:
|
| 111 |
+
return ""
|
| 112 |
+
s = str(x).strip()
|
| 113 |
+
return "" if s.lower() in _EMPTY_SENTINELS else s
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _load_image(x: Optional[ImageLike]) -> Optional["Image.Image"]:
|
| 117 |
+
"""Load an image from a path or pass-through a PIL Image."""
|
| 118 |
+
if x is None:
|
| 119 |
+
return None
|
| 120 |
+
if Image is Any:
|
| 121 |
+
raise ImportError("Pillow is required to load images. Please `pip install pillow`.")
|
| 122 |
+
if isinstance(x, str):
|
| 123 |
+
s = x.strip()
|
| 124 |
+
if not s or s.lower() in _EMPTY_SENTINELS:
|
| 125 |
+
return None
|
| 126 |
+
return Image.open(s).convert("RGB")
|
| 127 |
+
# PIL.Image.Image
|
| 128 |
+
if hasattr(x, "convert"):
|
| 129 |
+
return x.convert("RGB")
|
| 130 |
+
raise TypeError(f"Unsupported image type: {type(x)}")
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _tmpl_uses_any_named_image_ph(tmpl: str) -> bool:
|
| 134 |
+
s = (tmpl or "").lower()
|
| 135 |
+
# Cheap check (we avoid compiling regex per sample)
|
| 136 |
+
return any(f"<{k}>" in s for k in _IMAGE_KEYS) or ("<image" in s)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _tmpl_uses_any_text_ph(tmpl: str) -> bool:
|
| 140 |
+
s = (tmpl or "").lower()
|
| 141 |
+
return any(f"<{k}>" in s for k in _TEXT_KEYS)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _warn_missing_referenced_images(tmpl: str, image_map: Dict[str, Optional["Image.Image"]]):
|
| 145 |
+
"""Warn if the template referenced images but the caller didn't supply them."""
|
| 146 |
+
s = (tmpl or "").lower()
|
| 147 |
+
for k in _IMAGE_KEYS:
|
| 148 |
+
tag = f"<{k}>"
|
| 149 |
+
if tag in s and image_map.get(k) is None:
|
| 150 |
+
# Keep it as print (no logging dependency) since this is a template repo.
|
| 151 |
+
print(f"[cxrembed] WARNING: template references {tag} but it is missing")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _build_image_map_from_inputs(inp: CxRInputs) -> Dict[str, "Image.Image"]:
|
| 155 |
+
"""Build the named image map consumed by LingshuEmbedder._build_content_from_template."""
|
| 156 |
+
cur = _load_image(inp.current_img)
|
| 157 |
+
lat = _load_image(inp.lateral_img)
|
| 158 |
+
prv = _load_image(inp.prior_img)
|
| 159 |
+
|
| 160 |
+
out: Dict[str, Optional["Image.Image"]] = {
|
| 161 |
+
"current_image": cur,
|
| 162 |
+
"frontal_image": cur, # alias
|
| 163 |
+
"lateral_image": lat,
|
| 164 |
+
"prior_image": prv,
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
add = inp.additional_img
|
| 168 |
+
if add is None:
|
| 169 |
+
out["additional_image"] = None
|
| 170 |
+
elif isinstance(add, (list, tuple)):
|
| 171 |
+
# Provide both a generic alias and indexed placeholders: <additional_image1>, ...
|
| 172 |
+
add_list = [_load_image(x) for x in add]
|
| 173 |
+
out["additional_image"] = next((x for x in add_list if x is not None), None)
|
| 174 |
+
for i, im in enumerate(add_list, start=1):
|
| 175 |
+
if im is not None:
|
| 176 |
+
out[f"additional_image{i}"] = im
|
| 177 |
+
else:
|
| 178 |
+
out["additional_image"] = _load_image(add)
|
| 179 |
+
|
| 180 |
+
# Drop None to keep the template builder tight.
|
| 181 |
+
return {k: v for k, v in out.items() if v is not None}
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _build_text_map_from_inputs(inp: CxRInputs) -> Dict[str, str]:
|
| 185 |
+
cur_r = _clean_text(inp.current_report)
|
| 186 |
+
out = {
|
| 187 |
+
"current_report": cur_r,
|
| 188 |
+
"report": cur_r, # alias
|
| 189 |
+
"prior_report": _clean_text(inp.prior_report),
|
| 190 |
+
"demographics": _clean_text(inp.demographics),
|
| 191 |
+
"lab_test": _clean_text(inp.lab_test),
|
| 192 |
+
"history": _clean_text(inp.history),
|
| 193 |
+
"additional_txt": _clean_text(inp.additional_txt),
|
| 194 |
+
}
|
| 195 |
+
return {k: v for k, v in out.items() if v}
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def _read_json_if_exists(path: str) -> Dict[str, Any]:
|
| 199 |
+
if not path or not os.path.isfile(path):
|
| 200 |
+
return {}
|
| 201 |
+
try:
|
| 202 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 203 |
+
return json.load(f)
|
| 204 |
+
except Exception:
|
| 205 |
+
return {}
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# --------------------------- Main wrapper ---------------------------
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class CxREmbedder:
|
| 212 |
+
"""Inference wrapper.
|
| 213 |
+
|
| 214 |
+
Loads your Lingshu/Qwen2.5-VL backbone + (optional) LoRA + your embedding heads.
|
| 215 |
+
|
| 216 |
+
Minimal API:
|
| 217 |
+
- embed_record(...): embed a structured multi-modal record.
|
| 218 |
+
- embed_answer(text): embed a candidate answer in the same space.
|
| 219 |
+
- embed(...): lower-level template-based embedding.
|
| 220 |
+
|
| 221 |
+
Notes:
|
| 222 |
+
- The embedding path is **placeholder-aware**: images/texts are only attached
|
| 223 |
+
if the template includes the corresponding placeholders.
|
| 224 |
+
- The `mask_style` controls which tokens are pooled.
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
def __init__(
|
| 228 |
+
self,
|
| 229 |
+
model, # LingshuEmbedder (kept untyped to avoid hard dependency at import-time)
|
| 230 |
+
device: Union[str, torch.device] = "cuda",
|
| 231 |
+
amp: bool = True,
|
| 232 |
+
):
|
| 233 |
+
self.model = model
|
| 234 |
+
self.device = torch.device(device)
|
| 235 |
+
self.amp = bool(amp)
|
| 236 |
+
|
| 237 |
+
# Ensure eval mode
|
| 238 |
+
self.model.eval()
|
| 239 |
+
|
| 240 |
+
# ---------------------- constructors ----------------------
|
| 241 |
+
|
| 242 |
+
@classmethod
|
| 243 |
+
def from_local_checkpoint(
|
| 244 |
+
cls,
|
| 245 |
+
ckpt_dir: str,
|
| 246 |
+
*,
|
| 247 |
+
base_model_name: Optional[str] = None,
|
| 248 |
+
device: Union[str, torch.device] = "cuda",
|
| 249 |
+
amp: bool = True,
|
| 250 |
+
embed_dim: Optional[int] = None,
|
| 251 |
+
pool_mode: Optional[str] = None,
|
| 252 |
+
image_size: int = 504,
|
| 253 |
+
max_text_tokens: int = 1560,
|
| 254 |
+
apply_lora_to_vision: bool = False,
|
| 255 |
+
bidirectional: bool = True,
|
| 256 |
+
) -> "CxREmbedder":
|
| 257 |
+
"""Load from a training checkpoint directory.
|
| 258 |
+
|
| 259 |
+
Expected contents:
|
| 260 |
+
- unified_pooler.pt, unified_proj.pt, (optional) text_proj.pt, image_proj.pt
|
| 261 |
+
- misc.pt (optional)
|
| 262 |
+
- lora/ (optional)
|
| 263 |
+
- cxrembed_config.json (optional)
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
ckpt_dir = os.path.abspath(ckpt_dir)
|
| 267 |
+
cfg_path = os.path.join(ckpt_dir, "cxrembed_config.json")
|
| 268 |
+
cfg = _read_json_if_exists(cfg_path)
|
| 269 |
+
|
| 270 |
+
# Defer import so the package can be inspected without transformers installed.
|
| 271 |
+
from model_embed import LingshuEmbedder # type: ignore
|
| 272 |
+
|
| 273 |
+
if base_model_name is None:
|
| 274 |
+
base_model_name = cfg.get("base_model_name") or cfg.get("model_name") or "lingshu-medical-mllm/Lingshu-7B"
|
| 275 |
+
|
| 276 |
+
if embed_dim is None:
|
| 277 |
+
embed_dim = int(cfg.get("embed_dim", 1280))
|
| 278 |
+
if pool_mode is None:
|
| 279 |
+
pool_mode = str(cfg.get("pool_mode", "latent_attention"))
|
| 280 |
+
|
| 281 |
+
# If CUDA isn't available, force CPU.
|
| 282 |
+
dev = torch.device(device)
|
| 283 |
+
if dev.type == "cuda" and not torch.cuda.is_available():
|
| 284 |
+
print("[cxrembed] CUDA requested but not available; falling back to CPU")
|
| 285 |
+
dev = torch.device("cpu")
|
| 286 |
+
|
| 287 |
+
use_cuda = (dev.type == "cuda")
|
| 288 |
+
|
| 289 |
+
# IMPORTANT: if a LoRA folder exists, build the base backbone WITHOUT LoRA first,
|
| 290 |
+
# then load LoRA adapter weights with PEFT.
|
| 291 |
+
lora_dir = os.path.join(ckpt_dir, "lora")
|
| 292 |
+
force_no_lora = os.path.isdir(lora_dir)
|
| 293 |
+
|
| 294 |
+
m = LingshuEmbedder(
|
| 295 |
+
model_name=base_model_name,
|
| 296 |
+
attn_implementation=("flash_attention_2" if use_cuda else "sdpa"),
|
| 297 |
+
torch_dtype=(torch.bfloat16 if use_cuda else torch.float32),
|
| 298 |
+
embed_dim=int(embed_dim),
|
| 299 |
+
pool_mode=str(pool_mode),
|
| 300 |
+
image_size=int(image_size),
|
| 301 |
+
max_grid=1296,
|
| 302 |
+
bidirectional=bool(bidirectional),
|
| 303 |
+
use_lora=(False if force_no_lora else False), # never build train-time LoRA modules for inference
|
| 304 |
+
apply_lora_to_vision=bool(apply_lora_to_vision),
|
| 305 |
+
max_text_tokens=int(max_text_tokens),
|
| 306 |
+
enable_gradient_checkpointing=False,
|
| 307 |
+
device=str(dev),
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Load LoRA adapter (optional)
|
| 311 |
+
if os.path.isdir(lora_dir):
|
| 312 |
+
try:
|
| 313 |
+
from peft import PeftModel # type: ignore
|
| 314 |
+
|
| 315 |
+
m.vl = PeftModel.from_pretrained(m.vl, lora_dir, is_trainable=False)
|
| 316 |
+
if hasattr(m.vl, "set_adapter"):
|
| 317 |
+
m.vl.set_adapter("default")
|
| 318 |
+
print(f"[cxrembed] loaded LoRA adapter from: {lora_dir}")
|
| 319 |
+
except Exception as e:
|
| 320 |
+
print(f"[cxrembed] WARNING: failed to load LoRA adapter from {lora_dir}: {e}")
|
| 321 |
+
|
| 322 |
+
# Load heads (strict)
|
| 323 |
+
_load_heads(m, ckpt_dir, device=dev)
|
| 324 |
+
|
| 325 |
+
return cls(model=m, device=dev, amp=amp)
|
| 326 |
+
|
| 327 |
+
@classmethod
|
| 328 |
+
def from_pretrained(
|
| 329 |
+
cls,
|
| 330 |
+
repo_id: str,
|
| 331 |
+
*,
|
| 332 |
+
revision: Optional[str] = None,
|
| 333 |
+
cache_dir: Optional[str] = None,
|
| 334 |
+
device: Union[str, torch.device] = "cuda",
|
| 335 |
+
amp: bool = True,
|
| 336 |
+
**kwargs,
|
| 337 |
+
) -> "CxREmbedder":
|
| 338 |
+
"""Load from a Hugging Face Hub repo.
|
| 339 |
+
|
| 340 |
+
The repo should contain the same files as from_local_checkpoint().
|
| 341 |
+
|
| 342 |
+
Under the hood we `snapshot_download()` and then call from_local_checkpoint().
|
| 343 |
+
"""
|
| 344 |
+
try:
|
| 345 |
+
from huggingface_hub import snapshot_download # type: ignore
|
| 346 |
+
except Exception as e: # pragma: no cover
|
| 347 |
+
raise ImportError("Please `pip install huggingface_hub` to load from HF.") from e
|
| 348 |
+
|
| 349 |
+
local_dir = snapshot_download(
|
| 350 |
+
repo_id=repo_id,
|
| 351 |
+
revision=revision,
|
| 352 |
+
cache_dir=cache_dir,
|
| 353 |
+
local_files_only=False,
|
| 354 |
+
)
|
| 355 |
+
return cls.from_local_checkpoint(local_dir, device=device, amp=amp, **kwargs)
|
| 356 |
+
|
| 357 |
+
# ---------------------- public embedding API ----------------------
|
| 358 |
+
|
| 359 |
+
@torch.no_grad()
|
| 360 |
+
def embed_record(
|
| 361 |
+
self,
|
| 362 |
+
*,
|
| 363 |
+
current_img: Optional[ImageLike] = None,
|
| 364 |
+
prior_img: Optional[ImageLike] = None,
|
| 365 |
+
lateral_img: Optional[ImageLike] = None,
|
| 366 |
+
additional_img: Optional[Union[ImageLike, Sequence[ImageLike]]] = None,
|
| 367 |
+
prior_report: Optional[str] = None,
|
| 368 |
+
current_report: Optional[str] = None,
|
| 369 |
+
demographics: Optional[str] = None,
|
| 370 |
+
lab_test: Optional[str] = None,
|
| 371 |
+
history: Optional[str] = None,
|
| 372 |
+
additional_txt: Optional[str] = None,
|
| 373 |
+
instruction: Optional[str] = None,
|
| 374 |
+
template: str = DEFAULT_RECORD_TEMPLATE,
|
| 375 |
+
image_size: Optional[int] = None,
|
| 376 |
+
mask_style: str = "q_full_a_last",
|
| 377 |
+
normalize: bool = True,
|
| 378 |
+
) -> torch.Tensor:
|
| 379 |
+
"""Embed a single multi-modal record into the unified embedding space."""
|
| 380 |
+
inp = CxRInputs(
|
| 381 |
+
current_img=current_img,
|
| 382 |
+
prior_img=prior_img,
|
| 383 |
+
lateral_img=lateral_img,
|
| 384 |
+
additional_img=additional_img,
|
| 385 |
+
prior_report=prior_report,
|
| 386 |
+
current_report=current_report,
|
| 387 |
+
demographics=demographics,
|
| 388 |
+
lab_test=lab_test,
|
| 389 |
+
history=history,
|
| 390 |
+
additional_txt=additional_txt,
|
| 391 |
+
instruction=instruction,
|
| 392 |
+
)
|
| 393 |
+
return self.embed(
|
| 394 |
+
inputs=[inp],
|
| 395 |
+
templates=[template],
|
| 396 |
+
role="user",
|
| 397 |
+
image_size=image_size,
|
| 398 |
+
mask_style=mask_style,
|
| 399 |
+
normalize=normalize,
|
| 400 |
+
)[0]
|
| 401 |
+
|
| 402 |
+
@torch.no_grad()
|
| 403 |
+
def embed_answer(
|
| 404 |
+
self,
|
| 405 |
+
answer: str,
|
| 406 |
+
*,
|
| 407 |
+
normalize: bool = True,
|
| 408 |
+
mask_style: str = "q_full_a_last",
|
| 409 |
+
) -> torch.Tensor:
|
| 410 |
+
"""Embed a candidate answer string in the unified embedding space."""
|
| 411 |
+
inp = CxRInputs()
|
| 412 |
+
return self.embed(
|
| 413 |
+
inputs=[inp],
|
| 414 |
+
templates=[str(answer)],
|
| 415 |
+
role="assistant",
|
| 416 |
+
image_size=None,
|
| 417 |
+
mask_style=mask_style,
|
| 418 |
+
normalize=normalize,
|
| 419 |
+
)[0]
|
| 420 |
+
|
| 421 |
+
@torch.no_grad()
|
| 422 |
+
def embed(
|
| 423 |
+
self,
|
| 424 |
+
*,
|
| 425 |
+
inputs: Sequence[CxRInputs],
|
| 426 |
+
templates: Sequence[str],
|
| 427 |
+
role: str,
|
| 428 |
+
image_size: Optional[int] = None,
|
| 429 |
+
mask_style: str = "q_full_a_last",
|
| 430 |
+
normalize: bool = True,
|
| 431 |
+
) -> torch.Tensor:
|
| 432 |
+
"""Low-level batched embedding with templates.
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
inputs: list of CxRInputs
|
| 436 |
+
templates: list of templates (same length)
|
| 437 |
+
role: "user" or "assistant"
|
| 438 |
+
image_size: optional override (must be multiple of 28 for Qwen grid)
|
| 439 |
+
mask_style:
|
| 440 |
+
- q_full_a_last: queries use full attention (system+user), assistant uses last role block.
|
| 441 |
+
- both_full: both sides use full attention.
|
| 442 |
+
- both_last: both sides use last role block.
|
| 443 |
+
Returns:
|
| 444 |
+
torch.FloatTensor [B, D] on CPU.
|
| 445 |
+
"""
|
| 446 |
+
|
| 447 |
+
if role not in {"user", "assistant"}:
|
| 448 |
+
raise ValueError("role must be 'user' or 'assistant'")
|
| 449 |
+
if len(inputs) != len(templates):
|
| 450 |
+
raise ValueError("inputs and templates must have the same length")
|
| 451 |
+
|
| 452 |
+
wrapper_model = self.model
|
| 453 |
+
device = self.device
|
| 454 |
+
|
| 455 |
+
# We rely on internals from LingshuEmbedder.
|
| 456 |
+
vm = wrapper_model._get_vision_module()
|
| 457 |
+
vision_dtype = next(vm.parameters()).dtype
|
| 458 |
+
|
| 459 |
+
# Determine target image size (flooring to multiple of 28 happens inside to_qwen_grid).
|
| 460 |
+
target = wrapper_model._target_from_image_size(image_size)
|
| 461 |
+
|
| 462 |
+
texts: List[str] = []
|
| 463 |
+
flat_images: List["Image.Image"] = []
|
| 464 |
+
|
| 465 |
+
for inp, tmpl_raw in zip(inputs, templates):
|
| 466 |
+
tmpl = str(tmpl_raw or "")
|
| 467 |
+
|
| 468 |
+
want_img = _tmpl_uses_any_named_image_ph(tmpl)
|
| 469 |
+
want_txt = _tmpl_uses_any_text_ph(tmpl)
|
| 470 |
+
|
| 471 |
+
imap = _build_image_map_from_inputs(inp) if want_img else {}
|
| 472 |
+
tmap = _build_text_map_from_inputs(inp) if want_txt else {}
|
| 473 |
+
|
| 474 |
+
_warn_missing_referenced_images(tmpl, {k: imap.get(k) for k in _IMAGE_KEYS})
|
| 475 |
+
|
| 476 |
+
# Resize only the images actually present
|
| 477 |
+
if want_img and imap:
|
| 478 |
+
# Import here to reuse your exact resizing behavior.
|
| 479 |
+
from model_embed import to_qwen_grid # type: ignore
|
| 480 |
+
|
| 481 |
+
imap = {k.lower(): to_qwen_grid(im, target=target) for k, im in imap.items() if im is not None}
|
| 482 |
+
else:
|
| 483 |
+
imap = {}
|
| 484 |
+
|
| 485 |
+
content_list, images_in_order = wrapper_model._build_content_from_template(
|
| 486 |
+
tmpl, image_map=imap, text_map=tmap, append_unused_images=False
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# Ensure last-role masking is stable (non-empty role block)
|
| 490 |
+
if not content_list:
|
| 491 |
+
content_list = [{"type": "text", "text": " "}]
|
| 492 |
+
|
| 493 |
+
msgs = []
|
| 494 |
+
if role == "user":
|
| 495 |
+
inst = _clean_text(inp.instruction)
|
| 496 |
+
if inst:
|
| 497 |
+
msgs.append({"role": "system", "content": [{"type": "text", "text": f"INSTRUCTION:\n{inst}"}]})
|
| 498 |
+
msgs.append({"role": role, "content": content_list})
|
| 499 |
+
|
| 500 |
+
chat_text = wrapper_model.processor.apply_chat_template(
|
| 501 |
+
msgs, tokenize=False, add_generation_prompt=False
|
| 502 |
+
)
|
| 503 |
+
texts.append(chat_text)
|
| 504 |
+
for im in images_in_order:
|
| 505 |
+
flat_images.append(im)
|
| 506 |
+
|
| 507 |
+
proc = wrapper_model.processor(
|
| 508 |
+
text=texts,
|
| 509 |
+
images=(flat_images if len(flat_images) > 0 else None),
|
| 510 |
+
return_tensors="pt",
|
| 511 |
+
padding=True,
|
| 512 |
+
truncation=True,
|
| 513 |
+
max_length=getattr(wrapper_model, "max_text_tokens", 2560),
|
| 514 |
+
do_resize=False,
|
| 515 |
+
)
|
| 516 |
+
proc = {k: v.to(device) for k, v in proc.items()}
|
| 517 |
+
if "pixel_values" in proc:
|
| 518 |
+
proc["pixel_values"] = proc["pixel_values"].to(device=device, dtype=vision_dtype)
|
| 519 |
+
if "image_grid_thw" in proc:
|
| 520 |
+
proc["image_grid_thw"] = proc["image_grid_thw"].to(device)
|
| 521 |
+
|
| 522 |
+
autocast_dtype = torch.bfloat16 if (self.amp and device.type == "cuda") else None
|
| 523 |
+
with torch.autocast(
|
| 524 |
+
device_type="cuda",
|
| 525 |
+
dtype=autocast_dtype,
|
| 526 |
+
enabled=(autocast_dtype is not None and device.type == "cuda"),
|
| 527 |
+
):
|
| 528 |
+
out = wrapper_model.vl(**proc, output_hidden_states=True, use_cache=False)
|
| 529 |
+
hidden = out.hidden_states[-1]
|
| 530 |
+
|
| 531 |
+
# Pooling mask
|
| 532 |
+
if mask_style == "both_full" or (mask_style == "q_full_a_last" and role == "user"):
|
| 533 |
+
attn = proc.get("attention_mask", None)
|
| 534 |
+
span_mask = attn.bool() if attn is not None else torch.ones(hidden.shape[:2], device=hidden.device, dtype=torch.bool)
|
| 535 |
+
elif mask_style in {"both_last", "q_full_a_last"}:
|
| 536 |
+
span_mask = wrapper_model._mask_last_role_block(proc, hidden)
|
| 537 |
+
else:
|
| 538 |
+
raise ValueError(f"Unknown mask_style: {mask_style}")
|
| 539 |
+
|
| 540 |
+
if getattr(wrapper_model, "pool_mode", "latent_attention") == "latent_attention":
|
| 541 |
+
pooler = wrapper_model.unified_pooler
|
| 542 |
+
pool_dtype = next(pooler.parameters()).dtype
|
| 543 |
+
if hidden.dtype != pool_dtype:
|
| 544 |
+
hidden = hidden.to(dtype=pool_dtype)
|
| 545 |
+
vec = pooler(hidden, span_mask)
|
| 546 |
+
else:
|
| 547 |
+
mask = span_mask.to(hidden.dtype)
|
| 548 |
+
denom = mask.sum(dim=1, keepdim=True).clamp_min(1e-6)
|
| 549 |
+
vec = (hidden * mask.unsqueeze(-1)).sum(dim=1) / denom
|
| 550 |
+
|
| 551 |
+
proj = wrapper_model.unified_proj
|
| 552 |
+
proj_dtype = next(proj.parameters()).dtype
|
| 553 |
+
emb = proj(vec.to(dtype=proj_dtype))
|
| 554 |
+
|
| 555 |
+
if normalize:
|
| 556 |
+
emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
| 557 |
+
|
| 558 |
+
return emb.detach().float().cpu()
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
# --------------------------- checkpoint loading ---------------------------
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def _load_state_dict_strict(module: torch.nn.Module, path: str, device: torch.device):
|
| 565 |
+
state = torch.load(path, map_location=device)
|
| 566 |
+
if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()):
|
| 567 |
+
state = {k.replace("module.", "", 1): v for k, v in state.items()}
|
| 568 |
+
module.load_state_dict(state, strict=True)
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def _load_heads(model, ckpt_dir: str, device: torch.device):
|
| 572 |
+
"""Load your trained pooling/projection heads into a freshly constructed LingshuEmbedder."""
|
| 573 |
+
|
| 574 |
+
# files are named consistently with your training save_checkpoint() helper.
|
| 575 |
+
heads: List[Tuple[str, Optional[torch.nn.Module]]] = [
|
| 576 |
+
("unified_pooler.pt", getattr(model, "unified_pooler", None)),
|
| 577 |
+
("unified_proj.pt", getattr(model, "unified_proj", None)),
|
| 578 |
+
("text_proj.pt", getattr(model, "text_proj", None)),
|
| 579 |
+
("image_proj.pt", getattr(model, "image_proj", None)),
|
| 580 |
+
]
|
| 581 |
+
|
| 582 |
+
for fname, mod in heads:
|
| 583 |
+
if mod is None:
|
| 584 |
+
continue
|
| 585 |
+
p = os.path.join(ckpt_dir, fname)
|
| 586 |
+
if os.path.isfile(p):
|
| 587 |
+
_load_state_dict_strict(mod, p, device=device)
|
| 588 |
+
print(f"[cxrembed] loaded {fname}")
|
| 589 |
+
|
| 590 |
+
# Optional misc (temperature)
|
| 591 |
+
misc_p = os.path.join(ckpt_dir, "misc.pt")
|
| 592 |
+
if os.path.isfile(misc_p):
|
| 593 |
+
try:
|
| 594 |
+
misc = torch.load(misc_p, map_location="cpu")
|
| 595 |
+
if isinstance(misc, dict) and "logit_scale" in misc and hasattr(model, "logit_scale"):
|
| 596 |
+
with torch.no_grad():
|
| 597 |
+
model.logit_scale.data = torch.tensor(float(misc["logit_scale"]), dtype=torch.float32, device=device)
|
| 598 |
+
except Exception:
|
| 599 |
+
pass
|
| 600 |
+
|
cxrembed_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"base_model_name": "lingshu-medical-mllm/Lingshu-7B",
|
| 3 |
+
"embed_dim": 1280,
|
| 4 |
+
"pool_mode": "latent_attention",
|
| 5 |
+
"image_size": 504,
|
| 6 |
+
"max_text_tokens": 1560,
|
| 7 |
+
"format": "cxrembed_adapter_v1"
|
| 8 |
+
}
|
model_embed.py
ADDED
|
@@ -0,0 +1,1067 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import json
|
| 5 |
+
from typing import List, Optional, Dict, Tuple, Union
|
| 6 |
+
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
| 12 |
+
|
| 13 |
+
# Treat these as empty/missing (case-insensitive, whitespace-tolerant)
|
| 14 |
+
_EMPTY_SENTINELS = {"", "-1", "none", "null", "na", "n/a", "nan", "<na>"}
|
| 15 |
+
|
| 16 |
+
def _is_empty_cell(x) -> bool:
|
| 17 |
+
"""True if x should be considered 'missing'."""
|
| 18 |
+
if x is None:
|
| 19 |
+
return True
|
| 20 |
+
# float('nan') and numpy.float64('nan')
|
| 21 |
+
try:
|
| 22 |
+
if isinstance(x, float) and math.isnan(x):
|
| 23 |
+
return True
|
| 24 |
+
except Exception:
|
| 25 |
+
pass
|
| 26 |
+
s = str(x).strip().lower()
|
| 27 |
+
return s in _EMPTY_SENTINELS
|
| 28 |
+
|
| 29 |
+
def _clean_text_or_empty(x) -> str:
|
| 30 |
+
"""Return a clean string or '' if missing."""
|
| 31 |
+
return "" if _is_empty_cell(x) else str(x).strip()
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
from peft import LoraConfig, get_peft_model
|
| 35 |
+
HAS_PEFT = True
|
| 36 |
+
except Exception:
|
| 37 |
+
HAS_PEFT = False
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ----------------------- misc utils -----------------------
|
| 41 |
+
|
| 42 |
+
def l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-12) -> torch.Tensor:
|
| 43 |
+
return x / (x.norm(dim=dim, keepdim=True) + eps)
|
| 44 |
+
|
| 45 |
+
def masked_mean_pool(hidden: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 46 |
+
"""Mean over tokens where mask==True."""
|
| 47 |
+
if mask is None:
|
| 48 |
+
return hidden.mean(dim=1)
|
| 49 |
+
mask = mask.to(hidden.dtype)
|
| 50 |
+
denom = mask.sum(dim=1, keepdim=True).clamp_min(1e-6)
|
| 51 |
+
return (hidden * mask.unsqueeze(-1)).sum(dim=1) / denom
|
| 52 |
+
|
| 53 |
+
def to_qwen_grid(img: Image.Image, target: int = 512, patch_size: int = 14, merge_size: int = 2) -> Image.Image:
|
| 54 |
+
"""
|
| 55 |
+
Resize image so H=W is a multiple of 28 (=patch_size*merge_size).
|
| 56 |
+
FLOOR to nearest multiple (512->504, 1024->1008).
|
| 57 |
+
"""
|
| 58 |
+
grid = patch_size * merge_size # 28
|
| 59 |
+
new = max(grid, (target // grid) * grid)
|
| 60 |
+
return img.resize((new, new), Image.BILINEAR)
|
| 61 |
+
|
| 62 |
+
def _open_or_none(path: object, root: str = "") -> Optional[Image.Image]:
|
| 63 |
+
"""Returns a PIL.Image or None. Handles '', NaN, '-1', <NA>, etc."""
|
| 64 |
+
if _is_empty_cell(path):
|
| 65 |
+
return None
|
| 66 |
+
p = str(path).strip()
|
| 67 |
+
# Don't join URI-like paths
|
| 68 |
+
if root and not re.match(r'^[a-zA-Z][a-zA-Z0-9+\-.]*://', p):
|
| 69 |
+
p = os.path.join(root, p)
|
| 70 |
+
try:
|
| 71 |
+
return Image.open(p).convert("RGB")
|
| 72 |
+
except Exception:
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
def build_image_map_from_row(row, root: str = "") -> dict:
|
| 76 |
+
"""
|
| 77 |
+
Mapping per your schema:
|
| 78 |
+
- frontal_image <- img_path1 (also used as current_image)
|
| 79 |
+
- lateral_image <- img_path2
|
| 80 |
+
- prior_image <- img_path3
|
| 81 |
+
"""
|
| 82 |
+
m = {
|
| 83 |
+
"frontal_image": _open_or_none(str(row.get("img_path1", "-1")), root),
|
| 84 |
+
"lateral_image": _open_or_none(str(row.get("img_path2", "-1")), root),
|
| 85 |
+
"prior_image": _open_or_none(str(row.get("img_path3", "-1")), root),
|
| 86 |
+
}
|
| 87 |
+
# --- NEW: negative images available to templates ---
|
| 88 |
+
n1 = _open_or_none(str(row.get("neg_image1", row.get("neg_path1", "-1"))), root)
|
| 89 |
+
n2 = _open_or_none(str(row.get("neg_image2", row.get("neg_path2", "-1"))), root)
|
| 90 |
+
# support either column name for prior: neg_image3 or neg_prior_image, also neg_path3
|
| 91 |
+
n3 = _open_or_none(str(row.get("neg_image3", row.get("neg_prior_image", row.get("neg_path3", "-1")))), root)
|
| 92 |
+
if n1 is not None:
|
| 93 |
+
m.update({"neg_image1": n1, "neg_path1": n1, "neg_frontal_image": n1})
|
| 94 |
+
if n2 is not None:
|
| 95 |
+
m.update({"neg_image2": n2, "neg_path2": n2, "neg_lateral_image": n2})
|
| 96 |
+
if n3 is not None:
|
| 97 |
+
m.update({"neg_prior_image": n3, "neg_image3": n3, "neg_path3": n3})
|
| 98 |
+
return m
|
| 99 |
+
|
| 100 |
+
def _s(x): return "" if x is None else str(x)
|
| 101 |
+
|
| 102 |
+
def build_text_map_from_row(row) -> Dict[str, str]:
|
| 103 |
+
m = {
|
| 104 |
+
"report": _clean_text_or_empty(row.get("report")),
|
| 105 |
+
"prior_report": _clean_text_or_empty(row.get("prior_report")),
|
| 106 |
+
"demographics": _clean_text_or_empty(row.get("demographics")),
|
| 107 |
+
# --- NEW ---
|
| 108 |
+
"lab_test": _clean_text_or_empty(row.get("lab_test")),
|
| 109 |
+
"indication": _clean_text_or_empty(row.get("indication")),
|
| 110 |
+
}
|
| 111 |
+
# drop empties
|
| 112 |
+
return {k: v for k, v in m.items() if v}
|
| 113 |
+
|
| 114 |
+
def parse_text_placeholders(s) -> dict:
|
| 115 |
+
if isinstance(s, dict):
|
| 116 |
+
d = s
|
| 117 |
+
elif isinstance(s, str) and s.strip():
|
| 118 |
+
try:
|
| 119 |
+
d = json.loads(s)
|
| 120 |
+
except Exception:
|
| 121 |
+
d = {}
|
| 122 |
+
else:
|
| 123 |
+
d = {}
|
| 124 |
+
if not isinstance(d, dict):
|
| 125 |
+
return {}
|
| 126 |
+
out = {}
|
| 127 |
+
for k, v in d.items():
|
| 128 |
+
val = _clean_text_or_empty(v)
|
| 129 |
+
if val:
|
| 130 |
+
out[str(k).lower()] = val
|
| 131 |
+
return out
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# ----------------------- pooling modules -----------------------
|
| 135 |
+
|
| 136 |
+
class LatentAttentionPooler(nn.Module):
|
| 137 |
+
"""
|
| 138 |
+
NV-Embed style: tokens (Q) attend to trainable latents (K=V), then MLP,
|
| 139 |
+
then mean-pool over tokens (optionally masked).
|
| 140 |
+
"""
|
| 141 |
+
def __init__(self, dim: int, num_latents: int = 512, num_layers: int = 1,
|
| 142 |
+
num_heads: int = 8, mlp_ratio: float = 2.0):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim) / math.sqrt(dim))
|
| 145 |
+
self.layers = nn.ModuleList()
|
| 146 |
+
self.ln_q = nn.LayerNorm(dim) # for token queries
|
| 147 |
+
self.ln_kv = nn.LayerNorm(dim) # for latent K/V
|
| 148 |
+
|
| 149 |
+
for _ in range(num_layers):
|
| 150 |
+
attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
|
| 151 |
+
ffn = nn.Sequential(
|
| 152 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
| 153 |
+
nn.GELU(),
|
| 154 |
+
nn.Linear(int(dim * mlp_ratio), dim),
|
| 155 |
+
)
|
| 156 |
+
self.layers.append(nn.ModuleDict({"attn": attn, "ffn": ffn}))
|
| 157 |
+
|
| 158 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 159 |
+
# x: (B, S, D) last-layer token states from the LLM
|
| 160 |
+
B, S, D = x.shape
|
| 161 |
+
|
| 162 |
+
# Prepare Q (tokens) and K,V (trainable latents)
|
| 163 |
+
q = self.ln_q(x)
|
| 164 |
+
lat = self.latents.unsqueeze(0).expand(B, -1, -1).contiguous()
|
| 165 |
+
kv = self.ln_kv(lat)
|
| 166 |
+
|
| 167 |
+
# Cross-attn: tokens query the latent dictionary (no key padding mask on latents)
|
| 168 |
+
for blk in self.layers:
|
| 169 |
+
y = blk["attn"](q, kv, kv, need_weights=False)[0]
|
| 170 |
+
q = q + y # residual
|
| 171 |
+
q = q + blk["ffn"](q) # MLP + residual
|
| 172 |
+
|
| 173 |
+
# Mean-pool over **tokens**; mask only applied here
|
| 174 |
+
return masked_mean_pool(q, mask) # (B, D)
|
| 175 |
+
|
| 176 |
+
class Projection(nn.Module):
|
| 177 |
+
def __init__(self, in_dim: int, out_dim: int = 1024, hidden: Optional[int] = None):
|
| 178 |
+
super().__init__()
|
| 179 |
+
if hidden is None:
|
| 180 |
+
self.proj = nn.Sequential(nn.Linear(in_dim, out_dim, bias=False))
|
| 181 |
+
else:
|
| 182 |
+
self.proj = nn.Sequential(nn.Linear(in_dim, hidden), nn.GELU(), nn.Linear(hidden, out_dim, bias=False))
|
| 183 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 184 |
+
return l2norm(self.proj(x))
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# ----------------------- main wrapper -----------------------
|
| 188 |
+
|
| 189 |
+
class LingshuEmbedder(nn.Module):
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
model_name: str = "lingshu-medical-mllm/Lingshu-7B",
|
| 193 |
+
attn_implementation: str = "flash_attention_2",
|
| 194 |
+
torch_dtype: torch.dtype = torch.bfloat16,
|
| 195 |
+
embed_dim: int = 1024,
|
| 196 |
+
|
| 197 |
+
# unified pooling mode
|
| 198 |
+
pool_mode: str = "latent_attention", # "latent_attention" | "mean"
|
| 199 |
+
num_latents_unified: int = 512,
|
| 200 |
+
|
| 201 |
+
# image grid control (supports 504 and 1008)
|
| 202 |
+
image_size: int = 504, # default grid; per-call override allowed (504 or 1008)
|
| 203 |
+
min_grid: int = 256,
|
| 204 |
+
max_grid: int = 1296, # up to 36x36 (for 1008)
|
| 205 |
+
|
| 206 |
+
# LoRA (optional) - tuned for memorization
|
| 207 |
+
# r=64 for balanced performance; increase to 128 if VRAM allows
|
| 208 |
+
use_lora: bool = False,
|
| 209 |
+
lora_r: int = 64, lora_alpha: int = 64, lora_dropout: float = 0.0, # alpha=r, dropout=0 for memorization
|
| 210 |
+
apply_lora_to_vision: bool = False,
|
| 211 |
+
|
| 212 |
+
# make attention bi-directional (remove causal masking)
|
| 213 |
+
bidirectional: bool = True,
|
| 214 |
+
|
| 215 |
+
# text token budget (read by the training script)
|
| 216 |
+
max_text_tokens: int = 2560,
|
| 217 |
+
|
| 218 |
+
# gradient checkpointing
|
| 219 |
+
enable_gradient_checkpointing: bool = False,
|
| 220 |
+
|
| 221 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 222 |
+
) -> None:
|
| 223 |
+
super().__init__()
|
| 224 |
+
|
| 225 |
+
# ---- device & backend ----
|
| 226 |
+
if device is None:
|
| 227 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 228 |
+
else:
|
| 229 |
+
device = torch.device(device)
|
| 230 |
+
if device.type != "cuda":
|
| 231 |
+
attn_implementation = "sdpa"
|
| 232 |
+
if torch_dtype in (torch.float16, torch.bfloat16):
|
| 233 |
+
torch_dtype = torch.float32
|
| 234 |
+
|
| 235 |
+
# ---- load backbone + processor ----
|
| 236 |
+
self.vl = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 237 |
+
model_name, torch_dtype=torch_dtype, attn_implementation=attn_implementation
|
| 238 |
+
)
|
| 239 |
+
self.processor = AutoProcessor.from_pretrained(
|
| 240 |
+
model_name,
|
| 241 |
+
min_pixels=min_grid * 28 * 28,
|
| 242 |
+
max_pixels=max_grid * 28 * 28,
|
| 243 |
+
)
|
| 244 |
+
self._propagate_attn_impl(attn_implementation)
|
| 245 |
+
|
| 246 |
+
# freeze base
|
| 247 |
+
for p in self.vl.parameters():
|
| 248 |
+
p.requires_grad_(False)
|
| 249 |
+
|
| 250 |
+
# UNFREEZE vision projector for better image→text binding
|
| 251 |
+
# Qwen2.5-VL has a visual projection module
|
| 252 |
+
unfrozen_modules = []
|
| 253 |
+
for name, module in self.vl.named_modules():
|
| 254 |
+
# Look for vision projector: often named 'visual', 'vision_proj', 'mm_projector', etc.
|
| 255 |
+
if any(x in name.lower() for x in ['visual.merger', 'visual.proj', 'vision_proj', 'mm_projector']):
|
| 256 |
+
n_params = sum(p.numel() for p in module.parameters())
|
| 257 |
+
for p in module.parameters():
|
| 258 |
+
p.requires_grad_(True)
|
| 259 |
+
unfrozen_modules.append((name, n_params))
|
| 260 |
+
|
| 261 |
+
if unfrozen_modules:
|
| 262 |
+
print(f"[model] Unfrozen vision projector modules for memorization:")
|
| 263 |
+
for name, n_params in unfrozen_modules:
|
| 264 |
+
print(f" - {name}: {n_params:,} parameters")
|
| 265 |
+
|
| 266 |
+
# dims
|
| 267 |
+
txt_hidden = getattr(self.vl.config, "text_config", None)
|
| 268 |
+
vis_hidden = getattr(self.vl.config, "vision_config", None)
|
| 269 |
+
self.text_hidden = getattr(txt_hidden, "hidden_size", None)
|
| 270 |
+
self.vision_hidden = getattr(vis_hidden, "out_hidden_size", None) or getattr(vis_hidden, "hidden_size", None)
|
| 271 |
+
|
| 272 |
+
# projections (unified/text/image all project to same embed_dim space)
|
| 273 |
+
self.text_proj = Projection(self.text_hidden, embed_dim, hidden=None)
|
| 274 |
+
self.image_proj = Projection(self.vision_hidden, embed_dim, hidden=None)
|
| 275 |
+
self.unified_proj = Projection(self.text_hidden, embed_dim, hidden=None)
|
| 276 |
+
|
| 277 |
+
self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07)))
|
| 278 |
+
|
| 279 |
+
# unified pooling config
|
| 280 |
+
self.pool_mode = pool_mode
|
| 281 |
+
if self.pool_mode == "latent_attention":
|
| 282 |
+
self.unified_pooler = LatentAttentionPooler(
|
| 283 |
+
dim=self.text_hidden,
|
| 284 |
+
num_latents=num_latents_unified, # set default to 512 to match paper
|
| 285 |
+
num_layers=1,
|
| 286 |
+
num_heads=8
|
| 287 |
+
)
|
| 288 |
+
else:
|
| 289 |
+
self.unified_pooler = None
|
| 290 |
+
|
| 291 |
+
# image size handling (any multiple of 28 is allowed, e.g., 448, 504, 1008)
|
| 292 |
+
if image_size % 28 != 0:
|
| 293 |
+
raise ValueError(f"image_size must be a multiple of 28, got {image_size}")
|
| 294 |
+
self.image_size = image_size # default; can override per call
|
| 295 |
+
|
| 296 |
+
# optional LoRA
|
| 297 |
+
self.peft_active = False
|
| 298 |
+
if use_lora:
|
| 299 |
+
if not HAS_PEFT:
|
| 300 |
+
raise ImportError("peft not installed")
|
| 301 |
+
targets_text = ("q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj")
|
| 302 |
+
targets_vision = ("qkv", "proj")
|
| 303 |
+
targets = list(set(targets_text + (targets_vision if apply_lora_to_vision else tuple())))
|
| 304 |
+
cfg = LoraConfig(r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
| 305 |
+
target_modules=targets, bias="none", task_type="CAUSAL_LM")
|
| 306 |
+
self.vl = get_peft_model(self.vl, cfg)
|
| 307 |
+
self.peft_active = True
|
| 308 |
+
|
| 309 |
+
# make bi-directional if requested
|
| 310 |
+
if bidirectional:
|
| 311 |
+
self._enable_bidirectional_attention()
|
| 312 |
+
|
| 313 |
+
# gradient checkpointing
|
| 314 |
+
if enable_gradient_checkpointing:
|
| 315 |
+
# Use the non-reentrant variant to avoid "requires_grad" warnings
|
| 316 |
+
try:
|
| 317 |
+
self.vl.gradient_checkpointing_enable(
|
| 318 |
+
gradient_checkpointing_kwargs={"use_reentrant": False}
|
| 319 |
+
)
|
| 320 |
+
except TypeError:
|
| 321 |
+
# older transformers fallback
|
| 322 |
+
self.vl.gradient_checkpointing_enable()
|
| 323 |
+
try:
|
| 324 |
+
self.vl.config.use_cache = False
|
| 325 |
+
except Exception:
|
| 326 |
+
pass
|
| 327 |
+
|
| 328 |
+
# move to device
|
| 329 |
+
self.to(device)
|
| 330 |
+
self.device = device
|
| 331 |
+
|
| 332 |
+
# align pooler dtype with model (and device)
|
| 333 |
+
base_dtype = next(self.vl.parameters()).dtype
|
| 334 |
+
if getattr(self, "unified_pooler", None) is not None:
|
| 335 |
+
self.unified_pooler.to(device=device, dtype=base_dtype)
|
| 336 |
+
|
| 337 |
+
# expose text token budget for processor calls in training script
|
| 338 |
+
self.max_text_tokens = int(max_text_tokens)
|
| 339 |
+
|
| 340 |
+
# ---------- internals ----------
|
| 341 |
+
|
| 342 |
+
def _propagate_attn_impl(self, impl: str):
|
| 343 |
+
cfgs = [getattr(self.vl, "config", None)]
|
| 344 |
+
if cfgs[0] is not None:
|
| 345 |
+
for sub in ("text_config", "vision_config"):
|
| 346 |
+
cfgs.append(getattr(cfgs[0], sub, None))
|
| 347 |
+
for cfg in cfgs:
|
| 348 |
+
if cfg is None:
|
| 349 |
+
continue
|
| 350 |
+
try:
|
| 351 |
+
cfg._attn_implementation = impl
|
| 352 |
+
cfg.attn_implementation = impl
|
| 353 |
+
if hasattr(cfg, "use_flash_attention_2"):
|
| 354 |
+
cfg.use_flash_attention_2 = (impl == "flash_attention_2")
|
| 355 |
+
except Exception:
|
| 356 |
+
pass
|
| 357 |
+
for _, module in self.vl.named_modules():
|
| 358 |
+
if hasattr(module, "config"):
|
| 359 |
+
try:
|
| 360 |
+
module.config._attn_implementation = impl
|
| 361 |
+
module.config.attn_implementation = impl
|
| 362 |
+
if hasattr(module.config, "use_flash_attention_2"):
|
| 363 |
+
module.config.use_flash_attention_2 = (impl == "flash_attention_2")
|
| 364 |
+
except Exception:
|
| 365 |
+
pass
|
| 366 |
+
|
| 367 |
+
def _enable_bidirectional_attention(self):
|
| 368 |
+
"""Best-effort removal of causal masking."""
|
| 369 |
+
cfg = getattr(self.vl, "config", None)
|
| 370 |
+
if cfg is not None:
|
| 371 |
+
if hasattr(cfg, "is_decoder"): cfg.is_decoder = False
|
| 372 |
+
if hasattr(cfg, "use_cache"): cfg.use_cache = False
|
| 373 |
+
core = getattr(self.vl, "model", self.vl)
|
| 374 |
+
core_cfg = getattr(core, "config", None)
|
| 375 |
+
if core_cfg is not None:
|
| 376 |
+
if hasattr(core_cfg, "is_decoder"): core_cfg.is_decoder = False
|
| 377 |
+
if hasattr(core_cfg, "use_cache"): core_cfg.use_cache = False
|
| 378 |
+
for m in self.vl.modules():
|
| 379 |
+
if hasattr(m, "is_causal"):
|
| 380 |
+
try:
|
| 381 |
+
m.is_causal = False
|
| 382 |
+
except Exception:
|
| 383 |
+
pass
|
| 384 |
+
|
| 385 |
+
def _get_text_module(self):
|
| 386 |
+
core = getattr(self.vl, "model", self.vl)
|
| 387 |
+
for attr in ("language_model", "text_model", "lm"):
|
| 388 |
+
m = getattr(core, attr, None)
|
| 389 |
+
if m is not None and hasattr(m, "forward"):
|
| 390 |
+
return m
|
| 391 |
+
for _, module in self.vl.named_modules():
|
| 392 |
+
cname = module.__class__.__name__.lower()
|
| 393 |
+
if "vision" in cname:
|
| 394 |
+
continue
|
| 395 |
+
if hasattr(module, "forward") and hasattr(module, "embed_tokens"):
|
| 396 |
+
return module
|
| 397 |
+
raise AttributeError("Could not locate the text submodule in Qwen-VL.")
|
| 398 |
+
|
| 399 |
+
def _get_vision_module(self):
|
| 400 |
+
core = getattr(self.vl, "model", self.vl)
|
| 401 |
+
for attr in ("vision_model", "vision_tower", "visual", "vision"):
|
| 402 |
+
m = getattr(core, attr, None)
|
| 403 |
+
if m is not None and hasattr(m, "forward"):
|
| 404 |
+
return m
|
| 405 |
+
for _, module in self.vl.named_modules():
|
| 406 |
+
if "vision" in module.__class__.__name__.lower():
|
| 407 |
+
return module
|
| 408 |
+
raise AttributeError("Could not locate the vision submodule in Qwen-VL.")
|
| 409 |
+
|
| 410 |
+
def _get_vision_entry(self):
|
| 411 |
+
"""
|
| 412 |
+
Return the top-level VisionModel object that accepts:
|
| 413 |
+
forward(pixel_values=..., grid_thw=..., output_hidden_states=..., return_dict=True)
|
| 414 |
+
Avoid returning the inner transformer which expects (hidden_states, grid_thw).
|
| 415 |
+
"""
|
| 416 |
+
core = getattr(self.vl, "model", self.vl)
|
| 417 |
+
# Prefer the canonical attribute if present
|
| 418 |
+
vis = getattr(core, "vision_model", None)
|
| 419 |
+
if vis is not None:
|
| 420 |
+
return vis
|
| 421 |
+
# Fallback: search modules for something named *VisionModel
|
| 422 |
+
for _, m in core.named_modules():
|
| 423 |
+
name = m.__class__.__name__.lower()
|
| 424 |
+
if name.endswith("visionmodel"):
|
| 425 |
+
return m
|
| 426 |
+
# Last resort: previous generic getter (may return transformer; not ideal)
|
| 427 |
+
return self._get_vision_module()
|
| 428 |
+
|
| 429 |
+
# ----- chat/content builders & masking -----
|
| 430 |
+
|
| 431 |
+
def _target_from_image_size(self, image_size: Optional[int]) -> int:
|
| 432 |
+
"""
|
| 433 |
+
Return a pixel target that will be floored to a multiple of 28 by to_qwen_grid().
|
| 434 |
+
Any multiple of 28 works (e.g., 448, 504, 1008).
|
| 435 |
+
"""
|
| 436 |
+
sz = image_size if isinstance(image_size, int) and image_size % 28 == 0 else self.image_size
|
| 437 |
+
return int(sz)
|
| 438 |
+
|
| 439 |
+
def _build_interleaved_content(self, text: str, imgs: List[Image.Image], append_unused_images: bool = False) -> Tuple[list, list]:
|
| 440 |
+
"""
|
| 441 |
+
NUMERIC placeholders: <image1>, <image2>, ...
|
| 442 |
+
Returns (content_list, images_in_order).
|
| 443 |
+
"""
|
| 444 |
+
if text is None:
|
| 445 |
+
text = ""
|
| 446 |
+
content: list = []
|
| 447 |
+
ordered_images: list = []
|
| 448 |
+
imgs = imgs or []
|
| 449 |
+
|
| 450 |
+
pat = re.compile(r"<image\s*(\d+)\s*>", re.IGNORECASE)
|
| 451 |
+
pos = 0
|
| 452 |
+
matches = list(pat.finditer(text))
|
| 453 |
+
|
| 454 |
+
if not matches:
|
| 455 |
+
# Do not auto-append images unless explicitly requested
|
| 456 |
+
if text.strip():
|
| 457 |
+
content.append({"type": "text", "text": text})
|
| 458 |
+
if append_unused_images:
|
| 459 |
+
for im in imgs:
|
| 460 |
+
content.append({"type": "image", "image": im})
|
| 461 |
+
ordered_images.append(im)
|
| 462 |
+
return content, ordered_images
|
| 463 |
+
|
| 464 |
+
for m in matches:
|
| 465 |
+
s, e = m.span()
|
| 466 |
+
if s > pos:
|
| 467 |
+
seg = text[pos:s]
|
| 468 |
+
if seg.strip():
|
| 469 |
+
content.append({"type": "text", "text": seg})
|
| 470 |
+
idx = int(m.group(1)) - 1
|
| 471 |
+
if 0 <= idx < len(imgs):
|
| 472 |
+
content.append({"type": "image", "image": imgs[idx]})
|
| 473 |
+
ordered_images.append(imgs[idx])
|
| 474 |
+
pos = e
|
| 475 |
+
|
| 476 |
+
if pos < len(text):
|
| 477 |
+
seg = text[pos:]
|
| 478 |
+
if seg.strip():
|
| 479 |
+
content.append({"type": "text", "text": seg})
|
| 480 |
+
|
| 481 |
+
if append_unused_images:
|
| 482 |
+
used = set(ordered_images)
|
| 483 |
+
for im in imgs:
|
| 484 |
+
if im not in used:
|
| 485 |
+
content.append({"type": "image", "image": im})
|
| 486 |
+
ordered_images.append(im)
|
| 487 |
+
|
| 488 |
+
return content, ordered_images
|
| 489 |
+
|
| 490 |
+
def _build_content_from_template(
|
| 491 |
+
self,
|
| 492 |
+
template: str,
|
| 493 |
+
image_map: Optional[Dict[str, Image.Image]],
|
| 494 |
+
text_map: Optional[Dict[str, str]],
|
| 495 |
+
append_unused_images: bool = False,
|
| 496 |
+
) -> Tuple[list, list]:
|
| 497 |
+
"""
|
| 498 |
+
NAMED placeholders: <frontal_image>, <lateral_image>, <prior_image>, <report>, <prior_report>, <demographics>, ...
|
| 499 |
+
Also supports alias: <current_image> -> <frontal_image>.
|
| 500 |
+
"""
|
| 501 |
+
template = template or ""
|
| 502 |
+
image_map = {k.lower(): v for k, v in (image_map or {}).items() if v is not None}
|
| 503 |
+
text_map = {k.lower(): v for k, v in (text_map or {}).items() if v is not None and str(v).strip()}
|
| 504 |
+
|
| 505 |
+
content: list = []
|
| 506 |
+
images_in_order: list = []
|
| 507 |
+
|
| 508 |
+
pat = re.compile(r"<\s*([A-Za-z_]\w*)\s*>")
|
| 509 |
+
pos = 0
|
| 510 |
+
for m in pat.finditer(template):
|
| 511 |
+
s, e = m.span()
|
| 512 |
+
if s > pos:
|
| 513 |
+
seg = template[pos:s]
|
| 514 |
+
if seg.strip():
|
| 515 |
+
content.append({"type": "text", "text": seg})
|
| 516 |
+
|
| 517 |
+
name = m.group(1).lower()
|
| 518 |
+
# alias: current_image -> frontal_image
|
| 519 |
+
if name == "current_image":
|
| 520 |
+
name = "frontal_image"
|
| 521 |
+
|
| 522 |
+
if name in image_map: # <<< generalized image handling
|
| 523 |
+
img = image_map.get(name)
|
| 524 |
+
if img is not None:
|
| 525 |
+
content.append({"type": "image", "image": img})
|
| 526 |
+
images_in_order.append(img)
|
| 527 |
+
else:
|
| 528 |
+
val = text_map.get(name)
|
| 529 |
+
if val is not None:
|
| 530 |
+
content.append({"type": "text", "text": str(val)})
|
| 531 |
+
|
| 532 |
+
pos = e
|
| 533 |
+
|
| 534 |
+
if pos < len(template):
|
| 535 |
+
tail = template[pos:]
|
| 536 |
+
if tail.strip():
|
| 537 |
+
content.append({"type": "text", "text": tail})
|
| 538 |
+
|
| 539 |
+
# Append any not-yet-used images at the end (conditionally)
|
| 540 |
+
if append_unused_images:
|
| 541 |
+
for key, img in image_map.items():
|
| 542 |
+
if img is not None and img not in images_in_order:
|
| 543 |
+
content.append({"type": "image", "image": img})
|
| 544 |
+
images_in_order.append(img)
|
| 545 |
+
|
| 546 |
+
return content, images_in_order
|
| 547 |
+
|
| 548 |
+
def _mask_last_role_block(self, inputs: dict, hidden: torch.Tensor) -> torch.Tensor:
|
| 549 |
+
"""
|
| 550 |
+
Boolean mask (B,S) selecting tokens inside the **last** role block (user/assistant),
|
| 551 |
+
excluding the final <|im_end|>, for **any** batch size.
|
| 552 |
+
Falls back to attention_mask if special tokens are unavailable.
|
| 553 |
+
"""
|
| 554 |
+
device = hidden.device
|
| 555 |
+
ids = inputs.get("input_ids", None)
|
| 556 |
+
attn = inputs.get("attention_mask", None)
|
| 557 |
+
if ids is None:
|
| 558 |
+
return (attn if attn is not None else torch.ones(hidden.shape[:2], device=device, dtype=torch.long)).bool()
|
| 559 |
+
|
| 560 |
+
B, S = ids.shape
|
| 561 |
+
mask = torch.zeros((B, S), device=device, dtype=torch.bool)
|
| 562 |
+
|
| 563 |
+
# Try to get ChatML boundary tokens
|
| 564 |
+
try:
|
| 565 |
+
start_id = self.processor.tokenizer.convert_tokens_to_ids("<|im_start|>")
|
| 566 |
+
except Exception:
|
| 567 |
+
start_id = None
|
| 568 |
+
try:
|
| 569 |
+
end_id = self.processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 570 |
+
except Exception:
|
| 571 |
+
end_id = None
|
| 572 |
+
|
| 573 |
+
if end_id is None:
|
| 574 |
+
return (attn if attn is not None else torch.ones((B, S), device=device, dtype=torch.long)).bool()
|
| 575 |
+
|
| 576 |
+
for b in range(B):
|
| 577 |
+
# Limit search to valid tokens when attention mask is present
|
| 578 |
+
if attn is not None:
|
| 579 |
+
valid_len = int(attn[b].sum().item())
|
| 580 |
+
else:
|
| 581 |
+
valid_len = S
|
| 582 |
+
valid_len = max(1, min(valid_len, S))
|
| 583 |
+
seq = ids[b, :valid_len]
|
| 584 |
+
|
| 585 |
+
ends = (seq == end_id).nonzero(as_tuple=False).flatten()
|
| 586 |
+
if ends.numel() == 0:
|
| 587 |
+
# No explicit blocks; fall back to all valid tokens
|
| 588 |
+
mask[b, :valid_len] = True
|
| 589 |
+
continue
|
| 590 |
+
last_end = int(ends[-1].item())
|
| 591 |
+
|
| 592 |
+
last_start = -1
|
| 593 |
+
if start_id is not None:
|
| 594 |
+
starts = (seq == start_id).nonzero(as_tuple=False).flatten()
|
| 595 |
+
starts_before = starts[starts < last_end] if starts.numel() > 0 else None
|
| 596 |
+
if starts_before is not None and starts_before.numel() > 0:
|
| 597 |
+
last_start = int(starts_before[-1].item())
|
| 598 |
+
elif ends.numel() >= 2:
|
| 599 |
+
# Heuristic: if no <|im_start|>, use previous end as start
|
| 600 |
+
last_start = int(ends[-2].item())
|
| 601 |
+
else:
|
| 602 |
+
if ends.numel() >= 2:
|
| 603 |
+
last_start = int(ends[-2].item())
|
| 604 |
+
|
| 605 |
+
left = max(last_start + 1, 0)
|
| 606 |
+
right = max(last_end - 1, left)
|
| 607 |
+
mask[b, left:right + 1] = True
|
| 608 |
+
|
| 609 |
+
if attn is not None:
|
| 610 |
+
mask = mask & attn.bool()
|
| 611 |
+
return mask
|
| 612 |
+
|
| 613 |
+
# ---------- encoders (unified everywhere) ----------
|
| 614 |
+
|
| 615 |
+
@torch.no_grad()
|
| 616 |
+
def encode_text_unified(self, instructions: List[Optional[str]], texts: List[str], role: str = "user",
|
| 617 |
+
normalize: bool = True) -> torch.Tensor:
|
| 618 |
+
"""Text-only, but still go through the unified VL path for consistency."""
|
| 619 |
+
empty_images = [[] for _ in texts]
|
| 620 |
+
return self.encode_interleaved(instructions, texts, empty_images, role=role, normalize=normalize)
|
| 621 |
+
|
| 622 |
+
@torch.no_grad()
|
| 623 |
+
def encode_images_unified(self, instructions: List[Optional[str]], image_templates: List[str],
|
| 624 |
+
image_maps: List[Dict[str, Image.Image]], role: str = "user",
|
| 625 |
+
normalize: bool = True, image_size: Optional[int] = None) -> torch.Tensor:
|
| 626 |
+
"""
|
| 627 |
+
Image-only via unified path. Pass templates like "<frontal_image>" or "" (images only included if explicitly referenced).
|
| 628 |
+
"""
|
| 629 |
+
empty_text_maps = [{} for _ in image_templates]
|
| 630 |
+
return self.encode_interleaved_with_ph(instructions, image_templates, image_maps, empty_text_maps,
|
| 631 |
+
role=role, normalize=normalize, image_size=image_size)
|
| 632 |
+
|
| 633 |
+
@torch.no_grad()
|
| 634 |
+
def encode_interleaved(
|
| 635 |
+
self,
|
| 636 |
+
instructions: List[Optional[str]],
|
| 637 |
+
contents: List[str],
|
| 638 |
+
images: List[List[Image.Image]],
|
| 639 |
+
role: str = "user",
|
| 640 |
+
normalize: bool = True,
|
| 641 |
+
image_size: Optional[int] = None, # 504 or 1008 override
|
| 642 |
+
) -> torch.Tensor:
|
| 643 |
+
device = self.device
|
| 644 |
+
vm = self._get_vision_module()
|
| 645 |
+
vision_dtype = next(vm.parameters()).dtype
|
| 646 |
+
|
| 647 |
+
assert len(instructions) == len(contents) == len(images), "length mismatch"
|
| 648 |
+
out_vecs = []
|
| 649 |
+
target = self._target_from_image_size(image_size)
|
| 650 |
+
|
| 651 |
+
for inst, text, imgs in zip(instructions, contents, images):
|
| 652 |
+
proc_imgs = [to_qwen_grid(im, target=target) for im in (imgs or [])]
|
| 653 |
+
content_list, images_in_order = self._build_interleaved_content(
|
| 654 |
+
text or "", proc_imgs, append_unused_images=False
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
msgs = []
|
| 658 |
+
if inst and str(inst).strip():
|
| 659 |
+
msgs.append({"role": "system", "content": [{"type": "text", "text": inst}]})
|
| 660 |
+
msgs.append({"role": role, "content": content_list})
|
| 661 |
+
|
| 662 |
+
chat_text = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
|
| 663 |
+
|
| 664 |
+
proc = self.processor(
|
| 665 |
+
text=[chat_text],
|
| 666 |
+
images=images_in_order if images_in_order else None,
|
| 667 |
+
return_tensors="pt",
|
| 668 |
+
padding=True,
|
| 669 |
+
truncation=True,
|
| 670 |
+
do_resize=False,
|
| 671 |
+
max_length=self.max_text_tokens,
|
| 672 |
+
)
|
| 673 |
+
inputs = {k: v.to(device) for k, v in proc.items()}
|
| 674 |
+
if "pixel_values" in inputs:
|
| 675 |
+
inputs["pixel_values"] = inputs["pixel_values"].to(device=device, dtype=vision_dtype)
|
| 676 |
+
if "image_grid_thw" in inputs:
|
| 677 |
+
inputs["image_grid_thw"] = inputs["image_grid_thw"].to(device)
|
| 678 |
+
|
| 679 |
+
out = self.vl(**inputs, output_hidden_states=True, use_cache=False)
|
| 680 |
+
hidden = out.hidden_states[-1] # (1, S, H)
|
| 681 |
+
span_mask = self._mask_last_role_block(inputs, hidden) # (1, S)
|
| 682 |
+
|
| 683 |
+
if self.pool_mode == "latent_attention":
|
| 684 |
+
pool_dtype = next(self.unified_pooler.parameters()).dtype
|
| 685 |
+
if hidden.dtype != pool_dtype:
|
| 686 |
+
hidden = hidden.to(dtype=pool_dtype)
|
| 687 |
+
vec = self.unified_pooler(hidden, span_mask).squeeze(0)
|
| 688 |
+
else:
|
| 689 |
+
vec = masked_mean_pool(hidden, span_mask).squeeze(0)
|
| 690 |
+
|
| 691 |
+
out_vecs.append(vec)
|
| 692 |
+
|
| 693 |
+
embs = torch.stack(out_vecs, dim=0)
|
| 694 |
+
proj_dtype = next(self.unified_proj.parameters()).dtype
|
| 695 |
+
emb = self.unified_proj(embs.to(dtype=proj_dtype))
|
| 696 |
+
if normalize:
|
| 697 |
+
emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
| 698 |
+
return emb
|
| 699 |
+
|
| 700 |
+
@torch.no_grad()
|
| 701 |
+
def encode_interleaved_with_ph(
|
| 702 |
+
self,
|
| 703 |
+
instructions: List[Optional[str]],
|
| 704 |
+
templates: List[str],
|
| 705 |
+
image_maps: List[Optional[Dict[str, Image.Image]]],
|
| 706 |
+
text_maps: List[Optional[Dict[str, str]]],
|
| 707 |
+
role: str = "user",
|
| 708 |
+
normalize: bool = True,
|
| 709 |
+
image_size: Optional[int] = None, # 504 or 1008 override
|
| 710 |
+
) -> torch.Tensor:
|
| 711 |
+
device = self.device
|
| 712 |
+
vm = self._get_vision_module()
|
| 713 |
+
vision_dtype = next(vm.parameters()).dtype
|
| 714 |
+
|
| 715 |
+
assert len(instructions) == len(templates) == len(image_maps) == len(text_maps), "length mismatch"
|
| 716 |
+
|
| 717 |
+
vecs = []
|
| 718 |
+
target = self._target_from_image_size(image_size)
|
| 719 |
+
|
| 720 |
+
for inst, tmpl, imap, tmap in zip(instructions, templates, image_maps, text_maps):
|
| 721 |
+
proc_imap: Dict[str, Image.Image] = {}
|
| 722 |
+
if imap:
|
| 723 |
+
for k, im in imap.items():
|
| 724 |
+
if im is not None:
|
| 725 |
+
proc_imap[k.lower()] = to_qwen_grid(im, target=target)
|
| 726 |
+
|
| 727 |
+
content_list, images_in_order = self._build_content_from_template(tmpl or "", proc_imap, (tmap or {}))
|
| 728 |
+
|
| 729 |
+
msgs = []
|
| 730 |
+
if inst and str(inst).strip():
|
| 731 |
+
msgs.append({"role": "system", "content": [{"type": "text", "text": inst}]})
|
| 732 |
+
msgs.append({"role": role, "content": content_list})
|
| 733 |
+
|
| 734 |
+
chat_text = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
|
| 735 |
+
|
| 736 |
+
proc = self.processor(
|
| 737 |
+
text=[chat_text],
|
| 738 |
+
images=images_in_order if images_in_order else None,
|
| 739 |
+
return_tensors="pt",
|
| 740 |
+
padding=True,
|
| 741 |
+
truncation=True,
|
| 742 |
+
do_resize=False,
|
| 743 |
+
max_length=self.max_text_tokens,
|
| 744 |
+
)
|
| 745 |
+
inputs = {k: v.to(device) for k, v in proc.items()}
|
| 746 |
+
if "pixel_values" in inputs:
|
| 747 |
+
inputs["pixel_values"] = inputs["pixel_values"].to(device=device, dtype=vision_dtype)
|
| 748 |
+
if "image_grid_thw" in inputs:
|
| 749 |
+
inputs["image_grid_thw"] = inputs["image_grid_thw"].to(device)
|
| 750 |
+
|
| 751 |
+
out = self.vl(**inputs, output_hidden_states=True, use_cache=False)
|
| 752 |
+
hidden = out.hidden_states[-1] # (1, S, H)
|
| 753 |
+
span_mask = self._mask_last_role_block(inputs, hidden) # (1, S)
|
| 754 |
+
|
| 755 |
+
if self.pool_mode == "latent_attention":
|
| 756 |
+
pool_dtype = next(self.unified_pooler.parameters()).dtype
|
| 757 |
+
if hidden.dtype != pool_dtype:
|
| 758 |
+
hidden = hidden.to(dtype=pool_dtype)
|
| 759 |
+
vec = self.unified_pooler(hidden, span_mask).squeeze(0)
|
| 760 |
+
else:
|
| 761 |
+
vec = masked_mean_pool(hidden, span_mask).squeeze(0)
|
| 762 |
+
|
| 763 |
+
vecs.append(vec)
|
| 764 |
+
|
| 765 |
+
embs = torch.stack(vecs, dim=0)
|
| 766 |
+
proj_dtype = next(self.unified_proj.parameters()).dtype
|
| 767 |
+
emb = self.unified_proj(embs.to(dtype=proj_dtype))
|
| 768 |
+
if normalize:
|
| 769 |
+
emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
| 770 |
+
return emb
|
| 771 |
+
|
| 772 |
+
# ------------- (dual encoders for debugging) -------------
|
| 773 |
+
|
| 774 |
+
@torch.no_grad()
|
| 775 |
+
def encode_text_dual(self, texts: List[str], normalize: bool = True) -> torch.Tensor:
|
| 776 |
+
device = self.device
|
| 777 |
+
tok = self.processor.tokenizer(text=texts, padding=True, truncation=True, return_tensors="pt", max_length=self.max_text_tokens)
|
| 778 |
+
tok = {k: v.to(device) for k, v in tok.items()}
|
| 779 |
+
lm = self._get_text_module()
|
| 780 |
+
out = lm(**tok, output_hidden_states=True, use_cache=False)
|
| 781 |
+
hidden = out.last_hidden_state
|
| 782 |
+
mask = tok.get("attention_mask")
|
| 783 |
+
pooled = masked_mean_pool(hidden, mask)
|
| 784 |
+
proj_dtype = next(self.text_proj.parameters()).dtype
|
| 785 |
+
emb = self.text_proj(pooled.to(dtype=proj_dtype))
|
| 786 |
+
if normalize:
|
| 787 |
+
emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
| 788 |
+
return emb
|
| 789 |
+
|
| 790 |
+
@torch.no_grad()
|
| 791 |
+
def encode_images_dual(self, images: List[List[Image.Image]], normalize: bool = True,
|
| 792 |
+
image_size: Optional[int] = None) -> torch.Tensor:
|
| 793 |
+
device = self.device
|
| 794 |
+
flat = [img for group in images for img in group]
|
| 795 |
+
counts = [len(g) for g in images]
|
| 796 |
+
if len(flat) == 0:
|
| 797 |
+
proj_dtype = next(self.image_proj.parameters()).dtype
|
| 798 |
+
zeros = torch.zeros((len(images), self.vision_hidden), device=device, dtype=proj_dtype)
|
| 799 |
+
emb = self.image_proj(zeros)
|
| 800 |
+
if normalize:
|
| 801 |
+
emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
| 802 |
+
return emb
|
| 803 |
+
target = self._target_from_image_size(image_size)
|
| 804 |
+
processed = [to_qwen_grid(img, target=target) for img in flat]
|
| 805 |
+
proc = self.processor.image_processor(images=processed, return_tensors="pt", do_resize=False)
|
| 806 |
+
vm = self._get_vision_module()
|
| 807 |
+
vision_dtype = next(vm.parameters()).dtype
|
| 808 |
+
pixel_values = proc["pixel_values"].to(device=device, dtype=vision_dtype)
|
| 809 |
+
vis_out = vm(pixel_values=pixel_values, output_hidden_states=True)
|
| 810 |
+
feats = vis_out[0] if isinstance(vis_out, (tuple, list)) else getattr(vis_out, "last_hidden_state", None)
|
| 811 |
+
if feats is None:
|
| 812 |
+
feats = getattr(vis_out, "pooler_output", None)
|
| 813 |
+
if feats is None:
|
| 814 |
+
raise RuntimeError("Vision backbone did not return features as expected.")
|
| 815 |
+
per_img = feats.mean(dim=1) if feats.ndim == 3 else feats
|
| 816 |
+
splits = torch.split(per_img, counts, dim=0)
|
| 817 |
+
set_vecs = torch.stack([s.mean(dim=0) if s.ndim > 1 else s for s in splits], dim=0)
|
| 818 |
+
proj_dtype = next(self.image_proj.parameters()).dtype
|
| 819 |
+
emb = self.image_proj(set_vecs.to(dtype=proj_dtype))
|
| 820 |
+
if normalize:
|
| 821 |
+
emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
| 822 |
+
return emb
|
| 823 |
+
|
| 824 |
+
# ===================== PHRASE GROUNDING UTILS =====================
|
| 825 |
+
|
| 826 |
+
def _find_subsequence(self, haystack: list, needle: list) -> list:
|
| 827 |
+
"""Return start indices where 'needle' occurs in 'haystack' (exact match)."""
|
| 828 |
+
if not haystack or not needle or len(needle) > len(haystack):
|
| 829 |
+
return []
|
| 830 |
+
hits = []
|
| 831 |
+
n = len(needle)
|
| 832 |
+
for i in range(len(haystack) - n + 1):
|
| 833 |
+
if haystack[i:i+n] == needle:
|
| 834 |
+
hits.append(i)
|
| 835 |
+
return hits
|
| 836 |
+
|
| 837 |
+
def _window_decode_matches(self, tokenizer, ids, target_lower: str) -> list:
|
| 838 |
+
"""Fallback: sliding-window decode match (robust to BPE splits). Returns window (start,end) indices."""
|
| 839 |
+
hits = []
|
| 840 |
+
L = len(ids)
|
| 841 |
+
# Small cap on window length to avoid expensive decode; most medical terms fit <= 5 tokens.
|
| 842 |
+
for w in range(1, 8):
|
| 843 |
+
for i in range(0, L - w + 1):
|
| 844 |
+
s, e = i, i + w
|
| 845 |
+
text = tokenizer.decode(ids[s:e], skip_special_tokens=True).lower().replace(" ", "")
|
| 846 |
+
if target_lower in text:
|
| 847 |
+
hits.append((s, e))
|
| 848 |
+
# De-duplicate overlapping windows by preferring shortest span
|
| 849 |
+
hits = sorted(set(hits), key=lambda x: (x[1]-x[0], x[0]))
|
| 850 |
+
return hits
|
| 851 |
+
|
| 852 |
+
def _resize_heatmap_like(self, hm_np, target_w, target_h):
|
| 853 |
+
from PIL import Image
|
| 854 |
+
import numpy as np
|
| 855 |
+
# hm_np: (H, W) in [0,1]; resize with bilinear to (target_h, target_w)
|
| 856 |
+
H, W = hm_np.shape
|
| 857 |
+
im = Image.fromarray((hm_np * 255.0).astype("uint8"), mode="L")
|
| 858 |
+
im = im.resize((target_w, target_h), Image.BILINEAR)
|
| 859 |
+
out = (np.array(im).astype("float32") / 255.0)
|
| 860 |
+
return out
|
| 861 |
+
|
| 862 |
+
def _overlay_heatmap_on_image(self, img_pil, hm_np, alpha=0.45):
|
| 863 |
+
"""Return PIL with heatmap overlay; hm_np in [0,1] same size as img."""
|
| 864 |
+
import matplotlib
|
| 865 |
+
import numpy as np
|
| 866 |
+
from PIL import Image
|
| 867 |
+
|
| 868 |
+
img = np.array(img_pil.convert("RGB")).astype("float32") / 255.0
|
| 869 |
+
H, W = img.shape[:2]
|
| 870 |
+
hm = np.clip(hm_np, 0.0, 1.0)
|
| 871 |
+
if hm.shape[:2] != (H, W):
|
| 872 |
+
raise ValueError("Heatmap and image size mismatch")
|
| 873 |
+
# Use a perceptually reasonable colormap without fixing colors for downstream tools.
|
| 874 |
+
cmap = matplotlib.cm.get_cmap("jet")
|
| 875 |
+
color_hm = cmap(hm)[..., :3] # (H,W,3)
|
| 876 |
+
blended = (1.0 - alpha) * img + alpha * color_hm
|
| 877 |
+
blended = np.clip(blended, 0.0, 1.0)
|
| 878 |
+
return Image.fromarray((blended * 255).astype("uint8"))
|
| 879 |
+
|
| 880 |
+
def phrase_ground_and_visualize(
|
| 881 |
+
self,
|
| 882 |
+
word: str,
|
| 883 |
+
template: str,
|
| 884 |
+
row,
|
| 885 |
+
role: str = "user",
|
| 886 |
+
instruction: str = None,
|
| 887 |
+
image_size: int = None, # multiples of 28; defaults to self.image_size
|
| 888 |
+
layer_for_text: int = -1, # which hidden_states layer to pull token reps from
|
| 889 |
+
save_dir: str = None, # if set, saves overlays as PNGs
|
| 890 |
+
return_arrays: bool = False, # if True, return heatmaps as numpy arrays
|
| 891 |
+
):
|
| 892 |
+
"""
|
| 893 |
+
Compute patch-level grounding for a word against images referenced in `template` filled by `row`.
|
| 894 |
+
Returns a PhraseGroundingOutput, and optionally writes overlay PNGs.
|
| 895 |
+
|
| 896 |
+
Strategy:
|
| 897 |
+
- Build a single-sample chat like encode_interleaved_with_ph().
|
| 898 |
+
- Forward Qwen-VL with hidden_states (+ attention if available).
|
| 899 |
+
- Locate word tokens inside last role block.
|
| 900 |
+
- Run vision tower once to get per-patch features per image.
|
| 901 |
+
- Project (text token avg) with text_proj, patches with image_proj; cosine sim per patch → heatmap.
|
| 902 |
+
- (Optional) also compute LM self-attn from word tokens to any image placeholders if available.
|
| 903 |
+
"""
|
| 904 |
+
import os, numpy as np, torch
|
| 905 |
+
from PIL import Image
|
| 906 |
+
|
| 907 |
+
device = self.device
|
| 908 |
+
tok = self.processor.tokenizer
|
| 909 |
+
target = self._target_from_image_size(image_size)
|
| 910 |
+
|
| 911 |
+
# --- Build content exactly like your training path ---
|
| 912 |
+
imap = build_image_map_from_row(row, root="")
|
| 913 |
+
# resize to Qwen grid (only for actually referenced keys)
|
| 914 |
+
# We won't pre-filter keys; _build_content_from_template handles which placeholders are used.
|
| 915 |
+
proc_imap = {k.lower(): to_qwen_grid(v, target=target) for k, v in (imap or {}).items() if v is not None}
|
| 916 |
+
tmap = build_text_map_from_row(row)
|
| 917 |
+
|
| 918 |
+
content_list, images_in_order = self._build_content_from_template(template or "", proc_imap, (tmap or {}), append_unused_images=False)
|
| 919 |
+
|
| 920 |
+
msgs = []
|
| 921 |
+
if instruction and str(instruction).strip():
|
| 922 |
+
msgs.append({"role": "system", "content": [{"type": "text", "text": f"INSTRUCTION:\n{instruction}"}]})
|
| 923 |
+
msgs.append({"role": role, "content": content_list})
|
| 924 |
+
chat_text = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
|
| 925 |
+
|
| 926 |
+
vm = self._get_vision_module()
|
| 927 |
+
vision_dtype = next(vm.parameters()).dtype
|
| 928 |
+
|
| 929 |
+
proc = self.processor(
|
| 930 |
+
text=[chat_text],
|
| 931 |
+
images=images_in_order if images_in_order else None,
|
| 932 |
+
return_tensors="pt",
|
| 933 |
+
padding=True,
|
| 934 |
+
truncation=True,
|
| 935 |
+
do_resize=False,
|
| 936 |
+
max_length=self.max_text_tokens,
|
| 937 |
+
)
|
| 938 |
+
inputs = {k: v.to(device) for k, v in proc.items()}
|
| 939 |
+
if "pixel_values" in inputs:
|
| 940 |
+
inputs["pixel_values"] = inputs["pixel_values"].to(device=device, dtype=vision_dtype)
|
| 941 |
+
if "image_grid_thw" in inputs:
|
| 942 |
+
inputs["image_grid_thw"] = inputs["image_grid_thw"].to(device)
|
| 943 |
+
|
| 944 |
+
# --- Forward with hidden states (+ attentions if the model exposes them) ---
|
| 945 |
+
with torch.no_grad():
|
| 946 |
+
out = self.vl(**inputs, output_hidden_states=True, output_attentions=True, use_cache=False, return_dict=True)
|
| 947 |
+
|
| 948 |
+
hidden = out.hidden_states[layer_for_text] # (1, S, H)
|
| 949 |
+
span_mask = self._mask_last_role_block(inputs, hidden)[0].bool() # (S,)
|
| 950 |
+
seq_ids = inputs["input_ids"][0].tolist()
|
| 951 |
+
|
| 952 |
+
# --- Find token indices for the word inside the last role block ---
|
| 953 |
+
# 1) exact subsequence match of token ids
|
| 954 |
+
tgt_ids = tok(word, add_special_tokens=False)["input_ids"]
|
| 955 |
+
last_role_positions = [i for i, m in enumerate(span_mask.tolist()) if m]
|
| 956 |
+
id_seq_in_span = [seq_ids[i] for i in last_role_positions]
|
| 957 |
+
hits = self._find_subsequence(id_seq_in_span, tgt_ids)
|
| 958 |
+
token_span = None # (abs_start, abs_end)
|
| 959 |
+
if hits:
|
| 960 |
+
start_in_span = hits[0]
|
| 961 |
+
abs_start = last_role_positions[start_in_span]
|
| 962 |
+
abs_end = last_role_positions[start_in_span + len(tgt_ids) - 1] + 1 # exclusive
|
| 963 |
+
token_span = (abs_start, abs_end)
|
| 964 |
+
else:
|
| 965 |
+
# 2) fallback: decode windows in-span and fuzzy match lowercase without spaces
|
| 966 |
+
win_hits = self._window_decode_matches(tok, id_seq_in_span, target_lower=word.lower().replace(" ", ""))
|
| 967 |
+
if win_hits:
|
| 968 |
+
s, e = win_hits[0]
|
| 969 |
+
abs_start = last_role_positions[s]
|
| 970 |
+
abs_end = last_role_positions[e - 1] + 1
|
| 971 |
+
token_span = (abs_start, abs_end)
|
| 972 |
+
|
| 973 |
+
if token_span is None:
|
| 974 |
+
# If the word cannot be located, we center on the last token in the last-role block.
|
| 975 |
+
# This keeps the visualization functional for debugging.
|
| 976 |
+
last_idx = last_role_positions[-1]
|
| 977 |
+
token_span = (last_idx, last_idx + 1)
|
| 978 |
+
|
| 979 |
+
s_idx, e_idx = token_span
|
| 980 |
+
word_tokens = hidden[0, s_idx:e_idx, :] # (T_word, Htxt)
|
| 981 |
+
# Average sub-tokens → one vector
|
| 982 |
+
word_vec_txt = word_tokens.mean(dim=0, keepdim=True) # (1, Htxt)
|
| 983 |
+
|
| 984 |
+
# --- Get vision patch features per image ---
|
| 985 |
+
heatmaps = []
|
| 986 |
+
per_image_debug = []
|
| 987 |
+
if "pixel_values" in inputs:
|
| 988 |
+
# Use the TOP-LEVEL vision model entry
|
| 989 |
+
vmodel = self._get_vision_entry()
|
| 990 |
+
with torch.no_grad():
|
| 991 |
+
vout = vmodel(
|
| 992 |
+
pixel_values=inputs["pixel_values"],
|
| 993 |
+
grid_thw=inputs.get("image_grid_thw", None),
|
| 994 |
+
output_hidden_states=True,
|
| 995 |
+
return_dict=True,
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
# vout.last_hidden_state: (B, Svis, C)
|
| 999 |
+
vlast = vout.last_hidden_state
|
| 1000 |
+
B, Svis, C = vlast.shape
|
| 1001 |
+
|
| 1002 |
+
# Grid sizes per image (T,H,W)
|
| 1003 |
+
grids = inputs.get("image_grid_thw", None)
|
| 1004 |
+
if grids is not None:
|
| 1005 |
+
# grids shape: (B, 3) => (T, H, W)
|
| 1006 |
+
thw = grids.detach().cpu().tolist()
|
| 1007 |
+
if isinstance(thw[0], (int, float)): # single image edge case
|
| 1008 |
+
thw = [thw]
|
| 1009 |
+
else:
|
| 1010 |
+
thw = [[1, int(round(Svis ** 0.5)), int(round(Svis ** 0.5))] for _ in range(B)]
|
| 1011 |
+
|
| 1012 |
+
# If a CLS token exists, Svis == T*H*W + 1; drop it
|
| 1013 |
+
per_img = []
|
| 1014 |
+
offset = 0
|
| 1015 |
+
for i in range(B):
|
| 1016 |
+
t, h, w = map(int, thw[i])
|
| 1017 |
+
tokens_per = t * h * w
|
| 1018 |
+
take_from = 1 if (Svis == tokens_per + 1) else 0
|
| 1019 |
+
patches = vlast[i, take_from:take_from + tokens_per, :] # (T*H*W, C)
|
| 1020 |
+
per_img.append((patches, (t, h, w)))
|
| 1021 |
+
|
| 1022 |
+
proj_dtype_img = next(self.image_proj.parameters()).dtype
|
| 1023 |
+
proj_dtype_txt = next(self.text_proj.parameters()).dtype
|
| 1024 |
+
|
| 1025 |
+
word_vec = self.text_proj(word_vec_txt.to(dtype=proj_dtype_txt))
|
| 1026 |
+
word_vec = word_vec / (word_vec.norm(dim=-1, keepdim=True) + 1e-12)
|
| 1027 |
+
|
| 1028 |
+
for (patches, (t, h, w)) in per_img:
|
| 1029 |
+
patch_emb = self.image_proj(patches.to(dtype=proj_dtype_img))
|
| 1030 |
+
patch_emb = patch_emb / (patch_emb.norm(dim=-1, keepdim=True) + 1e-12)
|
| 1031 |
+
sim = (patch_emb @ word_vec[0].T).squeeze(-1) # (P,)
|
| 1032 |
+
sim = sim.reshape(t, h, w).mean(dim=0) # (H, W)
|
| 1033 |
+
smin, smax = float(sim.min()), float(sim.max())
|
| 1034 |
+
hm = (sim - smin) / max(1e-6, (smax - smin))
|
| 1035 |
+
heatmaps.append(hm.detach().cpu().numpy())
|
| 1036 |
+
per_image_debug.append({"tokens_per": t*h*w, "grid": (t, h, w)})
|
| 1037 |
+
|
| 1038 |
+
# --- Save overlays if requested ---
|
| 1039 |
+
saved_paths = []
|
| 1040 |
+
if save_dir and heatmaps:
|
| 1041 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 1042 |
+
for i, im in enumerate(images_in_order):
|
| 1043 |
+
# Ensure the heatmap is resized to the same (square) size we fed Qwen
|
| 1044 |
+
tgt_w, tgt_h = im.size
|
| 1045 |
+
hm_np = self._resize_heatmap_like(heatmaps[i], tgt_w, tgt_h)
|
| 1046 |
+
overlay = self._overlay_heatmap_on_image(im, hm_np, alpha=0.45)
|
| 1047 |
+
fname = os.path.join(save_dir, f"ground_{i:02d}_{word.replace(' ','_')}.png")
|
| 1048 |
+
overlay.save(fname)
|
| 1049 |
+
saved_paths.append(fname)
|
| 1050 |
+
|
| 1051 |
+
result = PhraseGroundingOutput(
|
| 1052 |
+
token_span=(int(s_idx), int(e_idx)),
|
| 1053 |
+
per_image=[{
|
| 1054 |
+
"heatmap": (heatmaps[i] if return_arrays else None),
|
| 1055 |
+
"saved_path": (saved_paths[i] if i < len(saved_paths) else None),
|
| 1056 |
+
"grid": per_image_debug[i].get("grid", None),
|
| 1057 |
+
"tokens_per": per_image_debug[i].get("tokens_per", None),
|
| 1058 |
+
"placeholder_attn": per_image_debug[i].get("placeholder_attn", None),
|
| 1059 |
+
} for i in range(len(heatmaps))]
|
| 1060 |
+
)
|
| 1061 |
+
return result
|
| 1062 |
+
|
| 1063 |
+
|
| 1064 |
+
class PhraseGroundingOutput:
|
| 1065 |
+
def __init__(self, token_span, per_image):
|
| 1066 |
+
self.token_span = token_span # (start_idx, end_idx) within last-role span
|
| 1067 |
+
self.per_image = per_image # list of dicts with fields below
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
peft
|
| 4 |
+
huggingface_hub
|
| 5 |
+
pillow
|