File size: 13,043 Bytes
9286661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
"""Chest2Vec — LoRA-tuned Qwen3-Embedding model for chest radiology reports.

Load with:

    from transformers import AutoModel
    model = AutoModel.from_pretrained("chest2vec/chest2vec_0.6B", trust_remote_code=True)
    emb = model.embed_texts(["Frontal chest radiograph. No pneumothorax."])  # [N, H], L2-normalized

Architecture:
  1. Base   : Qwen/Qwen3-Embedding-{0.6B,4B}  (downloaded at runtime)
  2. Adapter: frozen contrastive LoRA adapter  (./contrastive)

Embeddings use last-token (EOS) pooling with left padding, matching Qwen3-Embedding
and the Stage-2 training setup. FlashAttention-2 is used when CUDA + flash-attn>=2
are available (matching training); otherwise it falls back to SDPA so the model
also loads on CPU.
"""
import os
from typing import Dict, List, Optional

import torch
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig, PreTrainedModel

from .configuration_chest2vec import Chest2VecConfig

try:
    from peft import PeftModel
    _HAS_PEFT = True
except Exception:
    PeftModel = None
    _HAS_PEFT = False

try:
    from huggingface_hub import snapshot_download
    _HAS_HUB = True
except Exception:
    snapshot_download = None
    _HAS_HUB = False


# ----------------------------------------------------------------------------
# Attention backend selection
# ----------------------------------------------------------------------------
def _flash_attn_available() -> bool:
    if not torch.cuda.is_available():
        return False
    try:
        import flash_attn  # noqa: F401
        ver = getattr(flash_attn, "__version__", "0.0.0")
        return int(str(ver).split(".")[0]) >= 2
    except Exception:
        return False


def _pick_attn_impl(requested: Optional[str], want_flash: bool) -> str:
    import warnings
    if requested:
        return requested
    if want_flash and _flash_attn_available():
        return "flash_attention_2"
    if want_flash:
        warnings.warn(
            "Chest2Vec was trained with FlashAttention-2, but it is unavailable "
            "(needs CUDA + flash-attn>=2). Falling back to 'sdpa'; embeddings may "
            "differ very slightly from the reference implementation.",
            RuntimeWarning,
        )
    return "sdpa"


# ----------------------------------------------------------------------------
# Tokenization / pooling helpers (match Qwen3-Embedding + training)
# ----------------------------------------------------------------------------
def build_qwen_query(instruction: str, query: str) -> str:
    return f"Instruct: {str(instruction).strip()}\nQuery: {str(query).strip()}"


def get_pool_token_id(tok) -> int:
    eod_id = tok.convert_tokens_to_ids("<|endoftext|>")
    if eod_id is None or eod_id < 0:
        eod_id = tok.pad_token_id
    return eod_id


def encode_with_eos_ids(tok, texts: List[str], max_len: int) -> Dict[str, torch.Tensor]:
    """add_special_tokens=False, truncate to max_len-1, append <|endoftext|>, left-pad."""
    pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
    eod_id = get_pool_token_id(tok)
    enc = tok(
        [str(t) for t in texts],
        add_special_tokens=False,
        truncation=True,
        max_length=max_len - 1,
        padding=False,
        return_attention_mask=False,
    )
    input_ids = [ids + [eod_id] for ids in enc["input_ids"]]
    attn_mask = [[1] * len(ids) for ids in input_ids]
    T = max((len(ids) for ids in input_ids), default=1)
    input_ids = [[pad_id] * (T - len(ids)) + ids for ids in input_ids]
    attn_mask = [[0] * (T - len(m)) + m for m in attn_mask]
    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long),
        "attention_mask": torch.tensor(attn_mask, dtype=torch.long),
    }


def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    """Left-padding-aware last-token (EOS) pooling."""
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    idx = attention_mask.sum(dim=1) - 1
    return last_hidden_states[torch.arange(last_hidden_states.size(0), device=last_hidden_states.device), idx]


