import os import json import math from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union, Tuple import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig try: from peft import PeftModel _HAS_PEFT = True except Exception: PeftModel = None _HAS_PEFT = False try: from huggingface_hub import snapshot_download _HAS_HUB = True except Exception: snapshot_download = None _HAS_HUB = False # ----------------------------- # Sections (must match training) # ----------------------------- SECTION_NAMES = [ "Lungs and Airways", "Pleura", "Cardiovascular", "Hila and Mediastinum", "Tubes & Devices", "Musculoskeletal and Chest Wall", "Abdominal", "impression", "Other", ] SECTION_ALIASES = { "global": "global", "lungs": "Lungs and Airways", "lung": "Lungs and Airways", "pleura": "Pleura", "cardio": "Cardiovascular", "cardiovascular": "Cardiovascular", "hila": "Hila and Mediastinum", "mediastinum": "Hila and Mediastinum", "tubes": "Tubes & Devices", "devices": "Tubes & Devices", "msk": "Musculoskeletal and Chest Wall", "musculoskeletal": "Musculoskeletal and Chest Wall", "abd": "Abdominal", "abdominal": "Abdominal", "impression": "impression", "other": "Other", } def require_flash_attention_2() -> str: if not torch.cuda.is_available(): raise RuntimeError("FlashAttention-2 requires CUDA, but torch.cuda.is_available() is False.") try: import flash_attn # noqa: F401 ver = getattr(flash_attn, "__version__", "0.0.0") major = int(str(ver).split(".")[0]) if major < 2: raise RuntimeError(f"flash-attn version {ver} < 2.0.0") except Exception as e: raise RuntimeError( "FlashAttention-2 is REQUIRED but not available/importable.\n" "Install flash-attn>=2 and ensure it matches your torch/CUDA.\n" f"Import/Version error: {repr(e)}" ) return "flash_attention_2" def build_qwen_query(instruction: str, query: str) -> str: instruction = str(instruction).strip() query = str(query).strip() return f"Instruct: {instruction}\nQuery: {query}" def get_pool_token_id(tok) -> int: eod_id = tok.convert_tokens_to_ids("<|endoftext|>") if eod_id is None or eod_id < 0: eod_id = tok.pad_token_id return eod_id def encode_with_eos_ids(tok, texts: List[str], max_len: int) -> Dict[str, torch.Tensor]: """ Must match Stage-3 training: - add_special_tokens=False - truncation to max_len-1 - append <|endoftext|> - left-pad """ pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id eod_id = get_pool_token_id(tok) enc = tok( [str(t) for t in texts], add_special_tokens=False, truncation=True, max_length=max_len - 1, padding=False, return_attention_mask=False, ) input_ids = [ids + [eod_id] for ids in enc["input_ids"]] attn_mask = [[1] * len(ids) for ids in input_ids] T = max(len(ids) for ids in input_ids) if input_ids else 1 input_ids = [[pad_id] * (T - len(ids)) + ids for ids in input_ids] attn_mask = [[0] * (T - len(m)) + m for m in attn_mask] return { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor(attn_mask, dtype=torch.long), } def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """ Left-padding aware last-token pooling (extracts EOS token embedding). """ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: return last_hidden_states[:, -1] idx = attention_mask.sum(dim=1) - 1 return last_hidden_states[torch.arange(last_hidden_states.size(0), device=last_hidden_states.device), idx] def get_last_hidden_state(model, input_ids, attention_mask): """ Provide position_ids for left padding (FlashAttention-2). """ m = model.module if hasattr(model, "module") else model position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 0) out = m( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False, return_dict=True, ) if hasattr(out, "last_hidden_state"): return out.last_hidden_state out = m( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_hidden_states=True, use_cache=False, return_dict=True, ) return out.hidden_states[-1] # ----------------------------- # Stage-3 pooler (query_attn) # ----------------------------- class SectionQueryAttnPooler(nn.Module): """ Match your Stage-3 training pooler. """ def __init__( self, hidden_size: int, num_sections: int, mlp_hidden: int, use_layernorm: bool = True, pool_dropout: float = 0.1, pool_scale: float = 0.0, # 0 => 1/sqrt(H) ): super().__init__() self.hidden_size = int(hidden_size) self.num_sections = int(num_sections) self.ln = nn.LayerNorm(self.hidden_size) if use_layernorm else nn.Identity() self.pool_queries = nn.Parameter(torch.empty(self.num_sections, self.hidden_size)) nn.init.normal_(self.pool_queries, mean=0.0, std=0.02) self.pool_scale = float(pool_scale) if (pool_scale and pool_scale > 0) else (1.0 / math.sqrt(self.hidden_size)) self.pool_dropout = nn.Dropout(pool_dropout) if pool_dropout and pool_dropout > 0 else nn.Identity() # Bias-free MLP self.mlp = nn.Sequential( nn.Linear(self.hidden_size, int(mlp_hidden), bias=False), nn.GELU(), nn.Linear(int(mlp_hidden), self.hidden_size, bias=False), ) def forward_all(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: # hidden_states: [B,T,H] -> [B,S,H] if isinstance(self.ln, nn.LayerNorm): x = F.layer_norm( hidden_states.float(), self.ln.normalized_shape, self.ln.weight.float() if self.ln.weight is not None else None, self.ln.bias.float() if self.ln.bias is not None else None, self.ln.eps, ).to(dtype=hidden_states.dtype) else: x = hidden_states scores = torch.einsum("bth,sh->bts", x.float(), self.pool_queries.float()) * self.pool_scale scores = scores.masked_fill(attention_mask.unsqueeze(-1) == 0, -1e4) attn = torch.softmax(scores, dim=1).to(dtype=x.dtype) # [B,T,S] attn = self.pool_dropout(attn) pooled = torch.einsum("bth,bts->bsh", x, attn) # [B,S,H] pooled = pooled.to(dtype=next(self.mlp.parameters()).dtype) pooled = self.mlp(pooled) return F.normalize(pooled, p=2, dim=-1) def _ensure_pooler_device_dtype(pooler: nn.Module, device: torch.device, dtype: torch.dtype) -> None: p = next(pooler.parameters(), None) if p is None: return if p.device != device or p.dtype != dtype: pooler.to(device=device, dtype=dtype) def _read_json(path: str) -> Dict[str, Any]: with open(path, "r", encoding="utf-8") as f: return json.load(f) def _resolve_repo_path(repo_id_or_path: str) -> str: # If it's a local directory, use it as-is. if os.path.isdir(repo_id_or_path): return repo_id_or_path # Otherwise treat as HF repo_id and download snapshot. if not _HAS_HUB: raise RuntimeError( "huggingface_hub is required to load by repo_id. " "Install it: pip install huggingface_hub" ) return snapshot_download(repo_id_or_path) @dataclass class EmbedOutput: # Always available: section_matrix: torch.Tensor # [N,S,H], float32 on CPU by default global_embedding: torch.Tensor # [N,H], float32 on CPU by default # Convenience dicts: by_section_name: Dict[str, torch.Tensor] # each [N,H] by_alias: Dict[str, torch.Tensor] # alias -> [N,H] class Chest2Vec: """ Lightweight wrapper: - loads base Qwen3-Embedding - applies LoRA adapter - attaches Stage-3 section pooler """ def __init__(self, tokenizer, model, pooler, sections: List[str], device: torch.device): self.tokenizer = tokenizer self.model = model self.pooler = pooler self.sections = list(sections) self.device = device self.model.eval() self.pooler.eval() @classmethod def from_pretrained( cls, repo_id_or_path: str, *, device: str = "cuda:0", use_4bit: bool = False, force_flash_attention_2: bool = True, ) -> "Chest2Vec": repo_path = _resolve_repo_path(repo_id_or_path) cfg_path = os.path.join(repo_path, "chest2vec_config.json") if not os.path.isfile(cfg_path): raise FileNotFoundError(f"Missing chest2vec_config.json in {repo_path}") cfg = _read_json(cfg_path) base_model = str(cfg["base_model"]) adapter_subdir = str(cfg.get("adapter_subdir", "contrastive")) pooler_pt = str(cfg.get("pooler_pt", "section_pooler.pt")) pooler_cfg = str(cfg.get("pooler_cfg", "section_pooler_config.json")) sections = cfg.get("sections", SECTION_NAMES) if force_flash_attention_2 or bool(cfg.get("require_flash_attention_2", False)): attn_impl = require_flash_attention_2() else: attn_impl = "sdpa" if not _HAS_PEFT: raise RuntimeError("peft is required. Install: pip install peft") device_t = torch.device(device) tokenizer = AutoTokenizer.from_pretrained(base_model, padding_side="left", trust_remote_code=True) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token device_map = {"": str(device_t)} # Load base model with FlashAttention-2 if use_4bit: qconf = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, ) try: base = AutoModel.from_pretrained( base_model, trust_remote_code=True, attn_implementation=attn_impl, quantization_config=qconf, device_map=device_map, ) except TypeError as e: raise RuntimeError( "Your transformers version does not support attn_implementation=... " "Upgrade transformers to use FlashAttention-2." ) from e else: try: base = AutoModel.from_pretrained( base_model, trust_remote_code=True, attn_implementation=attn_impl, torch_dtype=torch.bfloat16, device_map=device_map, ) except TypeError as e: raise RuntimeError( "Your transformers version does not support attn_implementation=... " "Upgrade transformers to use FlashAttention-2." ) from e # Load adapter from this repo folder adapter_dir = os.path.join(repo_path, adapter_subdir) if not os.path.isfile(os.path.join(adapter_dir, "adapter_config.json")): raise FileNotFoundError(f"adapter_config.json not found under: {adapter_dir}") model = PeftModel.from_pretrained(base, adapter_dir) model.eval() # Attach section pooler pooler_cfg_path = os.path.join(repo_path, pooler_cfg) pooler_pt_path = os.path.join(repo_path, pooler_pt) if not os.path.isfile(pooler_cfg_path): raise FileNotFoundError(f"Missing pooler config: {pooler_cfg_path}") if not os.path.isfile(pooler_pt_path): raise FileNotFoundError(f"Missing pooler weights: {pooler_pt_path}") pcfg = _read_json(pooler_cfg_path) hidden_size = int(getattr(model.module if hasattr(model, "module") else model, "config").hidden_size) mlp_hidden = int(pcfg.get("mlp_hidden", hidden_size)) use_layernorm = bool(pcfg.get("use_layernorm", True)) pool_dropout = float(pcfg.get("pool_dropout", 0.1)) pool_scale = float(pcfg.get("pool_scale", 0.0)) pooler = SectionQueryAttnPooler( hidden_size=hidden_size, num_sections=len(sections), mlp_hidden=mlp_hidden, use_layernorm=use_layernorm, pool_dropout=pool_dropout, pool_scale=pool_scale, ) sd = torch.load(pooler_pt_path, map_location="cpu") pooler.load_state_dict(sd, strict=True) pooler.eval() # Move pooler to same device/dtype as hidden states # (we keep inference in autocast) pooler.to(device=device_t, dtype=torch.bfloat16 if device_t.type == "cuda" else torch.float32) return cls(tokenizer=tokenizer, model=model, pooler=pooler, sections=sections, device=device_t) @torch.inference_mode() def embed_texts( self, texts: List[str], *, max_len: int = 512, batch_size: int = 16, return_cpu_float32: bool = True, ) -> EmbedOutput: """ Encodes arbitrary texts (candidates, section strings, etc.) NOTE: This uses Stage-3 section pooling: - Section embeddings: section_pooler → [B,S,H] (9 section-specific embeddings) - Global embedding: EOS token embedding extracted BEFORE pooler → [B,H] (matches Stage-3 training) Returns: - section_matrix: [N,9,H] - section-specific embeddings - global_embedding: [N,H] - EOS token embedding (extracted before pooler) - by_section_name: dict[name] -> [N,H] - by_alias: dict['lungs'/'impression'/...] -> [N,H] """ # Determine AMP device = self.device if device.type == "cuda": amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 use_amp = True else: amp_dtype = torch.float32 use_amp = False outs_sec = [] outs_global = [] for i in range(0, len(texts), batch_size): chunk = [str(t) for t in texts[i:i + batch_size]] enc = encode_with_eos_ids(self.tokenizer, chunk, max_len) input_ids = enc["input_ids"].to(device, non_blocking=True) attention_mask = enc["attention_mask"].to(device, non_blocking=True) with torch.autocast(device_type=("cuda" if device.type == "cuda" else "cpu"), dtype=amp_dtype, enabled=use_amp): h = get_last_hidden_state(self.model, input_ids, attention_mask) # [B,T,H] # Global embedding: extract EOS token embedding BEFORE pooler (matches Stage-3 training) global_eos = last_token_pool(h, attention_mask) # [B,H] global_eos = F.normalize(global_eos.float(), p=2, dim=-1) # Section embeddings: pass through pooler _ensure_pooler_device_dtype(self.pooler, device=h.device, dtype=h.dtype) sec = self.pooler.forward_all(h, attention_mask) # [B,S,H] normalized outs_sec.append(sec.detach()) outs_global.append(global_eos.detach()) section_matrix = torch.cat(outs_sec, dim=0) # on device, dtype ~ bf16 global_emb = torch.cat(outs_global, dim=0) # on device, dtype ~ bf16 # Move to CPU float32 if requested (recommended for retrieval stability) if return_cpu_float32: section_matrix_cpu = section_matrix.float().cpu() # re-normalize to fix any numerical drift section_matrix_cpu = F.normalize(section_matrix_cpu, p=2, dim=-1) global_cpu = global_emb.float().cpu() global_cpu = F.normalize(global_cpu, p=2, dim=-1) else: section_matrix_cpu = section_matrix global_cpu = global_emb by_section_name = {name: section_matrix_cpu[:, idx, :] for idx, name in enumerate(self.sections)} # Helpful aliases for quick access by_alias: Dict[str, torch.Tensor] = {} by_alias["global"] = global_cpu for alias, real in SECTION_ALIASES.items(): if real == "global": continue if real in by_section_name: by_alias[alias] = by_section_name[real] return EmbedOutput( section_matrix=section_matrix_cpu, global_embedding=global_cpu, by_section_name=by_section_name, by_alias=by_alias, ) @torch.inference_mode() def embed_instruction_query( self, instructions: List[str], queries: List[str], *, max_len: int = 512, batch_size: int = 16, return_cpu_float32: bool = True, ) -> EmbedOutput: if len(instructions) != len(queries): raise ValueError("instructions and queries must have the same length.") q_texts = [build_qwen_query(i, q) for i, q in zip(instructions, queries)] return self.embed_texts( q_texts, max_len=max_len, batch_size=batch_size, return_cpu_float32=return_cpu_float32, ) @staticmethod def cosine_topk( query_emb: torch.Tensor, # [Nq,H] CPU float32 recommended cand_emb: torch.Tensor, # [Nd,H] CPU float32 recommended k: int = 10, *, device: str = "cuda", query_batch_size: int = 256, doc_chunk_size: int = 8192, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Chunked cosine top-k, stable in float32. Returns (top_scores [Nq,k], top_indices [Nq,k]) on CPU. """ device_t = torch.device(device) q = F.normalize(query_emb.float(), p=2, dim=-1) d = F.normalize(cand_emb.float(), p=2, dim=-1) Nq, H = q.shape Nd = d.shape[0] k = min(int(k), Nd) top_scores_all = torch.empty((Nq, k), dtype=torch.float32) top_indices_all = torch.empty((Nq, k), dtype=torch.long) for qs in range(0, Nq, query_batch_size): qe = q[qs:qs + query_batch_size].to(device_t, non_blocking=True) bq = qe.size(0) top_scores = torch.full((bq, k), -1e9, device=device_t, dtype=torch.float32) top_indices = torch.full((bq, k), -1, device=device_t, dtype=torch.long) for ds in range(0, Nd, doc_chunk_size): de = d[ds:ds + doc_chunk_size].to(device_t, non_blocking=True) scores = (qe @ de.T).float() chunk = scores.size(1) idx_chunk = torch.arange(ds, ds + chunk, device=device_t, dtype=torch.long).unsqueeze(0).expand(bq, -1) comb_scores = torch.cat([top_scores, scores], dim=1) comb_idx = torch.cat([top_indices, idx_chunk], dim=1) new_scores, new_pos = torch.topk(comb_scores, k, dim=1) new_idx = comb_idx.gather(1, new_pos) top_scores, top_indices = new_scores, new_idx top_scores_all[qs:qs + bq] = top_scores.cpu() top_indices_all[qs:qs + bq] = top_indices.cpu() return top_scores_all, top_indices_all