--- tags: - text-embeddings - retrieval - radiology - chest - qwen library_name: transformers --- # chest2vec_4B This repository contains the *delta weights* for a global embedding model on top of **Qwen/Qwen3-Embedding-4B**: - **LoRA Adapter**: Contrastive LoRA adapter trained with multi-positive sigmoid loss under `./contrastive/` - **Inference helper**: `chest2vec.py` Base model weights are **not** included; they are downloaded from Hugging Face at runtime. ## Model Architecture Chest2Vec is a two-stage model: 1. **Base**: Qwen/Qwen3-Embedding-4B (downloaded at runtime) 2. **LoRA Adapter**: Contrastive LoRA adapter trained with multi-positive sigmoid loss 3. **Pooling**: Last-token pooling (EOS token) for global embeddings The model produces **global embeddings only** (no section-specific embeddings). ## Installation Install the package and all dependencies: ```bash # Install PyTorch with CUDA 12.6 support pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126 # Install transformers and trl pip install transformers==4.57.3 trl==0.9.3 # Install deepspeed pip install deepspeed==0.16.9 # Install flash-attention pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl # Install chest2vec package pip install chest2vec ``` Or use the installation script: ```bash bash install_deps.sh ``` ## Requirements This model **requires FlashAttention-2** (CUDA) by default, which is automatically installed with the package. ## Quickstart ### Installation + Loading ```python from chest2vec import Chest2Vec # Load model from Hugging Face Hub m = Chest2Vec.from_pretrained("lukeingawesome/chest2vec_4b_chest", device="cuda:0") ``` ### Instruction + Query Embeddings ```python instructions = ["Find findings about the lungs."] queries = ["Consolidation in the right lower lobe."] out = m.embed_instruction_query(instructions, queries, max_len=512, batch_size=8) # Global embedding (last-token pooling) emb = out.embedding # [N, H] ``` ### Candidate Embeddings (Retrieval Bank) ```python candidates = [ "Lungs are clear. No focal consolidation.", "Pleural effusion on the left.", "Cardiomediastinal silhouette is normal." ] cand_out = m.embed_texts(candidates, max_len=512, batch_size=16) cand_emb = cand_out.embedding # [N, H] ``` ### Retrieval Example (Cosine Top-K) ```python # Query embeddings q = out.embedding # [Nq, H] # Document embeddings d = cand_out.embedding # [Nd, H] # Compute top-k cosine similarities scores, idx = Chest2Vec.cosine_topk(q, d, k=5, device="cuda") # scores: [Nq, k] - similarity scores # idx: [Nq, k] - indices of top-k candidates print(f"Top-5 scores: {scores[0]}") print(f"Top-5 indices: {idx[0]}") ``` ## API Reference ### `Chest2Vec.from_pretrained()` Load the model from Hugging Face Hub or local path. ```python m = Chest2Vec.from_pretrained( repo_id_or_path: str, # Hugging Face repo ID or local path device: str = "cuda:0", # Device to load model on use_4bit: bool = False, # Use 4-bit quantization force_flash_attention_2: bool = True ) ``` ### `embed_instruction_query()` Embed instruction-query pairs. Returns `EmbedOutput` with: - `embedding`: `[N, H]` - global embeddings (L2-normalized, last-token pooling) ```python out = m.embed_instruction_query( instructions: List[str], queries: List[str], max_len: int = 512, batch_size: int = 16 ) ``` ### `embed_texts()` Embed plain texts (for document/candidate encoding). ```python out = m.embed_texts( texts: List[str], max_len: int = 512, batch_size: int = 16 ) ``` Returns `EmbedOutput` with: - `embedding`: `[N, H]` - global embeddings (L2-normalized, last-token pooling) ### `cosine_topk()` Static method for efficient top-k cosine similarity search. ```python scores, idx = Chest2Vec.cosine_topk( query_emb: torch.Tensor, # [Nq, H] cand_emb: torch.Tensor, # [Nd, H] k: int = 10, device: str = "cuda" ) ``` ## Model Files - `chest2vec.py` - Model class and inference utilities - `chest2vec_config.json` - Model configuration - `contrastive/` - LoRA adapter directory - `adapter_config.json` - LoRA adapter configuration - `adapter_model.safetensors` - LoRA adapter weights ## Citation If you use this model, please cite: ```bibtex @misc{chest2vec_4b_chest, title={Chest2Vec: Global Embeddings for Chest X-Ray Reports}, author={Your Name}, year={2024}, howpublished={\url{https://huggingface.co/lukeingawesome/chest2vec_4b_chest}} } ``` ## License [Specify your license here] ## Usage (🤗 transformers) ```python from transformers import AutoModel # base Qwen3-Embedding weights download automatically; needs trust_remote_code model = AutoModel.from_pretrained("chest2vec/chest2vec_4B", trust_remote_code=True) emb = model.embed_texts([ "Frontal chest radiograph. No focal consolidation. No pneumothorax. Heart size normal.", ]) emb # [N, H] L2-normalized report embedding (last-token / EOS pooling) # similarity (emb[0] @ emb[1]) # cosine similarity (rows are unit-norm) ``` FlashAttention-2 is used automatically on CUDA when `flash-attn>=2` is installed (matching training); otherwise it falls back to SDPA so the model also loads on CPU.