File size: 5,632 Bytes
364cd6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
---
tags:
- text-embeddings
- retrieval
- radiology
- chest
- qwen
library_name: transformers
---

# chest2vec_0.6b_chest

This repository contains the *delta weights and pooling head* for a section-aware embedding model on top of **Qwen/Qwen3-Embedding-0.6B**:

- **Stage-2**: Frozen LoRA adapter (contrastive) under `./contrastive/`
- **Stage-3**: Section pooler `section_pooler.pt` producing **9 section embeddings**
- **Inference helper**: `chest2vec.py`

Base model weights are **not** included; they are downloaded from Hugging Face at runtime.

## Model Architecture

Chest2Vec is a three-stage model:
1. **Base**: Qwen/Qwen3-Embedding-0.6B (downloaded at runtime)
2. **Stage-2**: Contrastive LoRA adapter trained with multi-positive sigmoid loss
3. **Stage-3**: Section-aware query-attention pooler producing embeddings for 9 radiology report sections

## Sections

The model produces embeddings for 9 distinct sections:

1. Lungs and Airways
2. Pleura
3. Cardiovascular
4. Hila and Mediastinum
5. Tubes & Devices
6. Musculoskeletal and Chest Wall
7. Abdominal
8. impression
9. Other

## Installation

Install the package and all dependencies:

```bash
# Install PyTorch with CUDA 12.6 support
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126

# Install transformers and trl
pip install transformers==4.57.3 trl==0.9.3

# Install deepspeed
pip install deepspeed==0.16.9

# Install flash-attention
pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl

# Install chest2vec package
pip install chest2vec
```

Or use the installation script:

```bash
bash install_deps.sh
```

## Requirements

This model **requires FlashAttention-2** (CUDA) by default, which is automatically installed with the package.

## Quickstart

### Installation + Loading

```python
from chest2vec import Chest2Vec

# Load model from Hugging Face Hub
m = Chest2Vec.from_pretrained("chest2vec/chest2vec_0.6b_chest", device="cuda:0")
```

### Instruction + Query Embeddings

```python
instructions = ["Find findings about the lungs."]
queries = ["Consolidation in the right lower lobe."]

out = m.embed_instruction_query(instructions, queries, max_len=512, batch_size=8)

# Global embedding (derived): mean of 9 section vectors then L2-normalized
g = out.global_embedding                 # [N, H]

# Per-section embeddings (by full name)
lung = out.by_section_name["Lungs and Airways"]  # [N, H]
imp  = out.by_section_name["impression"]          # [N, H]

# Or use aliases (case-insensitive)
lung = out.by_alias["lungs"]   # [N, H]
cardio = out.by_alias["cardio"] # [N, H]
```

### Candidate Embeddings (Retrieval Bank)

```python
candidates = [
    "Lungs are clear. No focal consolidation.",
    "Pleural effusion on the left.",
    "Cardiomediastinal silhouette is normal."
]

cand_out = m.embed_texts(candidates, max_len=512, batch_size=16)

cand_global = cand_out.global_embedding  # [N, H]
cand_lung   = cand_out.by_alias["lungs"]  # [N, H]
```

### Retrieval Example (Cosine Top-K)

```python
# Query embeddings for "Lungs and Airways" section
q = out.by_alias["lungs"]       # [Nq, H]

# Document embeddings for "Lungs and Airways" section
d = cand_out.by_alias["lungs"]  # [Nd, H]

# Compute top-k cosine similarities
scores, idx = Chest2Vec.cosine_topk(q, d, k=5, device="cuda")
# scores: [Nq, k] - similarity scores
# idx: [Nq, k] - indices of top-k candidates

print(f"Top-5 scores: {scores[0]}")
print(f"Top-5 indices: {idx[0]}")
```

## API Reference

### `Chest2Vec.from_pretrained()`

Load the model from Hugging Face Hub or local path.

```python
m = Chest2Vec.from_pretrained(
    repo_id_or_path: str,      # Hugging Face repo ID or local path
    device: str = "cuda:0",    # Device to load model on
    use_4bit: bool = False,    # Use 4-bit quantization
    force_flash_attention_2: bool = True
)
```

### `embed_instruction_query()`

Embed instruction-query pairs. Returns `EmbedOutput` with:
- `section_matrix`: `[N, 9, H]` - embeddings for all 9 sections
- `global_embedding`: `[N, H]` - global embedding (mean of sections, L2-normalized)
- `by_section_name`: Dict mapping full section names to `[N, H]` tensors
- `by_alias`: Dict mapping aliases to `[N, H]` tensors

```python
out = m.embed_instruction_query(
    instructions: List[str],
    queries: List[str],
    max_len: int = 512,
    batch_size: int = 16
)
```

### `embed_texts()`

Embed plain texts (for document/candidate encoding).

```python
out = m.embed_texts(
    texts: List[str],
    max_len: int = 512,
    batch_size: int = 16
)
```

### `cosine_topk()`

Static method for efficient top-k cosine similarity search.

```python
scores, idx = Chest2Vec.cosine_topk(
    query_emb: torch.Tensor,  # [Nq, H]
    cand_emb: torch.Tensor,   # [Nd, H]
    k: int = 10,
    device: str = "cuda"
)
```

## Model Files

- `chest2vec.py` - Model class and inference utilities
- `chest2vec_config.json` - Model configuration
- `section_pooler.pt` - Stage-3 pooler weights
- `section_pooler_config.json` - Pooler configuration
- `contrastive/` - Stage-2 LoRA adapter directory
  - `adapter_config.json` - LoRA adapter configuration
  - `adapter_model.safetensors` - LoRA adapter weights

## Citation

If you use this model, please cite:

```bibtex
@misc{chest2vec_0.6b_chest,
  title={Chest2Vec: Section-Aware Embeddings for Chest X-Ray Reports},
  author={Your Name},
  year={2024},
  howpublished={\url{https://huggingface.co/chest2vec/chest2vec_0.6b_chest}}
}
```

## License

[Specify your license here]