lukeingawesome commited on
Commit
554d398
·
verified ·
1 Parent(s): 94d6510

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Build artifacts
2
+ dist/
3
+ build/
4
+ *.egg-info/
5
+ __pycache__/
6
+ *.pyc
7
+ *.pyo
8
+ *.pyd
9
+ .Python
10
+
11
+ # Testing
12
+ .pytest_cache/
13
+ .coverage
14
+ htmlcov/
15
+
16
+ # IDE
17
+ .vscode/
18
+ .idea/
19
+ *.swp
20
+ *.swo
21
+
22
+ # Jupyter
23
+ .ipynb_checkpoints/
README.md ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - text-embeddings
4
+ - retrieval
5
+ - radiology
6
+ - chest
7
+ - qwen
8
+ library_name: transformers
9
+ ---
10
+
11
+ # chest2vec_4b_chest
12
+
13
+ This repository contains the *delta weights* for a global embedding model on top of **Qwen/Qwen3-Embedding-4B**:
14
+
15
+ - **LoRA Adapter**: Contrastive LoRA adapter trained with multi-positive sigmoid loss under `./contrastive/`
16
+ - **Inference helper**: `chest2vec.py`
17
+
18
+ Base model weights are **not** included; they are downloaded from Hugging Face at runtime.
19
+
20
+ ## Model Architecture
21
+
22
+ Chest2Vec is a two-stage model:
23
+ 1. **Base**: Qwen/Qwen3-Embedding-4B (downloaded at runtime)
24
+ 2. **LoRA Adapter**: Contrastive LoRA adapter trained with multi-positive sigmoid loss
25
+ 3. **Pooling**: Last-token pooling (EOS token) for global embeddings
26
+
27
+ The model produces **global embeddings only** (no section-specific embeddings).
28
+
29
+ ## Installation
30
+
31
+ Install the package and all dependencies:
32
+
33
+ ```bash
34
+ # Install PyTorch with CUDA 12.6 support
35
+ pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
36
+
37
+ # Install transformers and trl
38
+ pip install transformers==4.57.3 trl==0.9.3
39
+
40
+ # Install deepspeed
41
+ pip install deepspeed==0.16.9
42
+
43
+ # Install flash-attention
44
+ pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
45
+
46
+ # Install chest2vec package
47
+ pip install chest2vec
48
+ ```
49
+
50
+ Or use the installation script:
51
+
52
+ ```bash
53
+ bash install_deps.sh
54
+ ```
55
+
56
+ ## Requirements
57
+
58
+ This model **requires FlashAttention-2** (CUDA) by default, which is automatically installed with the package.
59
+
60
+ ## Quickstart
61
+
62
+ ### Installation + Loading
63
+
64
+ ```python
65
+ from chest2vec import Chest2Vec
66
+
67
+ # Load model from Hugging Face Hub
68
+ m = Chest2Vec.from_pretrained("lukeingawesome/chest2vec_4b_chest", device="cuda:0")
69
+ ```
70
+
71
+ ### Instruction + Query Embeddings
72
+
73
+ ```python
74
+ instructions = ["Find findings about the lungs."]
75
+ queries = ["Consolidation in the right lower lobe."]
76
+
77
+ out = m.embed_instruction_query(instructions, queries, max_len=512, batch_size=8)
78
+
79
+ # Global embedding (last-token pooling)
80
+ emb = out.embedding # [N, H]
81
+ ```
82
+
83
+ ### Candidate Embeddings (Retrieval Bank)
84
+
85
+ ```python
86
+ candidates = [
87
+ "Lungs are clear. No focal consolidation.",
88
+ "Pleural effusion on the left.",
89
+ "Cardiomediastinal silhouette is normal."
90
+ ]
91
+
92
+ cand_out = m.embed_texts(candidates, max_len=512, batch_size=16)
93
+
94
+ cand_emb = cand_out.embedding # [N, H]
95
+ ```
96
+
97
+ ### Retrieval Example (Cosine Top-K)
98
+
99
+ ```python
100
+ # Query embeddings
101
+ q = out.embedding # [Nq, H]
102
+
103
+ # Document embeddings
104
+ d = cand_out.embedding # [Nd, H]
105
+
106
+ # Compute top-k cosine similarities
107
+ scores, idx = Chest2Vec.cosine_topk(q, d, k=5, device="cuda")
108
+ # scores: [Nq, k] - similarity scores
109
+ # idx: [Nq, k] - indices of top-k candidates
110
+
111
+ print(f"Top-5 scores: {scores[0]}")
112
+ print(f"Top-5 indices: {idx[0]}")
113
+ ```
114
+
115
+ ## API Reference
116
+
117
+ ### `Chest2Vec.from_pretrained()`
118
+
119
+ Load the model from Hugging Face Hub or local path.
120
+
121
+ ```python
122
+ m = Chest2Vec.from_pretrained(
123
+ repo_id_or_path: str, # Hugging Face repo ID or local path
124
+ device: str = "cuda:0", # Device to load model on
125
+ use_4bit: bool = False, # Use 4-bit quantization
126
+ force_flash_attention_2: bool = True
127
+ )
128
+ ```
129
+
130
+ ### `embed_instruction_query()`
131
+
132
+ Embed instruction-query pairs. Returns `EmbedOutput` with:
133
+ - `embedding`: `[N, H]` - global embeddings (L2-normalized, last-token pooling)
134
+
135
+ ```python
136
+ out = m.embed_instruction_query(
137
+ instructions: List[str],
138
+ queries: List[str],
139
+ max_len: int = 512,
140
+ batch_size: int = 16
141
+ )
142
+ ```
143
+
144
+ ### `embed_texts()`
145
+
146
+ Embed plain texts (for document/candidate encoding).
147
+
148
+ ```python
149
+ out = m.embed_texts(
150
+ texts: List[str],
151
+ max_len: int = 512,
152
+ batch_size: int = 16
153
+ )
154
+ ```
155
+
156
+ Returns `EmbedOutput` with:
157
+ - `embedding`: `[N, H]` - global embeddings (L2-normalized, last-token pooling)
158
+
159
+ ### `cosine_topk()`
160
+
161
+ Static method for efficient top-k cosine similarity search.
162
+
163
+ ```python
164
+ scores, idx = Chest2Vec.cosine_topk(
165
+ query_emb: torch.Tensor, # [Nq, H]
166
+ cand_emb: torch.Tensor, # [Nd, H]
167
+ k: int = 10,
168
+ device: str = "cuda"
169
+ )
170
+ ```
171
+
172
+ ## Model Files
173
+
174
+ - `chest2vec.py` - Model class and inference utilities
175
+ - `chest2vec_config.json` - Model configuration
176
+ - `contrastive/` - LoRA adapter directory
177
+ - `adapter_config.json` - LoRA adapter configuration
178
+ - `adapter_model.safetensors` - LoRA adapter weights
179
+
180
+ ## Citation
181
+
182
+ If you use this model, please cite:
183
+
184
+ ```bibtex
185
+ @misc{chest2vec_4b_chest,
186
+ title={Chest2Vec: Global Embeddings for Chest X-Ray Reports},
187
+ author={Your Name},
188
+ year={2024},
189
+ howpublished={\url{https://huggingface.co/lukeingawesome/chest2vec_4b_chest}}
190
+ }
191
+ ```
192
+
193
+ ## License
194
+
195
+ [Specify your license here]
__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from chest2vec import Chest2Vec, EmbedOutput
2
+
3
+ __all__ = ["Chest2Vec", "EmbedOutput"]
4
+ __version__ = "4.0.0"
chest2vec.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
10
+
11
+ try:
12
+ from peft import PeftModel
13
+ _HAS_PEFT = True
14
+ except Exception:
15
+ PeftModel = None
16
+ _HAS_PEFT = False
17
+
18
+ try:
19
+ from huggingface_hub import snapshot_download
20
+ _HAS_HUB = True
21
+ except Exception:
22
+ snapshot_download = None
23
+ _HAS_HUB = False
24
+
25
+
26
+ def require_flash_attention_2() -> str:
27
+ if not torch.cuda.is_available():
28
+ raise RuntimeError("FlashAttention-2 requires CUDA, but torch.cuda.is_available() is False.")
29
+ try:
30
+ import flash_attn # noqa: F401
31
+ ver = getattr(flash_attn, "__version__", "0.0.0")
32
+ major = int(str(ver).split(".")[0])
33
+ if major < 2:
34
+ raise RuntimeError(f"flash-attn version {ver} < 2.0.0")
35
+ except Exception as e:
36
+ raise RuntimeError(
37
+ "FlashAttention-2 is REQUIRED but not available/importable.\n"
38
+ "Install flash-attn>=2 and ensure it matches your torch/CUDA.\n"
39
+ f"Import/Version error: {repr(e)}"
40
+ )
41
+ return "flash_attention_2"
42
+
43
+
44
+ def build_qwen_query(instruction: str, query: str) -> str:
45
+ instruction = str(instruction).strip()
46
+ query = str(query).strip()
47
+ return f"Instruct: {instruction}\nQuery: {query}"
48
+
49
+
50
+ def get_pool_token_id(tok) -> int:
51
+ eod_id = tok.convert_tokens_to_ids("<|endoftext|>")
52
+ if eod_id is None or eod_id < 0:
53
+ eod_id = tok.pad_token_id
54
+ return eod_id
55
+
56
+
57
+ def encode_with_eos_ids(tok, texts: List[str], max_len: int) -> Dict[str, torch.Tensor]:
58
+ """
59
+ Encode texts with guaranteed <|endoftext|> at the end for Qwen3-Embedding pooling.
60
+ Reserves 1 slot for <|endoftext|>, left-pads to batch max length.
61
+ Must match the training tokenization exactly.
62
+ """
63
+ pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
64
+ eod_id = get_pool_token_id(tok)
65
+
66
+ # Reserve 1 position for <|endoftext|>
67
+ enc = tok(
68
+ [str(t) for t in texts],
69
+ add_special_tokens=False,
70
+ truncation=True,
71
+ max_length=max_len - 1,
72
+ padding=False,
73
+ return_attention_mask=False,
74
+ )
75
+
76
+ input_ids = [ids + [eod_id] for ids in enc["input_ids"]]
77
+ attn_mask = [[1] * len(ids) for ids in input_ids]
78
+
79
+ # Left-pad to batch max length
80
+ T = max(len(ids) for ids in input_ids) if input_ids else 1
81
+ input_ids = [[pad_id] * (T - len(ids)) + ids for ids in input_ids]
82
+ attn_mask = [[0] * (T - len(m)) + m for m in attn_mask]
83
+
84
+ return {
85
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
86
+ "attention_mask": torch.tensor(attn_mask, dtype=torch.long),
87
+ }
88
+
89
+
90
+ def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
91
+ """
92
+ Left-padding aware last-token pooling (extracts EOS token embedding).
93
+ """
94
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
95
+ if left_padding:
96
+ return last_hidden_states[:, -1]
97
+ idx = attention_mask.sum(dim=1) - 1
98
+ return last_hidden_states[torch.arange(last_hidden_states.size(0), device=last_hidden_states.device), idx]
99
+
100
+
101
+ def get_last_hidden_state(model, input_ids, attention_mask):
102
+ """
103
+ Get final hidden state from the model.
104
+ Provide position_ids for left padding (FlashAttention-2).
105
+ """
106
+ m = model.module if hasattr(model, "module") else model
107
+
108
+ # Compute position_ids for left-padded sequences (required for flash_attention_2)
109
+ position_ids = attention_mask.long().cumsum(-1) - 1
110
+ position_ids.masked_fill_(attention_mask == 0, 0)
111
+
112
+ out = m(
113
+ input_ids=input_ids,
114
+ attention_mask=attention_mask,
115
+ position_ids=position_ids,
116
+ use_cache=False,
117
+ return_dict=True,
118
+ )
119
+ if hasattr(out, "last_hidden_state"):
120
+ return out.last_hidden_state
121
+
122
+ out = m(
123
+ input_ids=input_ids,
124
+ attention_mask=attention_mask,
125
+ position_ids=position_ids,
126
+ output_hidden_states=True,
127
+ use_cache=False,
128
+ return_dict=True,
129
+ )
130
+ return out.hidden_states[-1]
131
+
132
+
133
+ def _read_json(path: str) -> Dict[str, Any]:
134
+ with open(path, "r", encoding="utf-8") as f:
135
+ return json.load(f)
136
+
137
+
138
+ def _resolve_repo_path(repo_id_or_path: str) -> str:
139
+ # If it's a local directory, use it as-is.
140
+ if os.path.isdir(repo_id_or_path):
141
+ return repo_id_or_path
142
+ # Otherwise treat as HF repo_id and download snapshot.
143
+ if not _HAS_HUB:
144
+ raise RuntimeError(
145
+ "huggingface_hub is required to load by repo_id. "
146
+ "Install it: pip install huggingface_hub"
147
+ )
148
+ return snapshot_download(repo_id_or_path)
149
+
150
+
151
+ @dataclass
152
+ class EmbedOutput:
153
+ """
154
+ Simplified output: only global embeddings (no section embeddings).
155
+ """
156
+ embedding: torch.Tensor # [N,H], float32 on CPU by default
157
+
158
+
159
+ class Chest2Vec:
160
+ """
161
+ Simplified wrapper for global embeddings only:
162
+ - loads base Qwen3-Embedding
163
+ - applies LoRA adapter
164
+ - returns global embeddings via last-token pooling
165
+ """
166
+ def __init__(self, tokenizer, model, device: torch.device):
167
+ self.tokenizer = tokenizer
168
+ self.model = model
169
+ self.device = device
170
+
171
+ self.model.eval()
172
+
173
+ @classmethod
174
+ def from_pretrained(
175
+ cls,
176
+ repo_id_or_path: str,
177
+ *,
178
+ device: str = "cuda:0",
179
+ use_4bit: bool = False,
180
+ force_flash_attention_2: bool = True,
181
+ ) -> "Chest2Vec":
182
+ repo_path = _resolve_repo_path(repo_id_or_path)
183
+
184
+ cfg_path = os.path.join(repo_path, "chest2vec_config.json")
185
+ if not os.path.isfile(cfg_path):
186
+ raise FileNotFoundError(f"Missing chest2vec_config.json in {repo_path}")
187
+ cfg = _read_json(cfg_path)
188
+
189
+ base_model = str(cfg["base_model"])
190
+ adapter_subdir = str(cfg.get("adapter_subdir", "contrastive"))
191
+
192
+ if force_flash_attention_2 or bool(cfg.get("require_flash_attention_2", False)):
193
+ attn_impl = require_flash_attention_2()
194
+ else:
195
+ attn_impl = "sdpa"
196
+
197
+ if not _HAS_PEFT:
198
+ raise RuntimeError("peft is required. Install: pip install peft")
199
+
200
+ device_t = torch.device(device)
201
+
202
+ tokenizer = AutoTokenizer.from_pretrained(base_model, padding_side="left", trust_remote_code=True)
203
+ if tokenizer.pad_token_id is None:
204
+ tokenizer.pad_token = tokenizer.eos_token
205
+
206
+ device_map = {"": str(device_t)}
207
+
208
+ # Load base model with FlashAttention-2
209
+ if use_4bit:
210
+ qconf = BitsAndBytesConfig(
211
+ load_in_4bit=True,
212
+ bnb_4bit_quant_type="nf4",
213
+ bnb_4bit_use_double_quant=True,
214
+ bnb_4bit_compute_dtype=torch.bfloat16,
215
+ )
216
+ try:
217
+ base = AutoModel.from_pretrained(
218
+ base_model,
219
+ trust_remote_code=True,
220
+ attn_implementation=attn_impl,
221
+ quantization_config=qconf,
222
+ device_map=device_map,
223
+ )
224
+ except TypeError as e:
225
+ raise RuntimeError(
226
+ "Your transformers version does not support attn_implementation=... "
227
+ "Upgrade transformers to use FlashAttention-2."
228
+ ) from e
229
+ else:
230
+ try:
231
+ base = AutoModel.from_pretrained(
232
+ base_model,
233
+ trust_remote_code=True,
234
+ attn_implementation=attn_impl,
235
+ torch_dtype=torch.bfloat16,
236
+ device_map=device_map,
237
+ )
238
+ except TypeError as e:
239
+ raise RuntimeError(
240
+ "Your transformers version does not support attn_implementation=... "
241
+ "Upgrade transformers to use FlashAttention-2."
242
+ ) from e
243
+
244
+ # Load adapter from this repo folder
245
+ adapter_dir = os.path.join(repo_path, adapter_subdir)
246
+ if not os.path.isfile(os.path.join(adapter_dir, "adapter_config.json")):
247
+ raise FileNotFoundError(f"adapter_config.json not found under: {adapter_dir}")
248
+
249
+ model = PeftModel.from_pretrained(base, adapter_dir)
250
+ model.eval()
251
+
252
+ return cls(tokenizer=tokenizer, model=model, device=device_t)
253
+
254
+ @torch.inference_mode()
255
+ def embed_texts(
256
+ self,
257
+ texts: List[str],
258
+ *,
259
+ max_len: int = 512,
260
+ batch_size: int = 16,
261
+ return_cpu_float32: bool = True,
262
+ ) -> EmbedOutput:
263
+ """
264
+ Encodes arbitrary texts and returns global embeddings via last-token pooling.
265
+
266
+ Returns:
267
+ - embedding: [N,H] - global embeddings (L2-normalized)
268
+ """
269
+ # Determine AMP
270
+ device = self.device
271
+ if device.type == "cuda":
272
+ amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
273
+ use_amp = True
274
+ else:
275
+ amp_dtype = torch.float32
276
+ use_amp = False
277
+
278
+ outs = []
279
+ for i in range(0, len(texts), batch_size):
280
+ chunk = [str(t) for t in texts[i:i + batch_size]]
281
+ enc = encode_with_eos_ids(self.tokenizer, chunk, max_len)
282
+ input_ids = enc["input_ids"].to(device, non_blocking=True)
283
+ attention_mask = enc["attention_mask"].to(device, non_blocking=True)
284
+
285
+ with torch.autocast(device_type=("cuda" if device.type == "cuda" else "cpu"),
286
+ dtype=amp_dtype, enabled=use_amp):
287
+ h = get_last_hidden_state(self.model, input_ids, attention_mask) # [B,T,H]
288
+
289
+ # Global embedding: extract EOS token embedding via last-token pooling
290
+ emb = last_token_pool(h, attention_mask) # [B,H]
291
+ emb = F.normalize(emb.float(), p=2, dim=-1)
292
+
293
+ outs.append(emb.detach())
294
+
295
+ embeddings = torch.cat(outs, dim=0) # on device, dtype ~ bf16
296
+
297
+ # Move to CPU float32 if requested (recommended for retrieval stability)
298
+ if return_cpu_float32:
299
+ embeddings_cpu = embeddings.float().cpu()
300
+ # re-normalize to fix any numerical drift
301
+ embeddings_cpu = F.normalize(embeddings_cpu, p=2, dim=-1)
302
+ else:
303
+ embeddings_cpu = embeddings
304
+
305
+ return EmbedOutput(embedding=embeddings_cpu)
306
+
307
+ @torch.inference_mode()
308
+ def embed_instruction_query(
309
+ self,
310
+ instructions: List[str],
311
+ queries: List[str],
312
+ *,
313
+ max_len: int = 512,
314
+ batch_size: int = 16,
315
+ return_cpu_float32: bool = True,
316
+ ) -> EmbedOutput:
317
+ """
318
+ Embed instruction-query pairs.
319
+
320
+ Returns:
321
+ - embedding: [N,H] - global embeddings (L2-normalized)
322
+ """
323
+ if len(instructions) != len(queries):
324
+ raise ValueError("instructions and queries must have the same length.")
325
+ q_texts = [build_qwen_query(i, q) for i, q in zip(instructions, queries)]
326
+ return self.embed_texts(
327
+ q_texts,
328
+ max_len=max_len,
329
+ batch_size=batch_size,
330
+ return_cpu_float32=return_cpu_float32,
331
+ )
332
+
333
+ @staticmethod
334
+ def cosine_topk(
335
+ query_emb: torch.Tensor, # [Nq,H] CPU float32 recommended
336
+ cand_emb: torch.Tensor, # [Nd,H] CPU float32 recommended
337
+ k: int = 10,
338
+ *,
339
+ device: str = "cuda",
340
+ query_batch_size: int = 256,
341
+ doc_chunk_size: int = 8192,
342
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
343
+ """
344
+ Chunked cosine top-k, stable in float32.
345
+ Returns (top_scores [Nq,k], top_indices [Nq,k]) on CPU.
346
+ """
347
+ device_t = torch.device(device)
348
+ q = F.normalize(query_emb.float(), p=2, dim=-1)
349
+ d = F.normalize(cand_emb.float(), p=2, dim=-1)
350
+ Nq, H = q.shape
351
+ Nd = d.shape[0]
352
+ k = min(int(k), Nd)
353
+
354
+ top_scores_all = torch.empty((Nq, k), dtype=torch.float32)
355
+ top_indices_all = torch.empty((Nq, k), dtype=torch.long)
356
+
357
+ for qs in range(0, Nq, query_batch_size):
358
+ qe = q[qs:qs + query_batch_size].to(device_t, non_blocking=True)
359
+ bq = qe.size(0)
360
+
361
+ top_scores = torch.full((bq, k), -1e9, device=device_t, dtype=torch.float32)
362
+ top_indices = torch.full((bq, k), -1, device=device_t, dtype=torch.long)
363
+
364
+ for ds in range(0, Nd, doc_chunk_size):
365
+ de = d[ds:ds + doc_chunk_size].to(device_t, non_blocking=True)
366
+ scores = (qe @ de.T).float()
367
+
368
+ chunk = scores.size(1)
369
+ idx_chunk = torch.arange(ds, ds + chunk, device=device_t, dtype=torch.long).unsqueeze(0).expand(bq, -1)
370
+
371
+ comb_scores = torch.cat([top_scores, scores], dim=1)
372
+ comb_idx = torch.cat([top_indices, idx_chunk], dim=1)
373
+
374
+ new_scores, new_pos = torch.topk(comb_scores, k, dim=1)
375
+ new_idx = comb_idx.gather(1, new_pos)
376
+
377
+ top_scores, top_indices = new_scores, new_idx
378
+
379
+ top_scores_all[qs:qs + bq] = top_scores.cpu()
380
+ top_indices_all[qs:qs + bq] = top_indices.cpu()
381
+
382
+ return top_scores_all, top_indices_all
chest2vec_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "chest2vec_4b_chest",
3
+ "base_model": "Qwen/Qwen3-Embedding-4B",
4
+ "adapter_subdir": "contrastive",
5
+ "require_flash_attention_2": true,
6
+ "default_max_len": 512,
7
+ "pooling": "last_token"
8
+ }
contrastive/adapter_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": {
4
+ "base_model_class": "Qwen3Model",
5
+ "parent_library": "transformers.models.qwen3.modeling_qwen3"
6
+ },
7
+ "base_model_name_or_path": "Qwen/Qwen3-Embedding-4B",
8
+ "bias": "none",
9
+ "corda_config": null,
10
+ "eva_config": null,
11
+ "exclude_modules": null,
12
+ "fan_in_fan_out": false,
13
+ "inference_mode": true,
14
+ "init_lora_weights": true,
15
+ "layer_replication": null,
16
+ "layers_pattern": null,
17
+ "layers_to_transform": null,
18
+ "loftq_config": {},
19
+ "lora_alpha": 32,
20
+ "lora_bias": false,
21
+ "lora_dropout": 0.1,
22
+ "megatron_config": null,
23
+ "megatron_core": "megatron.core",
24
+ "modules_to_save": null,
25
+ "peft_type": "LORA",
26
+ "r": 16,
27
+ "rank_pattern": {},
28
+ "revision": null,
29
+ "target_modules": [
30
+ "o_proj",
31
+ "v_proj",
32
+ "q_proj",
33
+ "up_proj",
34
+ "down_proj",
35
+ "gate_proj",
36
+ "k_proj"
37
+ ],
38
+ "task_type": null,
39
+ "trainable_token_indices": null,
40
+ "use_dora": false,
41
+ "use_rslora": false
42
+ }
contrastive/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd8ad38ebccd46fa13e2522e8ef23dc80aef48d29f70a780418cf5a1019d7243
3
+ size 66124752
install_deps.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Installation script for chest2vec dependencies
3
+ # This script installs PyTorch and flash-attention with the correct versions
4
+
5
+ set -e
6
+
7
+ echo "Installing PyTorch packages from custom index..."
8
+ pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
9
+
10
+ echo "Installing flash-attention from GitHub release..."
11
+ pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
12
+
13
+ echo "Installing chest2vec package..."
14
+ pip install chest2vec
15
+
16
+ echo "Installation complete!"
pyproject.toml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "chest2vec"
7
+ version = "4.0.0"
8
+ description = "Section-aware embeddings for chest X-ray reports"
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ dependencies = [
12
+ "transformers==4.57.3",
13
+ "trl==0.9.3",
14
+ "deepspeed==0.16.9",
15
+ "peft",
16
+ "huggingface_hub",
17
+ "bitsandbytes",
18
+ "accelerate",
19
+ "numpy",
20
+ ]
21
+
22
+ [project.urls]
23
+ Homepage = "https://github.com/chest2vec/chest2vec"
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.6.0
2
+ torchvision==0.21.0
3
+ torchaudio==2.6.0
4
+ transformers==4.57.3
5
+ trl==0.9.3
6
+ deepspeed==0.16.9
7
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
8
+ peft
9
+ huggingface_hub
10
+ bitsandbytes
11
+ accelerate
12
+ numpy
setup.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+ from pathlib import Path
3
+
4
+ # Read README for long description
5
+ readme_file = Path(__file__).parent / "README.md"
6
+ long_description = readme_file.read_text(encoding="utf-8") if readme_file.exists() else ""
7
+
8
+ setup(
9
+ name="chest2vec",
10
+ version="4.0.0",
11
+ description="Section-aware embeddings for chest X-ray reports",
12
+ long_description=long_description,
13
+ long_description_content_type="text/markdown",
14
+ author="Chest2Vec Team",
15
+ url="https://github.com/chest2vec/chest2vec",
16
+ packages=find_packages(),
17
+ py_modules=["chest2vec"],
18
+ include_package_data=True,
19
+ package_data={"": ["__init__.py"]},
20
+ install_requires=[
21
+ "transformers==4.57.3",
22
+ "trl==0.9.3",
23
+ "deepspeed==0.16.9",
24
+ "peft",
25
+ "huggingface_hub",
26
+ "bitsandbytes",
27
+ "accelerate",
28
+ "numpy",
29
+ ],
30
+ python_requires=">=3.8",
31
+ classifiers=[
32
+ "Development Status :: 4 - Beta",
33
+ "Intended Audience :: Developers",
34
+ "Intended Audience :: Science/Research",
35
+ "License :: OSI Approved :: Apache Software License",
36
+ "Programming Language :: Python :: 3",
37
+ "Programming Language :: Python :: 3.8",
38
+ "Programming Language :: Python :: 3.9",
39
+ "Programming Language :: Python :: 3.10",
40
+ "Programming Language :: Python :: 3.11",
41
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
42
+ ],
43
+ )