chest2vec_4B / README.md
lukeingawesome's picture
Add trust_remote_code integration (Qwen3-Embedding + LoRA)
c036088 verified
---
tags:
- text-embeddings
- retrieval
- radiology
- chest
- qwen
library_name: transformers
---
# chest2vec_4B
This repository contains the *delta weights* for a global embedding model on top of **Qwen/Qwen3-Embedding-4B**:
- **LoRA Adapter**: Contrastive LoRA adapter trained with multi-positive sigmoid loss under `./contrastive/`
- **Inference helper**: `chest2vec.py`
Base model weights are **not** included; they are downloaded from Hugging Face at runtime.
## Model Architecture
Chest2Vec is a two-stage model:
1. **Base**: Qwen/Qwen3-Embedding-4B (downloaded at runtime)
2. **LoRA Adapter**: Contrastive LoRA adapter trained with multi-positive sigmoid loss
3. **Pooling**: Last-token pooling (EOS token) for global embeddings
The model produces **global embeddings only** (no section-specific embeddings).
## Installation
Install the package and all dependencies:
```bash
# Install PyTorch with CUDA 12.6 support
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
# Install transformers and trl
pip install transformers==4.57.3 trl==0.9.3
# Install deepspeed
pip install deepspeed==0.16.9
# Install flash-attention
pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
# Install chest2vec package
pip install chest2vec
```
Or use the installation script:
```bash
bash install_deps.sh
```
## Requirements
This model **requires FlashAttention-2** (CUDA) by default, which is automatically installed with the package.
## Quickstart
### Installation + Loading
```python
from chest2vec import Chest2Vec
# Load model from Hugging Face Hub
m = Chest2Vec.from_pretrained("lukeingawesome/chest2vec_4b_chest", device="cuda:0")
```
### Instruction + Query Embeddings
```python
instructions = ["Find findings about the lungs."]
queries = ["Consolidation in the right lower lobe."]
out = m.embed_instruction_query(instructions, queries, max_len=512, batch_size=8)
# Global embedding (last-token pooling)
emb = out.embedding # [N, H]
```
### Candidate Embeddings (Retrieval Bank)
```python
candidates = [
"Lungs are clear. No focal consolidation.",
"Pleural effusion on the left.",
"Cardiomediastinal silhouette is normal."
]
cand_out = m.embed_texts(candidates, max_len=512, batch_size=16)
cand_emb = cand_out.embedding # [N, H]
```
### Retrieval Example (Cosine Top-K)
```python
# Query embeddings
q = out.embedding # [Nq, H]
# Document embeddings
d = cand_out.embedding # [Nd, H]
# Compute top-k cosine similarities
scores, idx = Chest2Vec.cosine_topk(q, d, k=5, device="cuda")
# scores: [Nq, k] - similarity scores
# idx: [Nq, k] - indices of top-k candidates
print(f"Top-5 scores: {scores[0]}")
print(f"Top-5 indices: {idx[0]}")
```
## API Reference
### `Chest2Vec.from_pretrained()`
Load the model from Hugging Face Hub or local path.
```python
m = Chest2Vec.from_pretrained(
repo_id_or_path: str, # Hugging Face repo ID or local path
device: str = "cuda:0", # Device to load model on
use_4bit: bool = False, # Use 4-bit quantization
force_flash_attention_2: bool = True
)
```
### `embed_instruction_query()`
Embed instruction-query pairs. Returns `EmbedOutput` with:
- `embedding`: `[N, H]` - global embeddings (L2-normalized, last-token pooling)
```python
out = m.embed_instruction_query(
instructions: List[str],
queries: List[str],
max_len: int = 512,
batch_size: int = 16
)
```
### `embed_texts()`
Embed plain texts (for document/candidate encoding).
```python
out = m.embed_texts(
texts: List[str],
max_len: int = 512,
batch_size: int = 16
)
```
Returns `EmbedOutput` with:
- `embedding`: `[N, H]` - global embeddings (L2-normalized, last-token pooling)
### `cosine_topk()`
Static method for efficient top-k cosine similarity search.
```python
scores, idx = Chest2Vec.cosine_topk(
query_emb: torch.Tensor, # [Nq, H]
cand_emb: torch.Tensor, # [Nd, H]
k: int = 10,
device: str = "cuda"
)
```
## Model Files
- `chest2vec.py` - Model class and inference utilities
- `chest2vec_config.json` - Model configuration
- `contrastive/` - LoRA adapter directory
- `adapter_config.json` - LoRA adapter configuration
- `adapter_model.safetensors` - LoRA adapter weights
## Citation
If you use this model, please cite:
```bibtex
@misc{chest2vec_4b_chest,
title={Chest2Vec: Global Embeddings for Chest X-Ray Reports},
author={Your Name},
year={2024},
howpublished={\url{https://huggingface.co/lukeingawesome/chest2vec_4b_chest}}
}
```
## License
[Specify your license here]
## Usage (🤗 transformers)
```python
from transformers import AutoModel
# base Qwen3-Embedding weights download automatically; needs trust_remote_code
model = AutoModel.from_pretrained("chest2vec/chest2vec_4B", trust_remote_code=True)
emb = model.embed_texts([
"Frontal chest radiograph. No focal consolidation. No pneumothorax. Heart size normal.",
])
emb # [N, H] L2-normalized report embedding (last-token / EOS pooling)
# similarity
(emb[0] @ emb[1]) # cosine similarity (rows are unit-norm)
```
FlashAttention-2 is used automatically on CUDA when `flash-attn>=2` is installed
(matching training); otherwise it falls back to SDPA so the model also loads on CPU.