Feature Extraction
Transformers
Safetensors
chest2vec
text-embeddings
retrieval
radiology
chest
qwen
custom_code
Instructions to use chest2vec/chest2vec_4B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use chest2vec/chest2vec_4B with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="chest2vec/chest2vec_4B", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("chest2vec/chest2vec_4B", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 13,043 Bytes
c036088 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 | """Chest2Vec — LoRA-tuned Qwen3-Embedding model for chest radiology reports.
Load with:
from transformers import AutoModel
model = AutoModel.from_pretrained("chest2vec/chest2vec_0.6B", trust_remote_code=True)
emb = model.embed_texts(["Frontal chest radiograph. No pneumothorax."]) # [N, H], L2-normalized
Architecture:
1. Base : Qwen/Qwen3-Embedding-{0.6B,4B} (downloaded at runtime)
2. Adapter: frozen contrastive LoRA adapter (./contrastive)
Embeddings use last-token (EOS) pooling with left padding, matching Qwen3-Embedding
and the Stage-2 training setup. FlashAttention-2 is used when CUDA + flash-attn>=2
are available (matching training); otherwise it falls back to SDPA so the model
also loads on CPU.
"""
import os
from typing import Dict, List, Optional
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig, PreTrainedModel
from .configuration_chest2vec import Chest2VecConfig
try:
from peft import PeftModel
_HAS_PEFT = True
except Exception:
PeftModel = None
_HAS_PEFT = False
try:
from huggingface_hub import snapshot_download
_HAS_HUB = True
except Exception:
snapshot_download = None
_HAS_HUB = False
# ----------------------------------------------------------------------------
# Attention backend selection
# ----------------------------------------------------------------------------
def _flash_attn_available() -> bool:
if not torch.cuda.is_available():
return False
try:
import flash_attn # noqa: F401
ver = getattr(flash_attn, "__version__", "0.0.0")
return int(str(ver).split(".")[0]) >= 2
except Exception:
return False
def _pick_attn_impl(requested: Optional[str], want_flash: bool) -> str:
import warnings
if requested:
return requested
if want_flash and _flash_attn_available():
return "flash_attention_2"
if want_flash:
warnings.warn(
"Chest2Vec was trained with FlashAttention-2, but it is unavailable "
"(needs CUDA + flash-attn>=2). Falling back to 'sdpa'; embeddings may "
"differ very slightly from the reference implementation.",
RuntimeWarning,
)
return "sdpa"
# ----------------------------------------------------------------------------
# Tokenization / pooling helpers (match Qwen3-Embedding + training)
# ----------------------------------------------------------------------------
def build_qwen_query(instruction: str, query: str) -> str:
return f"Instruct: {str(instruction).strip()}\nQuery: {str(query).strip()}"
def get_pool_token_id(tok) -> int:
eod_id = tok.convert_tokens_to_ids("<|endoftext|>")
if eod_id is None or eod_id < 0:
eod_id = tok.pad_token_id
return eod_id
def encode_with_eos_ids(tok, texts: List[str], max_len: int) -> Dict[str, torch.Tensor]:
"""add_special_tokens=False, truncate to max_len-1, append <|endoftext|>, left-pad."""
pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
eod_id = get_pool_token_id(tok)
enc = tok(
[str(t) for t in texts],
add_special_tokens=False,
truncation=True,
max_length=max_len - 1,
padding=False,
return_attention_mask=False,
)
input_ids = [ids + [eod_id] for ids in enc["input_ids"]]
attn_mask = [[1] * len(ids) for ids in input_ids]
T = max((len(ids) for ids in input_ids), default=1)
input_ids = [[pad_id] * (T - len(ids)) + ids for ids in input_ids]
attn_mask = [[0] * (T - len(m)) + m for m in attn_mask]
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attn_mask, dtype=torch.long),
}
def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Left-padding-aware last-token (EOS) pooling."""
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
idx = attention_mask.sum(dim=1) - 1
return last_hidden_states[torch.arange(last_hidden_states.size(0), device=last_hidden_states.device), idx]
def get_last_hidden_state(model, input_ids, attention_mask):
m = model.module if hasattr(model, "module") else model
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
out = m(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
use_cache=False, return_dict=True)
if getattr(out, "last_hidden_state", None) is not None:
return out.last_hidden_state
out = m(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
output_hidden_states=True, use_cache=False, return_dict=True)
return out.hidden_states[-1]
class Chest2VecModel(PreTrainedModel):
"""LoRA-tuned Qwen3-Embedding model producing L2-normalized report embeddings."""
config_class = Chest2VecConfig
base_model_prefix = "chest2vec"
# Attention is handled by the inner Qwen3 backbone; advertise support so the
# transformers attn-implementation validator on this wrapper passes.
_supports_sdpa = True
_supports_flash_attn_2 = True
_supports_flash_attn = True
_supports_attention_backend = True
def __init__(self, config: Chest2VecConfig):
super().__init__(config)
# The base+adapter are assembled in `from_pretrained` (base downloads at runtime).
self.backbone = None
self.tokenizer = None
self._device = torch.device("cpu")
self.register_buffer("_anchor", torch.zeros(1), persistent=False)
def get_input_embeddings(self):
return None
def set_input_embeddings(self, value):
pass
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)
device = kwargs.pop("device", None)
use_4bit = kwargs.pop("use_4bit", False)
attn_implementation = kwargs.pop("attn_implementation", None)
torch_dtype = kwargs.pop("torch_dtype", None)
token = kwargs.pop("token", None) or kwargs.pop("use_auth_token", None)
cache_dir = kwargs.pop("cache_dir", None)
# remaining HF plumbing kwargs (state_dict, low_cpu_mem_usage, ...) are ignored
repo_path = pretrained_model_name_or_path
if not os.path.isdir(repo_path):
if not _HAS_HUB:
raise RuntimeError("huggingface_hub is required to load by repo_id.")
repo_path = snapshot_download(repo_path, token=token, cache_dir=cache_dir)
if config is None:
config = Chest2VecConfig.from_pretrained(repo_path)
if device is None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device_t = torch.device(device)
if torch_dtype is None:
torch_dtype = torch.bfloat16 if device_t.type == "cuda" else torch.float32
model = cls(config)
model._assemble(repo_path, device=device_t, use_4bit=use_4bit,
attn_implementation=attn_implementation, torch_dtype=torch_dtype, token=token)
return model
def _assemble(self, repo_path, *, device, use_4bit, attn_implementation, torch_dtype, token=None):
cfg = self.config
if not _HAS_PEFT:
raise RuntimeError("peft is required. Install: pip install peft")
attn_impl = _pick_attn_impl(attn_implementation, bool(cfg.require_flash_attention_2))
tokenizer = AutoTokenizer.from_pretrained(
cfg.base_model, padding_side="left", trust_remote_code=True, token=token
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
base_kwargs = dict(trust_remote_code=True, attn_implementation=attn_impl, token=token)
if use_4bit:
base_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16,
)
base_kwargs["device_map"] = {"": str(device)}
else:
base_kwargs["torch_dtype"] = torch_dtype
if device.type == "cuda":
base_kwargs["device_map"] = {"": str(device)}
try:
base = AutoModel.from_pretrained(cfg.base_model, **base_kwargs)
except TypeError as e:
raise RuntimeError("transformers too old for attn_implementation=...; please upgrade.") from e
if device.type != "cuda" and not use_4bit:
base = base.to(device)
adapter_dir = os.path.join(repo_path, cfg.adapter_subdir)
if not os.path.isfile(os.path.join(adapter_dir, "adapter_config.json")):
raise FileNotFoundError(f"adapter_config.json not found under: {adapter_dir}")
backbone = PeftModel.from_pretrained(base, adapter_dir)
backbone.eval()
self.backbone = backbone
self.tokenizer = tokenizer
self._device = device
self.eval()
@property
def device(self):
return self._device
@torch.inference_mode()
def embed_texts(self, texts: List[str], *, max_len: Optional[int] = None,
batch_size: int = 16, return_cpu_float32: bool = True) -> torch.Tensor:
"""Return L2-normalized report embeddings, shape [N, H]."""
if self.backbone is None:
raise RuntimeError("Model not assembled; load via from_pretrained(...).")
max_len = int(max_len or self.config.default_max_len)
device = self._device
if device.type == "cuda":
amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
use_amp = True
else:
amp_dtype, use_amp = torch.float32, False
outs = []
for i in range(0, len(texts), batch_size):
chunk = [str(t) for t in texts[i:i + batch_size]]
enc = encode_with_eos_ids(self.tokenizer, chunk, max_len)
input_ids = enc["input_ids"].to(device, non_blocking=True)
attention_mask = enc["attention_mask"].to(device, non_blocking=True)
with torch.autocast(device_type=("cuda" if device.type == "cuda" else "cpu"),
dtype=amp_dtype, enabled=use_amp):
h = get_last_hidden_state(self.backbone, input_ids, attention_mask)
emb = F.normalize(last_token_pool(h, attention_mask).float(), p=2, dim=-1)
outs.append(emb.detach())
embeddings = torch.cat(outs, dim=0)
if return_cpu_float32:
embeddings = F.normalize(embeddings.float().cpu(), p=2, dim=-1)
return embeddings
@torch.inference_mode()
def embed_instruction_query(self, instructions: List[str], queries: List[str], **kw) -> torch.Tensor:
if len(instructions) != len(queries):
raise ValueError("instructions and queries must have the same length.")
return self.embed_texts([build_qwen_query(i, q) for i, q in zip(instructions, queries)], **kw)
def forward(self, texts: List[str], **kw) -> torch.Tensor: # type: ignore[override]
return self.embed_texts(texts, **kw)
@staticmethod
def cosine_topk(query_emb, cand_emb, k=10, *, device="cuda",
query_batch_size=256, doc_chunk_size=8192):
device_t = torch.device(device if torch.cuda.is_available() else "cpu")
q = F.normalize(query_emb.float(), p=2, dim=-1)
d = F.normalize(cand_emb.float(), p=2, dim=-1)
Nq, _ = q.shape
Nd = d.shape[0]
k = min(int(k), Nd)
top_scores_all = torch.empty((Nq, k), dtype=torch.float32)
top_indices_all = torch.empty((Nq, k), dtype=torch.long)
for qs in range(0, Nq, query_batch_size):
qe = q[qs:qs + query_batch_size].to(device_t, non_blocking=True)
bq = qe.size(0)
top_scores = torch.full((bq, k), -1e9, device=device_t, dtype=torch.float32)
top_indices = torch.full((bq, k), -1, device=device_t, dtype=torch.long)
for ds in range(0, Nd, doc_chunk_size):
de = d[ds:ds + doc_chunk_size].to(device_t, non_blocking=True)
scores = (qe @ de.T).float()
chunk = scores.size(1)
idx_chunk = torch.arange(ds, ds + chunk, device=device_t, dtype=torch.long).unsqueeze(0).expand(bq, -1)
comb_scores = torch.cat([top_scores, scores], dim=1)
comb_idx = torch.cat([top_indices, idx_chunk], dim=1)
new_scores, new_pos = torch.topk(comb_scores, k, dim=1)
top_scores, top_indices = new_scores, comb_idx.gather(1, new_pos)
top_scores_all[qs:qs + bq] = top_scores.cpu()
top_indices_all[qs:qs + bq] = top_indices.cpu()
return top_scores_all, top_indices_all
|