Feature Extraction
Transformers
Safetensors
chest2vec
text-embeddings
retrieval
radiology
chest
qwen
custom_code
Instructions to use chest2vec/chest2vec_4B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use chest2vec/chest2vec_4B with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="chest2vec/chest2vec_4B", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("chest2vec/chest2vec_4B", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Chest2Vec — LoRA-tuned Qwen3-Embedding model for chest radiology reports. | |
| Load with: | |
| from transformers import AutoModel | |
| model = AutoModel.from_pretrained("chest2vec/chest2vec_0.6B", trust_remote_code=True) | |
| emb = model.embed_texts(["Frontal chest radiograph. No pneumothorax."]) # [N, H], L2-normalized | |
| Architecture: | |
| 1. Base : Qwen/Qwen3-Embedding-{0.6B,4B} (downloaded at runtime) | |
| 2. Adapter: frozen contrastive LoRA adapter (./contrastive) | |
| Embeddings use last-token (EOS) pooling with left padding, matching Qwen3-Embedding | |
| and the Stage-2 training setup. FlashAttention-2 is used when CUDA + flash-attn>=2 | |
| are available (matching training); otherwise it falls back to SDPA so the model | |
| also loads on CPU. | |
| """ | |
| import os | |
| from typing import Dict, List, Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig, PreTrainedModel | |
| from .configuration_chest2vec import Chest2VecConfig | |
| try: | |
| from peft import PeftModel | |
| _HAS_PEFT = True | |
| except Exception: | |
| PeftModel = None | |
| _HAS_PEFT = False | |
| try: | |
| from huggingface_hub import snapshot_download | |
| _HAS_HUB = True | |
| except Exception: | |
| snapshot_download = None | |
| _HAS_HUB = False | |
| # ---------------------------------------------------------------------------- | |
| # Attention backend selection | |
| # ---------------------------------------------------------------------------- | |
| def _flash_attn_available() -> bool: | |
| if not torch.cuda.is_available(): | |
| return False | |
| try: | |
| import flash_attn # noqa: F401 | |
| ver = getattr(flash_attn, "__version__", "0.0.0") | |
| return int(str(ver).split(".")[0]) >= 2 | |
| except Exception: | |
| return False | |
| def _pick_attn_impl(requested: Optional[str], want_flash: bool) -> str: | |
| import warnings | |
| if requested: | |
| return requested | |
| if want_flash and _flash_attn_available(): | |
| return "flash_attention_2" | |
| if want_flash: | |
| warnings.warn( | |
| "Chest2Vec was trained with FlashAttention-2, but it is unavailable " | |
| "(needs CUDA + flash-attn>=2). Falling back to 'sdpa'; embeddings may " | |
| "differ very slightly from the reference implementation.", | |
| RuntimeWarning, | |
| ) | |
| return "sdpa" | |
| # ---------------------------------------------------------------------------- | |
| # Tokenization / pooling helpers (match Qwen3-Embedding + training) | |
| # ---------------------------------------------------------------------------- | |
| def build_qwen_query(instruction: str, query: str) -> str: | |
| return f"Instruct: {str(instruction).strip()}\nQuery: {str(query).strip()}" | |
| def get_pool_token_id(tok) -> int: | |
| eod_id = tok.convert_tokens_to_ids("<|endoftext|>") | |
| if eod_id is None or eod_id < 0: | |
| eod_id = tok.pad_token_id | |
| return eod_id | |
| def encode_with_eos_ids(tok, texts: List[str], max_len: int) -> Dict[str, torch.Tensor]: | |
| """add_special_tokens=False, truncate to max_len-1, append <|endoftext|>, left-pad.""" | |
| pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id | |
| eod_id = get_pool_token_id(tok) | |
| enc = tok( | |
| [str(t) for t in texts], | |
| add_special_tokens=False, | |
| truncation=True, | |
| max_length=max_len - 1, | |
| padding=False, | |
| return_attention_mask=False, | |
| ) | |
| input_ids = [ids + [eod_id] for ids in enc["input_ids"]] | |
| attn_mask = [[1] * len(ids) for ids in input_ids] | |
| T = max((len(ids) for ids in input_ids), default=1) | |
| input_ids = [[pad_id] * (T - len(ids)) + ids for ids in input_ids] | |
| attn_mask = [[0] * (T - len(m)) + m for m in attn_mask] | |
| return { | |
| "input_ids": torch.tensor(input_ids, dtype=torch.long), | |
| "attention_mask": torch.tensor(attn_mask, dtype=torch.long), | |
| } | |
| def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | |
| """Left-padding-aware last-token (EOS) pooling.""" | |
| left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) | |
| if left_padding: | |
| return last_hidden_states[:, -1] | |
| idx = attention_mask.sum(dim=1) - 1 | |
| return last_hidden_states[torch.arange(last_hidden_states.size(0), device=last_hidden_states.device), idx] | |
| def get_last_hidden_state(model, input_ids, attention_mask): | |
| m = model.module if hasattr(model, "module") else model | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(attention_mask == 0, 0) | |
| out = m(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, | |
| use_cache=False, return_dict=True) | |
| if getattr(out, "last_hidden_state", None) is not None: | |
| return out.last_hidden_state | |
| out = m(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, | |
| output_hidden_states=True, use_cache=False, return_dict=True) | |
| return out.hidden_states[-1] | |
| class Chest2VecModel(PreTrainedModel): | |
| """LoRA-tuned Qwen3-Embedding model producing L2-normalized report embeddings.""" | |
| config_class = Chest2VecConfig | |
| base_model_prefix = "chest2vec" | |
| # Attention is handled by the inner Qwen3 backbone; advertise support so the | |
| # transformers attn-implementation validator on this wrapper passes. | |
| _supports_sdpa = True | |
| _supports_flash_attn_2 = True | |
| _supports_flash_attn = True | |
| _supports_attention_backend = True | |
| def __init__(self, config: Chest2VecConfig): | |
| super().__init__(config) | |
| # The base+adapter are assembled in `from_pretrained` (base downloads at runtime). | |
| self.backbone = None | |
| self.tokenizer = None | |
| self._device = torch.device("cpu") | |
| self.register_buffer("_anchor", torch.zeros(1), persistent=False) | |
| def get_input_embeddings(self): | |
| return None | |
| def set_input_embeddings(self, value): | |
| pass | |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| config = kwargs.pop("config", None) | |
| device = kwargs.pop("device", None) | |
| use_4bit = kwargs.pop("use_4bit", False) | |
| attn_implementation = kwargs.pop("attn_implementation", None) | |
| torch_dtype = kwargs.pop("torch_dtype", None) | |
| token = kwargs.pop("token", None) or kwargs.pop("use_auth_token", None) | |
| cache_dir = kwargs.pop("cache_dir", None) | |
| # remaining HF plumbing kwargs (state_dict, low_cpu_mem_usage, ...) are ignored | |
| repo_path = pretrained_model_name_or_path | |
| if not os.path.isdir(repo_path): | |
| if not _HAS_HUB: | |
| raise RuntimeError("huggingface_hub is required to load by repo_id.") | |
| repo_path = snapshot_download(repo_path, token=token, cache_dir=cache_dir) | |
| if config is None: | |
| config = Chest2VecConfig.from_pretrained(repo_path) | |
| if device is None: | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| device_t = torch.device(device) | |
| if torch_dtype is None: | |
| torch_dtype = torch.bfloat16 if device_t.type == "cuda" else torch.float32 | |
| model = cls(config) | |
| model._assemble(repo_path, device=device_t, use_4bit=use_4bit, | |
| attn_implementation=attn_implementation, torch_dtype=torch_dtype, token=token) | |
| return model | |
| def _assemble(self, repo_path, *, device, use_4bit, attn_implementation, torch_dtype, token=None): | |
| cfg = self.config | |
| if not _HAS_PEFT: | |
| raise RuntimeError("peft is required. Install: pip install peft") | |
| attn_impl = _pick_attn_impl(attn_implementation, bool(cfg.require_flash_attention_2)) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| cfg.base_model, padding_side="left", trust_remote_code=True, token=token | |
| ) | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| base_kwargs = dict(trust_remote_code=True, attn_implementation=attn_impl, token=token) | |
| if use_4bit: | |
| base_kwargs["quantization_config"] = BitsAndBytesConfig( | |
| load_in_4bit=True, bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| base_kwargs["device_map"] = {"": str(device)} | |
| else: | |
| base_kwargs["torch_dtype"] = torch_dtype | |
| if device.type == "cuda": | |
| base_kwargs["device_map"] = {"": str(device)} | |
| try: | |
| base = AutoModel.from_pretrained(cfg.base_model, **base_kwargs) | |
| except TypeError as e: | |
| raise RuntimeError("transformers too old for attn_implementation=...; please upgrade.") from e | |
| if device.type != "cuda" and not use_4bit: | |
| base = base.to(device) | |
| adapter_dir = os.path.join(repo_path, cfg.adapter_subdir) | |
| if not os.path.isfile(os.path.join(adapter_dir, "adapter_config.json")): | |
| raise FileNotFoundError(f"adapter_config.json not found under: {adapter_dir}") | |
| backbone = PeftModel.from_pretrained(base, adapter_dir) | |
| backbone.eval() | |
| self.backbone = backbone | |
| self.tokenizer = tokenizer | |
| self._device = device | |
| self.eval() | |
| def device(self): | |
| return self._device | |
| def embed_texts(self, texts: List[str], *, max_len: Optional[int] = None, | |
| batch_size: int = 16, return_cpu_float32: bool = True) -> torch.Tensor: | |
| """Return L2-normalized report embeddings, shape [N, H].""" | |
| if self.backbone is None: | |
| raise RuntimeError("Model not assembled; load via from_pretrained(...).") | |
| max_len = int(max_len or self.config.default_max_len) | |
| device = self._device | |
| if device.type == "cuda": | |
| amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| use_amp = True | |
| else: | |
| amp_dtype, use_amp = torch.float32, False | |
| outs = [] | |
| for i in range(0, len(texts), batch_size): | |
| chunk = [str(t) for t in texts[i:i + batch_size]] | |
| enc = encode_with_eos_ids(self.tokenizer, chunk, max_len) | |
| input_ids = enc["input_ids"].to(device, non_blocking=True) | |
| attention_mask = enc["attention_mask"].to(device, non_blocking=True) | |
| with torch.autocast(device_type=("cuda" if device.type == "cuda" else "cpu"), | |
| dtype=amp_dtype, enabled=use_amp): | |
| h = get_last_hidden_state(self.backbone, input_ids, attention_mask) | |
| emb = F.normalize(last_token_pool(h, attention_mask).float(), p=2, dim=-1) | |
| outs.append(emb.detach()) | |
| embeddings = torch.cat(outs, dim=0) | |
| if return_cpu_float32: | |
| embeddings = F.normalize(embeddings.float().cpu(), p=2, dim=-1) | |
| return embeddings | |
| def embed_instruction_query(self, instructions: List[str], queries: List[str], **kw) -> torch.Tensor: | |
| if len(instructions) != len(queries): | |
| raise ValueError("instructions and queries must have the same length.") | |
| return self.embed_texts([build_qwen_query(i, q) for i, q in zip(instructions, queries)], **kw) | |
| def forward(self, texts: List[str], **kw) -> torch.Tensor: # type: ignore[override] | |
| return self.embed_texts(texts, **kw) | |
| def cosine_topk(query_emb, cand_emb, k=10, *, device="cuda", | |
| query_batch_size=256, doc_chunk_size=8192): | |
| device_t = torch.device(device if torch.cuda.is_available() else "cpu") | |
| q = F.normalize(query_emb.float(), p=2, dim=-1) | |
| d = F.normalize(cand_emb.float(), p=2, dim=-1) | |
| Nq, _ = q.shape | |
| Nd = d.shape[0] | |
| k = min(int(k), Nd) | |
| top_scores_all = torch.empty((Nq, k), dtype=torch.float32) | |
| top_indices_all = torch.empty((Nq, k), dtype=torch.long) | |
| for qs in range(0, Nq, query_batch_size): | |
| qe = q[qs:qs + query_batch_size].to(device_t, non_blocking=True) | |
| bq = qe.size(0) | |
| top_scores = torch.full((bq, k), -1e9, device=device_t, dtype=torch.float32) | |
| top_indices = torch.full((bq, k), -1, device=device_t, dtype=torch.long) | |
| for ds in range(0, Nd, doc_chunk_size): | |
| de = d[ds:ds + doc_chunk_size].to(device_t, non_blocking=True) | |
| scores = (qe @ de.T).float() | |
| chunk = scores.size(1) | |
| idx_chunk = torch.arange(ds, ds + chunk, device=device_t, dtype=torch.long).unsqueeze(0).expand(bq, -1) | |
| comb_scores = torch.cat([top_scores, scores], dim=1) | |
| comb_idx = torch.cat([top_indices, idx_chunk], dim=1) | |
| new_scores, new_pos = torch.topk(comb_scores, k, dim=1) | |
| top_scores, top_indices = new_scores, comb_idx.gather(1, new_pos) | |
| top_scores_all[qs:qs + bq] = top_scores.cpu() | |
| top_indices_all[qs:qs + bq] = top_indices.cpu() | |
| return top_scores_all, top_indices_all | |