|
|
--- |
|
|
tags: |
|
|
- text-embeddings |
|
|
- retrieval |
|
|
- radiology |
|
|
- chest |
|
|
- qwen |
|
|
library_name: transformers |
|
|
--- |
|
|
|
|
|
# chest2vec_0.6b_chest |
|
|
|
|
|
This repository contains the *delta weights and pooling head* for a section-aware embedding model on top of **Qwen/Qwen3-Embedding-0.6B**: |
|
|
|
|
|
- **Stage-2**: Frozen LoRA adapter (contrastive) under `./contrastive/` |
|
|
- **Stage-3**: Section pooler `section_pooler.pt` producing **9 section embeddings** |
|
|
- **Inference helper**: `chest2vec.py` |
|
|
|
|
|
Base model weights are **not** included; they are downloaded from Hugging Face at runtime. |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
Chest2Vec is a three-stage model: |
|
|
1. **Base**: Qwen/Qwen3-Embedding-0.6B (downloaded at runtime) |
|
|
2. **Stage-2**: Contrastive LoRA adapter trained with multi-positive sigmoid loss |
|
|
3. **Stage-3**: Section-aware query-attention pooler producing embeddings for 9 radiology report sections |
|
|
|
|
|
## Sections |
|
|
|
|
|
The model produces embeddings for 9 distinct sections: |
|
|
|
|
|
1. Lungs and Airways |
|
|
2. Pleura |
|
|
3. Cardiovascular |
|
|
4. Hila and Mediastinum |
|
|
5. Tubes & Devices |
|
|
6. Musculoskeletal and Chest Wall |
|
|
7. Abdominal |
|
|
8. impression |
|
|
9. Other |
|
|
|
|
|
## Installation |
|
|
|
|
|
Install the package and all dependencies: |
|
|
|
|
|
```bash |
|
|
# Install PyTorch with CUDA 12.6 support |
|
|
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126 |
|
|
|
|
|
# Install transformers and trl |
|
|
pip install transformers==4.57.3 trl==0.9.3 |
|
|
|
|
|
# Install deepspeed |
|
|
pip install deepspeed==0.16.9 |
|
|
|
|
|
# Install flash-attention |
|
|
pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl |
|
|
|
|
|
# Install chest2vec package |
|
|
pip install chest2vec |
|
|
``` |
|
|
|
|
|
Or use the installation script: |
|
|
|
|
|
```bash |
|
|
bash install_deps.sh |
|
|
``` |
|
|
|
|
|
## Requirements |
|
|
|
|
|
This model **requires FlashAttention-2** (CUDA) by default, which is automatically installed with the package. |
|
|
|
|
|
## Quickstart |
|
|
|
|
|
### Installation + Loading |
|
|
|
|
|
```python |
|
|
from chest2vec import Chest2Vec |
|
|
|
|
|
# Load model from Hugging Face Hub |
|
|
m = Chest2Vec.from_pretrained("chest2vec/chest2vec_0.6b_chest", device="cuda:0") |
|
|
``` |
|
|
|
|
|
### Instruction + Query Embeddings |
|
|
|
|
|
```python |
|
|
instructions = ["Find findings about the lungs."] |
|
|
queries = ["Consolidation in the right lower lobe."] |
|
|
|
|
|
out = m.embed_instruction_query(instructions, queries, max_len=512, batch_size=8) |
|
|
|
|
|
# Global embedding (derived): mean of 9 section vectors then L2-normalized |
|
|
g = out.global_embedding # [N, H] |
|
|
|
|
|
# Per-section embeddings (by full name) |
|
|
lung = out.by_section_name["Lungs and Airways"] # [N, H] |
|
|
imp = out.by_section_name["impression"] # [N, H] |
|
|
|
|
|
# Or use aliases (case-insensitive) |
|
|
lung = out.by_alias["lungs"] # [N, H] |
|
|
cardio = out.by_alias["cardio"] # [N, H] |
|
|
``` |
|
|
|
|
|
### Candidate Embeddings (Retrieval Bank) |
|
|
|
|
|
```python |
|
|
candidates = [ |
|
|
"Lungs are clear. No focal consolidation.", |
|
|
"Pleural effusion on the left.", |
|
|
"Cardiomediastinal silhouette is normal." |
|
|
] |
|
|
|
|
|
cand_out = m.embed_texts(candidates, max_len=512, batch_size=16) |
|
|
|
|
|
cand_global = cand_out.global_embedding # [N, H] |
|
|
cand_lung = cand_out.by_alias["lungs"] # [N, H] |
|
|
``` |
|
|
|
|
|
### Retrieval Example (Cosine Top-K) |
|
|
|
|
|
```python |
|
|
# Query embeddings for "Lungs and Airways" section |
|
|
q = out.by_alias["lungs"] # [Nq, H] |
|
|
|
|
|
# Document embeddings for "Lungs and Airways" section |
|
|
d = cand_out.by_alias["lungs"] # [Nd, H] |
|
|
|
|
|
# Compute top-k cosine similarities |
|
|
scores, idx = Chest2Vec.cosine_topk(q, d, k=5, device="cuda") |
|
|
# scores: [Nq, k] - similarity scores |
|
|
# idx: [Nq, k] - indices of top-k candidates |
|
|
|
|
|
print(f"Top-5 scores: {scores[0]}") |
|
|
print(f"Top-5 indices: {idx[0]}") |
|
|
``` |
|
|
|
|
|
## API Reference |
|
|
|
|
|
### `Chest2Vec.from_pretrained()` |
|
|
|
|
|
Load the model from Hugging Face Hub or local path. |
|
|
|
|
|
```python |
|
|
m = Chest2Vec.from_pretrained( |
|
|
repo_id_or_path: str, # Hugging Face repo ID or local path |
|
|
device: str = "cuda:0", # Device to load model on |
|
|
use_4bit: bool = False, # Use 4-bit quantization |
|
|
force_flash_attention_2: bool = True |
|
|
) |
|
|
``` |
|
|
|
|
|
### `embed_instruction_query()` |
|
|
|
|
|
Embed instruction-query pairs. Returns `EmbedOutput` with: |
|
|
- `section_matrix`: `[N, 9, H]` - embeddings for all 9 sections |
|
|
- `global_embedding`: `[N, H]` - global embedding (mean of sections, L2-normalized) |
|
|
- `by_section_name`: Dict mapping full section names to `[N, H]` tensors |
|
|
- `by_alias`: Dict mapping aliases to `[N, H]` tensors |
|
|
|
|
|
```python |
|
|
out = m.embed_instruction_query( |
|
|
instructions: List[str], |
|
|
queries: List[str], |
|
|
max_len: int = 512, |
|
|
batch_size: int = 16 |
|
|
) |
|
|
``` |
|
|
|
|
|
### `embed_texts()` |
|
|
|
|
|
Embed plain texts (for document/candidate encoding). |
|
|
|
|
|
```python |
|
|
out = m.embed_texts( |
|
|
texts: List[str], |
|
|
max_len: int = 512, |
|
|
batch_size: int = 16 |
|
|
) |
|
|
``` |
|
|
|
|
|
### `cosine_topk()` |
|
|
|
|
|
Static method for efficient top-k cosine similarity search. |
|
|
|
|
|
```python |
|
|
scores, idx = Chest2Vec.cosine_topk( |
|
|
query_emb: torch.Tensor, # [Nq, H] |
|
|
cand_emb: torch.Tensor, # [Nd, H] |
|
|
k: int = 10, |
|
|
device: str = "cuda" |
|
|
) |
|
|
``` |
|
|
|
|
|
## Model Files |
|
|
|
|
|
- `chest2vec.py` - Model class and inference utilities |
|
|
- `chest2vec_config.json` - Model configuration |
|
|
- `section_pooler.pt` - Stage-3 pooler weights |
|
|
- `section_pooler_config.json` - Pooler configuration |
|
|
- `contrastive/` - Stage-2 LoRA adapter directory |
|
|
- `adapter_config.json` - LoRA adapter configuration |
|
|
- `adapter_model.safetensors` - LoRA adapter weights |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite: |
|
|
|
|
|
```bibtex |
|
|
@misc{chest2vec_0.6b_chest, |
|
|
title={Chest2Vec: Section-Aware Embeddings for Chest X-Ray Reports}, |
|
|
author={Your Name}, |
|
|
year={2024}, |
|
|
howpublished={\url{https://huggingface.co/chest2vec/chest2vec_0.6b_chest}} |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
[Specify your license here] |
|
|
|
|
|
|