def get_last_hidden_state(model, input_ids, attention_mask):
    m = model.module if hasattr(model, "module") else model
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 0)
    out = m(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
            use_cache=False, return_dict=True)
    if getattr(out, "last_hidden_state", None) is not None:
        return out.last_hidden_state
    out = m(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
            output_hidden_states=True, use_cache=False, return_dict=True)
    return out.hidden_states[-1]


class Chest2VecModel(PreTrainedModel):
    """LoRA-tuned Qwen3-Embedding model producing L2-normalized report embeddings."""

    config_class = Chest2VecConfig
    base_model_prefix = "chest2vec"
    # Attention is handled by the inner Qwen3 backbone; advertise support so the
    # transformers attn-implementation validator on this wrapper passes.
    _supports_sdpa = True
    _supports_flash_attn_2 = True
    _supports_flash_attn = True
    _supports_attention_backend = True

    def __init__(self, config: Chest2VecConfig):
        super().__init__(config)
        # The base+adapter are assembled in `from_pretrained` (base downloads at runtime).
        self.backbone = None
        self.tokenizer = None
        self._device = torch.device("cpu")
        self.register_buffer("_anchor", torch.zeros(1), persistent=False)

    def get_input_embeddings(self):
        return None

    def set_input_embeddings(self, value):
        pass

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        config = kwargs.pop("config", None)
        device = kwargs.pop("device", None)
        use_4bit = kwargs.pop("use_4bit", False)
        attn_implementation = kwargs.pop("attn_implementation", None)
        torch_dtype = kwargs.pop("torch_dtype", None)
        token = kwargs.pop("token", None) or kwargs.pop("use_auth_token", None)
        cache_dir = kwargs.pop("cache_dir", None)
        # remaining HF plumbing kwargs (state_dict, low_cpu_mem_usage, ...) are ignored

        repo_path = pretrained_model_name_or_path
        if not os.path.isdir(repo_path):
            if not _HAS_HUB:
                raise RuntimeError("huggingface_hub is required to load by repo_id.")
            repo_path = snapshot_download(repo_path, token=token, cache_dir=cache_dir)

        if config is None:
            config = Chest2VecConfig.from_pretrained(repo_path)

        if device is None:
            device = "cuda:0" if torch.cuda.is_available() else "cpu"
        device_t = torch.device(device)
        if torch_dtype is None:
            torch_dtype = torch.bfloat16 if device_t.type == "cuda" else torch.float32

        model = cls(config)
        model._assemble(repo_path, device=device_t, use_4bit=use_4bit,
                        attn_implementation=attn_implementation, torch_dtype=torch_dtype, token=token)
        return model

    def _assemble(self, repo_path, *, device, use_4bit, attn_implementation, torch_dtype, token=None):
        cfg = self.config
        if not _HAS_PEFT:
            raise RuntimeError("peft is required. Install: pip install peft")

        attn_impl = _pick_attn_impl(attn_implementation, bool(cfg.require_flash_attention_2))

        tokenizer = AutoTokenizer.from_pretrained(
            cfg.base_model, padding_side="left", trust_remote_code=True, token=token
        )
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token = tokenizer.eos_token

        base_kwargs = dict(trust_remote_code=True, attn_implementation=attn_impl, token=token)
        if use_4bit:
            base_kwargs["quantization_config"] = BitsAndBytesConfig(
                load_in_4bit=True, bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16,
            )
            base_kwargs["device_map"] = {"": str(device)}
        else:
            base_kwargs["torch_dtype"] = torch_dtype
            if device.type == "cuda":
                base_kwargs["device_map"] = {"": str(device)}
        try:
            base = AutoModel.from_pretrained(cfg.base_model, **base_kwargs)
        except TypeError as e:
            raise RuntimeError("transformers too old for attn_implementation=...; please upgrade.") from e
        if device.type != "cuda" and not use_4bit:
            base = base.to(device)

        adapter_dir = os.path.join(repo_path, cfg.adapter_subdir)
        if not os.path.isfile(os.path.join(adapter_dir, "adapter_config.json")):
            raise FileNotFoundError(f"adapter_config.json not found under: {adapter_dir}")
        backbone = PeftModel.from_pretrained(base, adapter_dir)
        backbone.eval()

        self.backbone = backbone
        self.tokenizer = tokenizer
        self._device = device
        self.eval()

    @property
    def device(self):
        return self._device

    @torch.inference_mode()
    def embed_texts(self, texts: List[str], *, max_len: Optional[int] = None,
                    batch_size: int = 16, return_cpu_float32: bool = True) -> torch.Tensor:
        """Return L2-normalized report embeddings, shape [N, H]."""
        if self.backbone is None:
            raise RuntimeError("Model not assembled; load via from_pretrained(...).")
        max_len = int(max_len or self.config.default_max_len)
        device = self._device
        if device.type == "cuda":
            amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
            use_amp = True
        else:
            amp_dtype, use_amp = torch.float32, False

        outs = []
        for i in range(0, len(texts), batch_size):
            chunk = [str(t) for t in texts[i:i + batch_size]]
            enc = encode_with_eos_ids(self.tokenizer, chunk, max_len)
            input_ids = enc["input_ids"].to(device, non_blocking=True)
            attention_mask = enc["attention_mask"].to(device, non_blocking=True)
            with torch.autocast(device_type=("cuda" if device.type == "cuda" else "cpu"),
                                dtype=amp_dtype, enabled=use_amp):
                h = get_last_hidden_state(self.backbone, input_ids, attention_mask)
                emb = F.normalize(last_token_pool(h, attention_mask).float(), p=2, dim=-1)
            outs.append(emb.detach())
        embeddings = torch.cat(outs, dim=0)
        if return_cpu_float32:
            embeddings = F.normalize(embeddings.float().cpu(), p=2, dim=-1)
        return embeddings

    @torch.inference_mode()
    def embed_instruction_query(self, instructions: List[str], queries: List[str], **kw) -> torch.Tensor:
        if len(instructions) != len(queries):
            raise ValueError("instructions and queries must have the same length.")
        return self.embed_texts([build_qwen_query(i, q) for i, q in zip(instructions, queries)], **kw)

    def forward(self, texts: List[str], **kw) -> torch.Tensor:  # type: ignore[override]
        return self.embed_texts(texts, **kw)

    @staticmethod
    def cosine_topk(query_emb, cand_emb, k=10, *, device="cuda",
                    query_batch_size=256, doc_chunk_size=8192):
        device_t = torch.device(device if torch.cuda.is_available() else "cpu")
        q = F.normalize(query_emb.float(), p=2, dim=-1)
        d = F.normalize(cand_emb.float(), p=2, dim=-1)
        Nq, _ = q.shape
        Nd = d.shape[0]
        k = min(int(k), Nd)
        top_scores_all = torch.empty((Nq, k), dtype=torch.float32)
        top_indices_all = torch.empty((Nq, k), dtype=torch.long)
        for qs in range(0, Nq, query_batch_size):
            qe = q[qs:qs + query_batch_size].to(device_t, non_blocking=True)
            bq = qe.size(0)
            top_scores = torch.full((bq, k), -1e9, device=device_t, dtype=torch.float32)
            top_indices = torch.full((bq, k), -1, device=device_t, dtype=torch.long)
            for ds in range(0, Nd, doc_chunk_size):
                de = d[ds:ds + doc_chunk_size].to(device_t, non_blocking=True)
                scores = (qe @ de.T).float()
                chunk = scores.size(1)
                idx_chunk = torch.arange(ds, ds + chunk, device=device_t, dtype=torch.long).unsqueeze(0).expand(bq, -1)
                comb_scores = torch.cat([top_scores, scores], dim=1)
                comb_idx = torch.cat([top_indices, idx_chunk], dim=1)
                new_scores, new_pos = torch.topk(comb_scores, k, dim=1)
                top_scores, top_indices = new_scores, comb_idx.gather(1, new_pos)
            top_scores_all[qs:qs + bq] = top_scores.cpu()
            top_indices_all[qs:qs + bq] = top_indices.cpu()
        return top_scores_all, top_indices_all