|
|
import os |
|
|
import json |
|
|
import math |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Dict, List, Optional, Union, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig |
|
|
|
|
|
try: |
|
|
from peft import PeftModel |
|
|
_HAS_PEFT = True |
|
|
except Exception: |
|
|
PeftModel = None |
|
|
_HAS_PEFT = False |
|
|
|
|
|
try: |
|
|
from huggingface_hub import snapshot_download |
|
|
_HAS_HUB = True |
|
|
except Exception: |
|
|
snapshot_download = None |
|
|
_HAS_HUB = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SECTION_NAMES = [ |
|
|
"Lungs and Airways", |
|
|
"Pleura", |
|
|
"Cardiovascular", |
|
|
"Hila and Mediastinum", |
|
|
"Tubes & Devices", |
|
|
"Musculoskeletal and Chest Wall", |
|
|
"Abdominal", |
|
|
"impression", |
|
|
"Other", |
|
|
] |
|
|
|
|
|
SECTION_ALIASES = { |
|
|
"global": "global", |
|
|
"lungs": "Lungs and Airways", |
|
|
"lung": "Lungs and Airways", |
|
|
"pleura": "Pleura", |
|
|
"cardio": "Cardiovascular", |
|
|
"cardiovascular": "Cardiovascular", |
|
|
"hila": "Hila and Mediastinum", |
|
|
"mediastinum": "Hila and Mediastinum", |
|
|
"tubes": "Tubes & Devices", |
|
|
"devices": "Tubes & Devices", |
|
|
"msk": "Musculoskeletal and Chest Wall", |
|
|
"musculoskeletal": "Musculoskeletal and Chest Wall", |
|
|
"abd": "Abdominal", |
|
|
"abdominal": "Abdominal", |
|
|
"impression": "impression", |
|
|
"other": "Other", |
|
|
} |
|
|
|
|
|
|
|
|
def require_flash_attention_2() -> str: |
|
|
if not torch.cuda.is_available(): |
|
|
raise RuntimeError("FlashAttention-2 requires CUDA, but torch.cuda.is_available() is False.") |
|
|
try: |
|
|
import flash_attn |
|
|
ver = getattr(flash_attn, "__version__", "0.0.0") |
|
|
major = int(str(ver).split(".")[0]) |
|
|
if major < 2: |
|
|
raise RuntimeError(f"flash-attn version {ver} < 2.0.0") |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
"FlashAttention-2 is REQUIRED but not available/importable.\n" |
|
|
"Install flash-attn>=2 and ensure it matches your torch/CUDA.\n" |
|
|
f"Import/Version error: {repr(e)}" |
|
|
) |
|
|
return "flash_attention_2" |
|
|
|
|
|
|
|
|
def build_qwen_query(instruction: str, query: str) -> str: |
|
|
instruction = str(instruction).strip() |
|
|
query = str(query).strip() |
|
|
return f"Instruct: {instruction}\nQuery: {query}" |
|
|
|
|
|
|
|
|
def get_pool_token_id(tok) -> int: |
|
|
eod_id = tok.convert_tokens_to_ids("<|endoftext|>") |
|
|
if eod_id is None or eod_id < 0: |
|
|
eod_id = tok.pad_token_id |
|
|
return eod_id |
|
|
|
|
|
|
|
|
def encode_with_eos_ids(tok, texts: List[str], max_len: int) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Must match Stage-3 training: |
|
|
- add_special_tokens=False |
|
|
- truncation to max_len-1 |
|
|
- append <|endoftext|> |
|
|
- left-pad |
|
|
""" |
|
|
pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id |
|
|
eod_id = get_pool_token_id(tok) |
|
|
|
|
|
enc = tok( |
|
|
[str(t) for t in texts], |
|
|
add_special_tokens=False, |
|
|
truncation=True, |
|
|
max_length=max_len - 1, |
|
|
padding=False, |
|
|
return_attention_mask=False, |
|
|
) |
|
|
|
|
|
input_ids = [ids + [eod_id] for ids in enc["input_ids"]] |
|
|
attn_mask = [[1] * len(ids) for ids in input_ids] |
|
|
|
|
|
T = max(len(ids) for ids in input_ids) if input_ids else 1 |
|
|
input_ids = [[pad_id] * (T - len(ids)) + ids for ids in input_ids] |
|
|
attn_mask = [[0] * (T - len(m)) + m for m in attn_mask] |
|
|
|
|
|
return { |
|
|
"input_ids": torch.tensor(input_ids, dtype=torch.long), |
|
|
"attention_mask": torch.tensor(attn_mask, dtype=torch.long), |
|
|
} |
|
|
|
|
|
|
|
|
def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Left-padding aware last-token pooling (extracts EOS token embedding). |
|
|
""" |
|
|
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
|
|
if left_padding: |
|
|
return last_hidden_states[:, -1] |
|
|
idx = attention_mask.sum(dim=1) - 1 |
|
|
return last_hidden_states[torch.arange(last_hidden_states.size(0), device=last_hidden_states.device), idx] |
|
|
|
|
|
|
|
|
def get_last_hidden_state(model, input_ids, attention_mask): |
|
|
""" |
|
|
Provide position_ids for left padding (FlashAttention-2). |
|
|
""" |
|
|
m = model.module if hasattr(model, "module") else model |
|
|
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
|
position_ids.masked_fill_(attention_mask == 0, 0) |
|
|
|
|
|
out = m( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
use_cache=False, |
|
|
return_dict=True, |
|
|
) |
|
|
if hasattr(out, "last_hidden_state"): |
|
|
return out.last_hidden_state |
|
|
|
|
|
out = m( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
output_hidden_states=True, |
|
|
use_cache=False, |
|
|
return_dict=True, |
|
|
) |
|
|
return out.hidden_states[-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SectionQueryAttnPooler(nn.Module): |
|
|
""" |
|
|
Match your Stage-3 training pooler. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
hidden_size: int, |
|
|
num_sections: int, |
|
|
mlp_hidden: int, |
|
|
use_layernorm: bool = True, |
|
|
pool_dropout: float = 0.1, |
|
|
pool_scale: float = 0.0, |
|
|
): |
|
|
super().__init__() |
|
|
self.hidden_size = int(hidden_size) |
|
|
self.num_sections = int(num_sections) |
|
|
|
|
|
self.ln = nn.LayerNorm(self.hidden_size) if use_layernorm else nn.Identity() |
|
|
|
|
|
self.pool_queries = nn.Parameter(torch.empty(self.num_sections, self.hidden_size)) |
|
|
nn.init.normal_(self.pool_queries, mean=0.0, std=0.02) |
|
|
|
|
|
self.pool_scale = float(pool_scale) if (pool_scale and pool_scale > 0) else (1.0 / math.sqrt(self.hidden_size)) |
|
|
self.pool_dropout = nn.Dropout(pool_dropout) if pool_dropout and pool_dropout > 0 else nn.Identity() |
|
|
|
|
|
|
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(self.hidden_size, int(mlp_hidden), bias=False), |
|
|
nn.GELU(), |
|
|
nn.Linear(int(mlp_hidden), self.hidden_size, bias=False), |
|
|
) |
|
|
|
|
|
def forward_all(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
if isinstance(self.ln, nn.LayerNorm): |
|
|
x = F.layer_norm( |
|
|
hidden_states.float(), |
|
|
self.ln.normalized_shape, |
|
|
self.ln.weight.float() if self.ln.weight is not None else None, |
|
|
self.ln.bias.float() if self.ln.bias is not None else None, |
|
|
self.ln.eps, |
|
|
).to(dtype=hidden_states.dtype) |
|
|
else: |
|
|
x = hidden_states |
|
|
|
|
|
scores = torch.einsum("bth,sh->bts", x.float(), self.pool_queries.float()) * self.pool_scale |
|
|
scores = scores.masked_fill(attention_mask.unsqueeze(-1) == 0, -1e4) |
|
|
|
|
|
attn = torch.softmax(scores, dim=1).to(dtype=x.dtype) |
|
|
attn = self.pool_dropout(attn) |
|
|
|
|
|
pooled = torch.einsum("bth,bts->bsh", x, attn) |
|
|
pooled = pooled.to(dtype=next(self.mlp.parameters()).dtype) |
|
|
pooled = self.mlp(pooled) |
|
|
|
|
|
return F.normalize(pooled, p=2, dim=-1) |
|
|
|
|
|
|
|
|
def _ensure_pooler_device_dtype(pooler: nn.Module, device: torch.device, dtype: torch.dtype) -> None: |
|
|
p = next(pooler.parameters(), None) |
|
|
if p is None: |
|
|
return |
|
|
if p.device != device or p.dtype != dtype: |
|
|
pooler.to(device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
def _read_json(path: str) -> Dict[str, Any]: |
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
|
|
|
|
|
|
def _resolve_repo_path(repo_id_or_path: str) -> str: |
|
|
|
|
|
if os.path.isdir(repo_id_or_path): |
|
|
return repo_id_or_path |
|
|
|
|
|
if not _HAS_HUB: |
|
|
raise RuntimeError( |
|
|
"huggingface_hub is required to load by repo_id. " |
|
|
"Install it: pip install huggingface_hub" |
|
|
) |
|
|
return snapshot_download(repo_id_or_path) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class EmbedOutput: |
|
|
|
|
|
section_matrix: torch.Tensor |
|
|
global_embedding: torch.Tensor |
|
|
|
|
|
by_section_name: Dict[str, torch.Tensor] |
|
|
by_alias: Dict[str, torch.Tensor] |
|
|
|
|
|
|
|
|
class Chest2Vec: |
|
|
""" |
|
|
Lightweight wrapper: |
|
|
- loads base Qwen3-Embedding |
|
|
- applies LoRA adapter |
|
|
- attaches Stage-3 section pooler |
|
|
""" |
|
|
def __init__(self, tokenizer, model, pooler, sections: List[str], device: torch.device): |
|
|
self.tokenizer = tokenizer |
|
|
self.model = model |
|
|
self.pooler = pooler |
|
|
self.sections = list(sections) |
|
|
self.device = device |
|
|
|
|
|
self.model.eval() |
|
|
self.pooler.eval() |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
repo_id_or_path: str, |
|
|
*, |
|
|
device: str = "cuda:0", |
|
|
use_4bit: bool = False, |
|
|
force_flash_attention_2: bool = True, |
|
|
) -> "Chest2Vec": |
|
|
repo_path = _resolve_repo_path(repo_id_or_path) |
|
|
|
|
|
cfg_path = os.path.join(repo_path, "chest2vec_config.json") |
|
|
if not os.path.isfile(cfg_path): |
|
|
raise FileNotFoundError(f"Missing chest2vec_config.json in {repo_path}") |
|
|
cfg = _read_json(cfg_path) |
|
|
|
|
|
base_model = str(cfg["base_model"]) |
|
|
adapter_subdir = str(cfg.get("adapter_subdir", "contrastive")) |
|
|
pooler_pt = str(cfg.get("pooler_pt", "section_pooler.pt")) |
|
|
pooler_cfg = str(cfg.get("pooler_cfg", "section_pooler_config.json")) |
|
|
sections = cfg.get("sections", SECTION_NAMES) |
|
|
|
|
|
if force_flash_attention_2 or bool(cfg.get("require_flash_attention_2", False)): |
|
|
attn_impl = require_flash_attention_2() |
|
|
else: |
|
|
attn_impl = "sdpa" |
|
|
|
|
|
if not _HAS_PEFT: |
|
|
raise RuntimeError("peft is required. Install: pip install peft") |
|
|
|
|
|
device_t = torch.device(device) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model, padding_side="left", trust_remote_code=True) |
|
|
if tokenizer.pad_token_id is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
device_map = {"": str(device_t)} |
|
|
|
|
|
|
|
|
if use_4bit: |
|
|
qconf = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
) |
|
|
try: |
|
|
base = AutoModel.from_pretrained( |
|
|
base_model, |
|
|
trust_remote_code=True, |
|
|
attn_implementation=attn_impl, |
|
|
quantization_config=qconf, |
|
|
device_map=device_map, |
|
|
) |
|
|
except TypeError as e: |
|
|
raise RuntimeError( |
|
|
"Your transformers version does not support attn_implementation=... " |
|
|
"Upgrade transformers to use FlashAttention-2." |
|
|
) from e |
|
|
else: |
|
|
try: |
|
|
base = AutoModel.from_pretrained( |
|
|
base_model, |
|
|
trust_remote_code=True, |
|
|
attn_implementation=attn_impl, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map=device_map, |
|
|
) |
|
|
except TypeError as e: |
|
|
raise RuntimeError( |
|
|
"Your transformers version does not support attn_implementation=... " |
|
|
"Upgrade transformers to use FlashAttention-2." |
|
|
) from e |
|
|
|
|
|
|
|
|
adapter_dir = os.path.join(repo_path, adapter_subdir) |
|
|
if not os.path.isfile(os.path.join(adapter_dir, "adapter_config.json")): |
|
|
raise FileNotFoundError(f"adapter_config.json not found under: {adapter_dir}") |
|
|
|
|
|
model = PeftModel.from_pretrained(base, adapter_dir) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
pooler_cfg_path = os.path.join(repo_path, pooler_cfg) |
|
|
pooler_pt_path = os.path.join(repo_path, pooler_pt) |
|
|
if not os.path.isfile(pooler_cfg_path): |
|
|
raise FileNotFoundError(f"Missing pooler config: {pooler_cfg_path}") |
|
|
if not os.path.isfile(pooler_pt_path): |
|
|
raise FileNotFoundError(f"Missing pooler weights: {pooler_pt_path}") |
|
|
|
|
|
pcfg = _read_json(pooler_cfg_path) |
|
|
|
|
|
hidden_size = int(getattr(model.module if hasattr(model, "module") else model, "config").hidden_size) |
|
|
mlp_hidden = int(pcfg.get("mlp_hidden", hidden_size)) |
|
|
use_layernorm = bool(pcfg.get("use_layernorm", True)) |
|
|
pool_dropout = float(pcfg.get("pool_dropout", 0.1)) |
|
|
pool_scale = float(pcfg.get("pool_scale", 0.0)) |
|
|
|
|
|
pooler = SectionQueryAttnPooler( |
|
|
hidden_size=hidden_size, |
|
|
num_sections=len(sections), |
|
|
mlp_hidden=mlp_hidden, |
|
|
use_layernorm=use_layernorm, |
|
|
pool_dropout=pool_dropout, |
|
|
pool_scale=pool_scale, |
|
|
) |
|
|
sd = torch.load(pooler_pt_path, map_location="cpu") |
|
|
pooler.load_state_dict(sd, strict=True) |
|
|
pooler.eval() |
|
|
|
|
|
|
|
|
|
|
|
pooler.to(device=device_t, dtype=torch.bfloat16 if device_t.type == "cuda" else torch.float32) |
|
|
|
|
|
return cls(tokenizer=tokenizer, model=model, pooler=pooler, sections=sections, device=device_t) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def embed_texts( |
|
|
self, |
|
|
texts: List[str], |
|
|
*, |
|
|
max_len: int = 512, |
|
|
batch_size: int = 16, |
|
|
return_cpu_float32: bool = True, |
|
|
) -> EmbedOutput: |
|
|
""" |
|
|
Encodes arbitrary texts (candidates, section strings, etc.) |
|
|
|
|
|
NOTE: This uses Stage-3 section pooling: |
|
|
- Section embeddings: section_pooler → [B,S,H] (9 section-specific embeddings) |
|
|
- Global embedding: EOS token embedding extracted BEFORE pooler → [B,H] (matches Stage-3 training) |
|
|
|
|
|
Returns: |
|
|
- section_matrix: [N,9,H] - section-specific embeddings |
|
|
- global_embedding: [N,H] - EOS token embedding (extracted before pooler) |
|
|
- by_section_name: dict[name] -> [N,H] |
|
|
- by_alias: dict['lungs'/'impression'/...] -> [N,H] |
|
|
""" |
|
|
|
|
|
device = self.device |
|
|
if device.type == "cuda": |
|
|
amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
|
|
use_amp = True |
|
|
else: |
|
|
amp_dtype = torch.float32 |
|
|
use_amp = False |
|
|
|
|
|
outs_sec = [] |
|
|
outs_global = [] |
|
|
for i in range(0, len(texts), batch_size): |
|
|
chunk = [str(t) for t in texts[i:i + batch_size]] |
|
|
enc = encode_with_eos_ids(self.tokenizer, chunk, max_len) |
|
|
input_ids = enc["input_ids"].to(device, non_blocking=True) |
|
|
attention_mask = enc["attention_mask"].to(device, non_blocking=True) |
|
|
|
|
|
with torch.autocast(device_type=("cuda" if device.type == "cuda" else "cpu"), |
|
|
dtype=amp_dtype, enabled=use_amp): |
|
|
h = get_last_hidden_state(self.model, input_ids, attention_mask) |
|
|
|
|
|
|
|
|
global_eos = last_token_pool(h, attention_mask) |
|
|
global_eos = F.normalize(global_eos.float(), p=2, dim=-1) |
|
|
|
|
|
|
|
|
_ensure_pooler_device_dtype(self.pooler, device=h.device, dtype=h.dtype) |
|
|
sec = self.pooler.forward_all(h, attention_mask) |
|
|
|
|
|
outs_sec.append(sec.detach()) |
|
|
outs_global.append(global_eos.detach()) |
|
|
|
|
|
section_matrix = torch.cat(outs_sec, dim=0) |
|
|
global_emb = torch.cat(outs_global, dim=0) |
|
|
|
|
|
|
|
|
if return_cpu_float32: |
|
|
section_matrix_cpu = section_matrix.float().cpu() |
|
|
|
|
|
section_matrix_cpu = F.normalize(section_matrix_cpu, p=2, dim=-1) |
|
|
global_cpu = global_emb.float().cpu() |
|
|
global_cpu = F.normalize(global_cpu, p=2, dim=-1) |
|
|
else: |
|
|
section_matrix_cpu = section_matrix |
|
|
global_cpu = global_emb |
|
|
|
|
|
by_section_name = {name: section_matrix_cpu[:, idx, :] for idx, name in enumerate(self.sections)} |
|
|
|
|
|
|
|
|
by_alias: Dict[str, torch.Tensor] = {} |
|
|
by_alias["global"] = global_cpu |
|
|
for alias, real in SECTION_ALIASES.items(): |
|
|
if real == "global": |
|
|
continue |
|
|
if real in by_section_name: |
|
|
by_alias[alias] = by_section_name[real] |
|
|
|
|
|
return EmbedOutput( |
|
|
section_matrix=section_matrix_cpu, |
|
|
global_embedding=global_cpu, |
|
|
by_section_name=by_section_name, |
|
|
by_alias=by_alias, |
|
|
) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def embed_instruction_query( |
|
|
self, |
|
|
instructions: List[str], |
|
|
queries: List[str], |
|
|
*, |
|
|
max_len: int = 512, |
|
|
batch_size: int = 16, |
|
|
return_cpu_float32: bool = True, |
|
|
) -> EmbedOutput: |
|
|
if len(instructions) != len(queries): |
|
|
raise ValueError("instructions and queries must have the same length.") |
|
|
q_texts = [build_qwen_query(i, q) for i, q in zip(instructions, queries)] |
|
|
return self.embed_texts( |
|
|
q_texts, |
|
|
max_len=max_len, |
|
|
batch_size=batch_size, |
|
|
return_cpu_float32=return_cpu_float32, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def cosine_topk( |
|
|
query_emb: torch.Tensor, |
|
|
cand_emb: torch.Tensor, |
|
|
k: int = 10, |
|
|
*, |
|
|
device: str = "cuda", |
|
|
query_batch_size: int = 256, |
|
|
doc_chunk_size: int = 8192, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Chunked cosine top-k, stable in float32. |
|
|
Returns (top_scores [Nq,k], top_indices [Nq,k]) on CPU. |
|
|
""" |
|
|
device_t = torch.device(device) |
|
|
q = F.normalize(query_emb.float(), p=2, dim=-1) |
|
|
d = F.normalize(cand_emb.float(), p=2, dim=-1) |
|
|
Nq, H = q.shape |
|
|
Nd = d.shape[0] |
|
|
k = min(int(k), Nd) |
|
|
|
|
|
top_scores_all = torch.empty((Nq, k), dtype=torch.float32) |
|
|
top_indices_all = torch.empty((Nq, k), dtype=torch.long) |
|
|
|
|
|
for qs in range(0, Nq, query_batch_size): |
|
|
qe = q[qs:qs + query_batch_size].to(device_t, non_blocking=True) |
|
|
bq = qe.size(0) |
|
|
|
|
|
top_scores = torch.full((bq, k), -1e9, device=device_t, dtype=torch.float32) |
|
|
top_indices = torch.full((bq, k), -1, device=device_t, dtype=torch.long) |
|
|
|
|
|
for ds in range(0, Nd, doc_chunk_size): |
|
|
de = d[ds:ds + doc_chunk_size].to(device_t, non_blocking=True) |
|
|
scores = (qe @ de.T).float() |
|
|
|
|
|
chunk = scores.size(1) |
|
|
idx_chunk = torch.arange(ds, ds + chunk, device=device_t, dtype=torch.long).unsqueeze(0).expand(bq, -1) |
|
|
|
|
|
comb_scores = torch.cat([top_scores, scores], dim=1) |
|
|
comb_idx = torch.cat([top_indices, idx_chunk], dim=1) |
|
|
|
|
|
new_scores, new_pos = torch.topk(comb_scores, k, dim=1) |
|
|
new_idx = comb_idx.gather(1, new_pos) |
|
|
|
|
|
top_scores, top_indices = new_scores, new_idx |
|
|
|
|
|
top_scores_all[qs:qs + bq] = top_scores.cpu() |
|
|
top_indices_all[qs:qs + bq] = top_indices.cpu() |
|
|
|
|
|
return top_scores_all, top_indices_all |
|
|
|