chest2vec_4B / modeling_chest2vec.py
lukeingawesome's picture
Add trust_remote_code integration (Qwen3-Embedding + LoRA)
c036088 verified
"""Chest2Vec — LoRA-tuned Qwen3-Embedding model for chest radiology reports.
Load with:
from transformers import AutoModel
model = AutoModel.from_pretrained("chest2vec/chest2vec_0.6B", trust_remote_code=True)
emb = model.embed_texts(["Frontal chest radiograph. No pneumothorax."]) # [N, H], L2-normalized
Architecture:
1. Base : Qwen/Qwen3-Embedding-{0.6B,4B} (downloaded at runtime)
2. Adapter: frozen contrastive LoRA adapter (./contrastive)
Embeddings use last-token (EOS) pooling with left padding, matching Qwen3-Embedding
and the Stage-2 training setup. FlashAttention-2 is used when CUDA + flash-attn>=2
are available (matching training); otherwise it falls back to SDPA so the model
also loads on CPU.
"""
import os
from typing import Dict, List, Optional
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig, PreTrainedModel
from .configuration_chest2vec import Chest2VecConfig
try:
from peft import PeftModel
_HAS_PEFT = True
except Exception:
PeftModel = None
_HAS_PEFT = False
try:
from huggingface_hub import snapshot_download
_HAS_HUB = True
except Exception:
snapshot_download = None
_HAS_HUB = False
# ----------------------------------------------------------------------------
# Attention backend selection
# ----------------------------------------------------------------------------
def _flash_attn_available() -> bool:
if not torch.cuda.is_available():
return False
try:
import flash_attn # noqa: F401
ver = getattr(flash_attn, "__version__", "0.0.0")
return int(str(ver).split(".")[0]) >= 2
except Exception:
return False
def _pick_attn_impl(requested: Optional[str], want_flash: bool) -> str:
import warnings
if requested:
return requested
if want_flash and _flash_attn_available():
return "flash_attention_2"
if want_flash:
warnings.warn(
"Chest2Vec was trained with FlashAttention-2, but it is unavailable "
"(needs CUDA + flash-attn>=2). Falling back to 'sdpa'; embeddings may "
"differ very slightly from the reference implementation.",
RuntimeWarning,
)
return "sdpa"
# ----------------------------------------------------------------------------
# Tokenization / pooling helpers (match Qwen3-Embedding + training)
# ----------------------------------------------------------------------------
def build_qwen_query(instruction: str, query: str) -> str:
return f"Instruct: {str(instruction).strip()}\nQuery: {str(query).strip()}"
def get_pool_token_id(tok) -> int:
eod_id = tok.convert_tokens_to_ids("<|endoftext|>")
if eod_id is None or eod_id < 0:
eod_id = tok.pad_token_id
return eod_id
def encode_with_eos_ids(tok, texts: List[str], max_len: int) -> Dict[str, torch.Tensor]:
"""add_special_tokens=False, truncate to max_len-1, append <|endoftext|>, left-pad."""
pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
eod_id = get_pool_token_id(tok)
enc = tok(
[str(t) for t in texts],
add_special_tokens=False,
truncation=True,
max_length=max_len - 1,
padding=False,
return_attention_mask=False,
)
input_ids = [ids + [eod_id] for ids in enc["input_ids"]]
attn_mask = [[1] * len(ids) for ids in input_ids]
T = max((len(ids) for ids in input_ids), default=1)
input_ids = [[pad_id] * (T - len(ids)) + ids for ids in input_ids]
attn_mask = [[0] * (T - len(m)) + m for m in attn_mask]
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attn_mask, dtype=torch.long),
}
def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Left-padding-aware last-token (EOS) pooling."""
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
idx = attention_mask.sum(dim=1) - 1
return last_hidden_states[torch.arange(last_hidden_states.size(0), device=last_hidden_states.device), idx]
def get_last_hidden_state(model, input_ids, attention_mask):
m = model.module if hasattr(model, "module") else model
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
out = m(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
use_cache=False, return_dict=True)
if getattr(out, "last_hidden_state", None) is not None:
return out.last_hidden_state
out = m(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
output_hidden_states=True, use_cache=False, return_dict=True)
return out.hidden_states[-1]
class Chest2VecModel(PreTrainedModel):
"""LoRA-tuned Qwen3-Embedding model producing L2-normalized report embeddings."""
config_class = Chest2VecConfig
base_model_prefix = "chest2vec"
# Attention is handled by the inner Qwen3 backbone; advertise support so the
# transformers attn-implementation validator on this wrapper passes.
_supports_sdpa = True
_supports_flash_attn_2 = True
_supports_flash_attn = True
_supports_attention_backend = True
def __init__(self, config: Chest2VecConfig):
super().__init__(config)
# The base+adapter are assembled in `from_pretrained` (base downloads at runtime).
self.backbone = None
self.tokenizer = None
self._device = torch.device("cpu")
self.register_buffer("_anchor", torch.zeros(1), persistent=False)
def get_input_embeddings(self):
return None
def set_input_embeddings(self, value):
pass
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)
device = kwargs.pop("device", None)
use_4bit = kwargs.pop("use_4bit", False)
attn_implementation = kwargs.pop("attn_implementation", None)
torch_dtype = kwargs.pop("torch_dtype", None)
token = kwargs.pop("token", None) or kwargs.pop("use_auth_token", None)
cache_dir = kwargs.pop("cache_dir", None)
# remaining HF plumbing kwargs (state_dict, low_cpu_mem_usage, ...) are ignored
repo_path = pretrained_model_name_or_path
if not os.path.isdir(repo_path):
if not _HAS_HUB:
raise RuntimeError("huggingface_hub is required to load by repo_id.")
repo_path = snapshot_download(repo_path, token=token, cache_dir=cache_dir)
if config is None:
config = Chest2VecConfig.from_pretrained(repo_path)
if device is None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device_t = torch.device(device)
if torch_dtype is None:
torch_dtype = torch.bfloat16 if device_t.type == "cuda" else torch.float32
model = cls(config)
model._assemble(repo_path, device=device_t, use_4bit=use_4bit,
attn_implementation=attn_implementation, torch_dtype=torch_dtype, token=token)
return model
def _assemble(self, repo_path, *, device, use_4bit, attn_implementation, torch_dtype, token=None):
cfg = self.config
if not _HAS_PEFT:
raise RuntimeError("peft is required. Install: pip install peft")
attn_impl = _pick_attn_impl(attn_implementation, bool(cfg.require_flash_attention_2))
tokenizer = AutoTokenizer.from_pretrained(
cfg.base_model, padding_side="left", trust_remote_code=True, token=token
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
base_kwargs = dict(trust_remote_code=True, attn_implementation=attn_impl, token=token)
if use_4bit:
base_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16,
)
base_kwargs["device_map"] = {"": str(device)}
else:
base_kwargs["torch_dtype"] = torch_dtype
if device.type == "cuda":
base_kwargs["device_map"] = {"": str(device)}
try:
base = AutoModel.from_pretrained(cfg.base_model, **base_kwargs)
except TypeError as e:
raise RuntimeError("transformers too old for attn_implementation=...; please upgrade.") from e
if device.type != "cuda" and not use_4bit:
base = base.to(device)
adapter_dir = os.path.join(repo_path, cfg.adapter_subdir)
if not os.path.isfile(os.path.join(adapter_dir, "adapter_config.json")):
raise FileNotFoundError(f"adapter_config.json not found under: {adapter_dir}")
backbone = PeftModel.from_pretrained(base, adapter_dir)
backbone.eval()
self.backbone = backbone
self.tokenizer = tokenizer
self._device = device
self.eval()
@property
def device(self):
return self._device
@torch.inference_mode()
def embed_texts(self, texts: List[str], *, max_len: Optional[int] = None,
batch_size: int = 16, return_cpu_float32: bool = True) -> torch.Tensor:
"""Return L2-normalized report embeddings, shape [N, H]."""
if self.backbone is None:
raise RuntimeError("Model not assembled; load via from_pretrained(...).")
max_len = int(max_len or self.config.default_max_len)
device = self._device
if device.type == "cuda":
amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
use_amp = True
else:
amp_dtype, use_amp = torch.float32, False
outs = []
for i in range(0, len(texts), batch_size):
chunk = [str(t) for t in texts[i:i + batch_size]]
enc = encode_with_eos_ids(self.tokenizer, chunk, max_len)
input_ids = enc["input_ids"].to(device, non_blocking=True)
attention_mask = enc["attention_mask"].to(device, non_blocking=True)
with torch.autocast(device_type=("cuda" if device.type == "cuda" else "cpu"),
dtype=amp_dtype, enabled=use_amp):
h = get_last_hidden_state(self.backbone, input_ids, attention_mask)
emb = F.normalize(last_token_pool(h, attention_mask).float(), p=2, dim=-1)
outs.append(emb.detach())
embeddings = torch.cat(outs, dim=0)
if return_cpu_float32:
embeddings = F.normalize(embeddings.float().cpu(), p=2, dim=-1)
return embeddings
@torch.inference_mode()
def embed_instruction_query(self, instructions: List[str], queries: List[str], **kw) -> torch.Tensor:
if len(instructions) != len(queries):
raise ValueError("instructions and queries must have the same length.")
return self.embed_texts([build_qwen_query(i, q) for i, q in zip(instructions, queries)], **kw)
def forward(self, texts: List[str], **kw) -> torch.Tensor: # type: ignore[override]
return self.embed_texts(texts, **kw)
@staticmethod
def cosine_topk(query_emb, cand_emb, k=10, *, device="cuda",
query_batch_size=256, doc_chunk_size=8192):
device_t = torch.device(device if torch.cuda.is_available() else "cpu")
q = F.normalize(query_emb.float(), p=2, dim=-1)
d = F.normalize(cand_emb.float(), p=2, dim=-1)
Nq, _ = q.shape
Nd = d.shape[0]
k = min(int(k), Nd)
top_scores_all = torch.empty((Nq, k), dtype=torch.float32)
top_indices_all = torch.empty((Nq, k), dtype=torch.long)
for qs in range(0, Nq, query_batch_size):
qe = q[qs:qs + query_batch_size].to(device_t, non_blocking=True)
bq = qe.size(0)
top_scores = torch.full((bq, k), -1e9, device=device_t, dtype=torch.float32)
top_indices = torch.full((bq, k), -1, device=device_t, dtype=torch.long)
for ds in range(0, Nd, doc_chunk_size):
de = d[ds:ds + doc_chunk_size].to(device_t, non_blocking=True)
scores = (qe @ de.T).float()
chunk = scores.size(1)
idx_chunk = torch.arange(ds, ds + chunk, device=device_t, dtype=torch.long).unsqueeze(0).expand(bq, -1)
comb_scores = torch.cat([top_scores, scores], dim=1)
comb_idx = torch.cat([top_indices, idx_chunk], dim=1)
new_scores, new_pos = torch.topk(comb_scores, k, dim=1)
top_scores, top_indices = new_scores, comb_idx.gather(1, new_pos)
top_scores_all[qs:qs + bq] = top_scores.cpu()
top_indices_all[qs:qs + bq] = top_indices.cpu()
return top_scores_all, top_indices_all