lukeingawesome commited on
Commit
414d633
·
1 Parent(s): f6c76bd

Resolve conflicts after stash pop

Browse files
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +34,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
 
34
  *.zip filter=lfs diff=lfs merge=lfs -text
35
  *.zst filter=lfs diff=lfs merge=lfs -text
36
  *tfevents* filter=lfs diff=lfs merge=lfs -text
37
+ =======
38
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
39
+ *.pt filter=lfs diff=lfs merge=lfs -text
40
+ *.bin filter=lfs diff=lfs merge=lfs -text
41
+ >>>>>>> 2dc7409 (Release chest2vec_0.6b_cxr (Stage2 LoRA + Stage3 pooler + API))
README.md ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - text-embeddings
4
+ - retrieval
5
+ - radiology
6
+ - cxr
7
+ - qwen
8
+ library_name: transformers
9
+ ---
10
+
11
+ # chest2vec_0.6b_cxr
12
+
13
+ This repository contains the *delta weights and pooling head* for a section-aware embedding model on top of **Qwen/Qwen3-Embedding-0.6B**:
14
+
15
+ - **Stage-2**: Frozen LoRA adapter (contrastive) under `./contrastive/`
16
+ - **Stage-3**: Section pooler `section_pooler.pt` producing **9 section embeddings**
17
+ - **Inference helper**: `chest2vec.py`
18
+
19
+ Base model weights are **not** included; they are downloaded from Hugging Face at runtime.
20
+
21
+ ## Model Architecture
22
+
23
+ Chest2Vec is a three-stage model:
24
+ 1. **Base**: Qwen/Qwen3-Embedding-0.6B (downloaded at runtime)
25
+ 2. **Stage-2**: Contrastive LoRA adapter trained with multi-positive sigmoid loss
26
+ 3. **Stage-3**: Section-aware query-attention pooler producing embeddings for 9 radiology report sections
27
+
28
+ ## Sections
29
+
30
+ The model produces embeddings for 9 distinct sections:
31
+
32
+ 1. Lungs and Airways
33
+ 2. Pleura
34
+ 3. Cardiovascular
35
+ 4. Hila and Mediastinum
36
+ 5. Tubes & Devices
37
+ 6. Musculoskeletal and Chest Wall
38
+ 7. Abdominal
39
+ 8. impression
40
+ 9. Other
41
+
42
+ ## Requirements
43
+
44
+ This model **requires FlashAttention-2** (CUDA) by default.
45
+
46
+ ```bash
47
+ pip install -U torch transformers peft huggingface_hub
48
+ pip install flash-attn --no-build-isolation
49
+ ```
50
+
51
+ ## Quickstart
52
+
53
+ ### Installation + Loading
54
+
55
+ ```python
56
+ from chest2vec import Chest2Vec
57
+
58
+ # Load model from Hugging Face Hub
59
+ m = Chest2Vec.from_pretrained("chest2vec/chest2vec_0.6b_cxr", device="cuda:0")
60
+ ```
61
+
62
+ ### Instruction + Query Embeddings
63
+
64
+ ```python
65
+ instructions = ["Find findings about the lungs."]
66
+ queries = ["Consolidation in the right lower lobe."]
67
+
68
+ out = m.embed_instruction_query(instructions, queries, max_len=512, batch_size=8)
69
+
70
+ # Global embedding (derived): mean of 9 section vectors then L2-normalized
71
+ g = out.global_embedding # [N, H]
72
+
73
+ # Per-section embeddings (by full name)
74
+ lung = out.by_section_name["Lungs and Airways"] # [N, H]
75
+ imp = out.by_section_name["impression"] # [N, H]
76
+
77
+ # Or use aliases (case-insensitive)
78
+ lung = out.by_alias["lungs"] # [N, H]
79
+ cardio = out.by_alias["cardio"] # [N, H]
80
+ ```
81
+
82
+ ### Candidate Embeddings (Retrieval Bank)
83
+
84
+ ```python
85
+ candidates = [
86
+ "Lungs are clear. No focal consolidation.",
87
+ "Pleural effusion on the left.",
88
+ "Cardiomediastinal silhouette is normal."
89
+ ]
90
+
91
+ cand_out = m.embed_texts(candidates, max_len=512, batch_size=16)
92
+
93
+ cand_global = cand_out.global_embedding # [N, H]
94
+ cand_lung = cand_out.by_alias["lungs"] # [N, H]
95
+ ```
96
+
97
+ ### Retrieval Example (Cosine Top-K)
98
+
99
+ ```python
100
+ # Query embeddings for "Lungs and Airways" section
101
+ q = out.by_alias["lungs"] # [Nq, H]
102
+
103
+ # Document embeddings for "Lungs and Airways" section
104
+ d = cand_out.by_alias["lungs"] # [Nd, H]
105
+
106
+ # Compute top-k cosine similarities
107
+ scores, idx = Chest2Vec.cosine_topk(q, d, k=5, device="cuda")
108
+ # scores: [Nq, k] - similarity scores
109
+ # idx: [Nq, k] - indices of top-k candidates
110
+
111
+ print(f"Top-5 scores: {scores[0]}")
112
+ print(f"Top-5 indices: {idx[0]}")
113
+ ```
114
+
115
+ ## API Reference
116
+
117
+ ### `Chest2Vec.from_pretrained()`
118
+
119
+ Load the model from Hugging Face Hub or local path.
120
+
121
+ ```python
122
+ m = Chest2Vec.from_pretrained(
123
+ repo_id_or_path: str, # Hugging Face repo ID or local path
124
+ device: str = "cuda:0", # Device to load model on
125
+ use_4bit: bool = False, # Use 4-bit quantization
126
+ force_flash_attention_2: bool = True
127
+ )
128
+ ```
129
+
130
+ ### `embed_instruction_query()`
131
+
132
+ Embed instruction-query pairs. Returns `EmbedOutput` with:
133
+ - `section_matrix`: `[N, 9, H]` - embeddings for all 9 sections
134
+ - `global_embedding`: `[N, H]` - global embedding (mean of sections, L2-normalized)
135
+ - `by_section_name`: Dict mapping full section names to `[N, H]` tensors
136
+ - `by_alias`: Dict mapping aliases to `[N, H]` tensors
137
+
138
+ ```python
139
+ out = m.embed_instruction_query(
140
+ instructions: List[str],
141
+ queries: List[str],
142
+ max_len: int = 512,
143
+ batch_size: int = 16
144
+ )
145
+ ```
146
+
147
+ ### `embed_texts()`
148
+
149
+ Embed plain texts (for document/candidate encoding).
150
+
151
+ ```python
152
+ out = m.embed_texts(
153
+ texts: List[str],
154
+ max_len: int = 512,
155
+ batch_size: int = 16
156
+ )
157
+ ```
158
+
159
+ ### `cosine_topk()`
160
+
161
+ Static method for efficient top-k cosine similarity search.
162
+
163
+ ```python
164
+ scores, idx = Chest2Vec.cosine_topk(
165
+ query_emb: torch.Tensor, # [Nq, H]
166
+ cand_emb: torch.Tensor, # [Nd, H]
167
+ k: int = 10,
168
+ device: str = "cuda"
169
+ )
170
+ ```
171
+
172
+ ## Model Files
173
+
174
+ - `chest2vec.py` - Model class and inference utilities
175
+ - `chest2vec_config.json` - Model configuration
176
+ - `section_pooler.pt` - Stage-3 pooler weights
177
+ - `section_pooler_config.json` - Pooler configuration
178
+ - `contrastive/` - Stage-2 LoRA adapter directory
179
+ - `adapter_config.json` - LoRA adapter configuration
180
+ - `adapter_model.safetensors` - LoRA adapter weights
181
+
182
+ ## Citation
183
+
184
+ If you use this model, please cite:
185
+
186
+ ```bibtex
187
+ @misc{chest2vec_0.6b_cxr,
188
+ title={Chest2Vec: Section-Aware Embeddings for Chest X-Ray Reports},
189
+ author={Your Name},
190
+ year={2024},
191
+ howpublished={\url{https://huggingface.co/chest2vec/chest2vec_0.6b_cxr}}
192
+ }
193
+ ```
194
+
195
+ ## License
196
+
197
+ [Specify your license here]
chest2vec.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Optional, Union, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
12
+
13
+ try:
14
+ from peft import PeftModel
15
+ _HAS_PEFT = True
16
+ except Exception:
17
+ PeftModel = None
18
+ _HAS_PEFT = False
19
+
20
+ try:
21
+ from huggingface_hub import snapshot_download
22
+ _HAS_HUB = True
23
+ except Exception:
24
+ snapshot_download = None
25
+ _HAS_HUB = False
26
+
27
+
28
+ # -----------------------------
29
+ # Sections (must match training)
30
+ # -----------------------------
31
+ SECTION_NAMES = [
32
+ "Lungs and Airways",
33
+ "Pleura",
34
+ "Cardiovascular",
35
+ "Hila and Mediastinum",
36
+ "Tubes & Devices",
37
+ "Musculoskeletal and Chest Wall",
38
+ "Abdominal",
39
+ "impression",
40
+ "Other",
41
+ ]
42
+
43
+ SECTION_ALIASES = {
44
+ "global": "global",
45
+ "lungs": "Lungs and Airways",
46
+ "lung": "Lungs and Airways",
47
+ "pleura": "Pleura",
48
+ "cardio": "Cardiovascular",
49
+ "cardiovascular": "Cardiovascular",
50
+ "hila": "Hila and Mediastinum",
51
+ "mediastinum": "Hila and Mediastinum",
52
+ "tubes": "Tubes & Devices",
53
+ "devices": "Tubes & Devices",
54
+ "msk": "Musculoskeletal and Chest Wall",
55
+ "musculoskeletal": "Musculoskeletal and Chest Wall",
56
+ "abd": "Abdominal",
57
+ "abdominal": "Abdominal",
58
+ "impression": "impression",
59
+ "other": "Other",
60
+ }
61
+
62
+
63
+ def require_flash_attention_2() -> str:
64
+ if not torch.cuda.is_available():
65
+ raise RuntimeError("FlashAttention-2 requires CUDA, but torch.cuda.is_available() is False.")
66
+ try:
67
+ import flash_attn # noqa: F401
68
+ ver = getattr(flash_attn, "__version__", "0.0.0")
69
+ major = int(str(ver).split(".")[0])
70
+ if major < 2:
71
+ raise RuntimeError(f"flash-attn version {ver} < 2.0.0")
72
+ except Exception as e:
73
+ raise RuntimeError(
74
+ "FlashAttention-2 is REQUIRED but not available/importable.\n"
75
+ "Install flash-attn>=2 and ensure it matches your torch/CUDA.\n"
76
+ f"Import/Version error: {repr(e)}"
77
+ )
78
+ return "flash_attention_2"
79
+
80
+
81
+ def build_qwen_query(instruction: str, query: str) -> str:
82
+ instruction = str(instruction).strip()
83
+ query = str(query).strip()
84
+ return f"Instruct: {instruction}\nQuery: {query}"
85
+
86
+
87
+ def get_pool_token_id(tok) -> int:
88
+ eod_id = tok.convert_tokens_to_ids("<|endoftext|>")
89
+ if eod_id is None or eod_id < 0:
90
+ eod_id = tok.pad_token_id
91
+ return eod_id
92
+
93
+
94
+ def encode_with_eos_ids(tok, texts: List[str], max_len: int) -> Dict[str, torch.Tensor]:
95
+ """
96
+ Must match Stage-3 training:
97
+ - add_special_tokens=False
98
+ - truncation to max_len-1
99
+ - append <|endoftext|>
100
+ - left-pad
101
+ """
102
+ pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
103
+ eod_id = get_pool_token_id(tok)
104
+
105
+ enc = tok(
106
+ [str(t) for t in texts],
107
+ add_special_tokens=False,
108
+ truncation=True,
109
+ max_length=max_len - 1,
110
+ padding=False,
111
+ return_attention_mask=False,
112
+ )
113
+
114
+ input_ids = [ids + [eod_id] for ids in enc["input_ids"]]
115
+ attn_mask = [[1] * len(ids) for ids in input_ids]
116
+
117
+ T = max(len(ids) for ids in input_ids) if input_ids else 1
118
+ input_ids = [[pad_id] * (T - len(ids)) + ids for ids in input_ids]
119
+ attn_mask = [[0] * (T - len(m)) + m for m in attn_mask]
120
+
121
+ return {
122
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
123
+ "attention_mask": torch.tensor(attn_mask, dtype=torch.long),
124
+ }
125
+
126
+
127
+ def get_last_hidden_state(model, input_ids, attention_mask):
128
+ """
129
+ Provide position_ids for left padding (FlashAttention-2).
130
+ """
131
+ m = model.module if hasattr(model, "module") else model
132
+
133
+ position_ids = attention_mask.long().cumsum(-1) - 1
134
+ position_ids.masked_fill_(attention_mask == 0, 0)
135
+
136
+ out = m(
137
+ input_ids=input_ids,
138
+ attention_mask=attention_mask,
139
+ position_ids=position_ids,
140
+ use_cache=False,
141
+ return_dict=True,
142
+ )
143
+ if hasattr(out, "last_hidden_state"):
144
+ return out.last_hidden_state
145
+
146
+ out = m(
147
+ input_ids=input_ids,
148
+ attention_mask=attention_mask,
149
+ position_ids=position_ids,
150
+ output_hidden_states=True,
151
+ use_cache=False,
152
+ return_dict=True,
153
+ )
154
+ return out.hidden_states[-1]
155
+
156
+
157
+ # -----------------------------
158
+ # Stage-3 pooler (query_attn)
159
+ # -----------------------------
160
+ class SectionQueryAttnPooler(nn.Module):
161
+ """
162
+ Match your Stage-3 training pooler.
163
+ """
164
+ def __init__(
165
+ self,
166
+ hidden_size: int,
167
+ num_sections: int,
168
+ mlp_hidden: int,
169
+ use_layernorm: bool = True,
170
+ pool_dropout: float = 0.1,
171
+ pool_scale: float = 0.0, # 0 => 1/sqrt(H)
172
+ ):
173
+ super().__init__()
174
+ self.hidden_size = int(hidden_size)
175
+ self.num_sections = int(num_sections)
176
+
177
+ self.ln = nn.LayerNorm(self.hidden_size) if use_layernorm else nn.Identity()
178
+
179
+ self.pool_queries = nn.Parameter(torch.empty(self.num_sections, self.hidden_size))
180
+ nn.init.normal_(self.pool_queries, mean=0.0, std=0.02)
181
+
182
+ self.pool_scale = float(pool_scale) if (pool_scale and pool_scale > 0) else (1.0 / math.sqrt(self.hidden_size))
183
+ self.pool_dropout = nn.Dropout(pool_dropout) if pool_dropout and pool_dropout > 0 else nn.Identity()
184
+
185
+ # Bias-free MLP
186
+ self.mlp = nn.Sequential(
187
+ nn.Linear(self.hidden_size, int(mlp_hidden), bias=False),
188
+ nn.GELU(),
189
+ nn.Linear(int(mlp_hidden), self.hidden_size, bias=False),
190
+ )
191
+
192
+ def forward_all(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
193
+ # hidden_states: [B,T,H] -> [B,S,H]
194
+ if isinstance(self.ln, nn.LayerNorm):
195
+ x = F.layer_norm(
196
+ hidden_states.float(),
197
+ self.ln.normalized_shape,
198
+ self.ln.weight.float() if self.ln.weight is not None else None,
199
+ self.ln.bias.float() if self.ln.bias is not None else None,
200
+ self.ln.eps,
201
+ ).to(dtype=hidden_states.dtype)
202
+ else:
203
+ x = hidden_states
204
+
205
+ scores = torch.einsum("bth,sh->bts", x.float(), self.pool_queries.float()) * self.pool_scale
206
+ scores = scores.masked_fill(attention_mask.unsqueeze(-1) == 0, -1e4)
207
+
208
+ attn = torch.softmax(scores, dim=1).to(dtype=x.dtype) # [B,T,S]
209
+ attn = self.pool_dropout(attn)
210
+
211
+ pooled = torch.einsum("bth,bts->bsh", x, attn) # [B,S,H]
212
+ pooled = pooled.to(dtype=next(self.mlp.parameters()).dtype)
213
+ pooled = self.mlp(pooled)
214
+
215
+ return F.normalize(pooled, p=2, dim=-1)
216
+
217
+
218
+ def _ensure_pooler_device_dtype(pooler: nn.Module, device: torch.device, dtype: torch.dtype) -> None:
219
+ p = next(pooler.parameters(), None)
220
+ if p is None:
221
+ return
222
+ if p.device != device or p.dtype != dtype:
223
+ pooler.to(device=device, dtype=dtype)
224
+
225
+
226
+ def _read_json(path: str) -> Dict[str, Any]:
227
+ with open(path, "r", encoding="utf-8") as f:
228
+ return json.load(f)
229
+
230
+
231
+ def _resolve_repo_path(repo_id_or_path: str) -> str:
232
+ # If it's a local directory, use it as-is.
233
+ if os.path.isdir(repo_id_or_path):
234
+ return repo_id_or_path
235
+ # Otherwise treat as HF repo_id and download snapshot.
236
+ if not _HAS_HUB:
237
+ raise RuntimeError(
238
+ "huggingface_hub is required to load by repo_id. "
239
+ "Install it: pip install huggingface_hub"
240
+ )
241
+ return snapshot_download(repo_id_or_path)
242
+
243
+
244
+ @dataclass
245
+ class EmbedOutput:
246
+ # Always available:
247
+ section_matrix: torch.Tensor # [N,S,H], float32 on CPU by default
248
+ global_embedding: torch.Tensor # [N,H], float32 on CPU by default
249
+ # Convenience dicts:
250
+ by_section_name: Dict[str, torch.Tensor] # each [N,H]
251
+ by_alias: Dict[str, torch.Tensor] # alias -> [N,H]
252
+
253
+
254
+ class Chest2Vec:
255
+ """
256
+ Lightweight wrapper:
257
+ - loads base Qwen3-Embedding
258
+ - applies LoRA adapter
259
+ - attaches Stage-3 section pooler
260
+ """
261
+ def __init__(self, tokenizer, model, pooler, sections: List[str], device: torch.device):
262
+ self.tokenizer = tokenizer
263
+ self.model = model
264
+ self.pooler = pooler
265
+ self.sections = list(sections)
266
+ self.device = device
267
+
268
+ self.model.eval()
269
+ self.pooler.eval()
270
+
271
+ @classmethod
272
+ def from_pretrained(
273
+ cls,
274
+ repo_id_or_path: str,
275
+ *,
276
+ device: str = "cuda:0",
277
+ use_4bit: bool = False,
278
+ force_flash_attention_2: bool = True,
279
+ ) -> "Chest2Vec":
280
+ repo_path = _resolve_repo_path(repo_id_or_path)
281
+
282
+ cfg_path = os.path.join(repo_path, "chest2vec_config.json")
283
+ if not os.path.isfile(cfg_path):
284
+ raise FileNotFoundError(f"Missing chest2vec_config.json in {repo_path}")
285
+ cfg = _read_json(cfg_path)
286
+
287
+ base_model = str(cfg["base_model"])
288
+ adapter_subdir = str(cfg.get("adapter_subdir", "contrastive"))
289
+ pooler_pt = str(cfg.get("pooler_pt", "section_pooler.pt"))
290
+ pooler_cfg = str(cfg.get("pooler_cfg", "section_pooler_config.json"))
291
+ sections = cfg.get("sections", SECTION_NAMES)
292
+
293
+ if force_flash_attention_2 or bool(cfg.get("require_flash_attention_2", False)):
294
+ attn_impl = require_flash_attention_2()
295
+ else:
296
+ attn_impl = "sdpa"
297
+
298
+ if not _HAS_PEFT:
299
+ raise RuntimeError("peft is required. Install: pip install peft")
300
+
301
+ device_t = torch.device(device)
302
+
303
+ tokenizer = AutoTokenizer.from_pretrained(base_model, padding_side="left", trust_remote_code=True)
304
+ if tokenizer.pad_token_id is None:
305
+ tokenizer.pad_token = tokenizer.eos_token
306
+
307
+ device_map = {"": str(device_t)}
308
+
309
+ # Load base model with FlashAttention-2
310
+ if use_4bit:
311
+ qconf = BitsAndBytesConfig(
312
+ load_in_4bit=True,
313
+ bnb_4bit_quant_type="nf4",
314
+ bnb_4bit_use_double_quant=True,
315
+ bnb_4bit_compute_dtype=torch.bfloat16,
316
+ )
317
+ try:
318
+ base = AutoModel.from_pretrained(
319
+ base_model,
320
+ trust_remote_code=True,
321
+ attn_implementation=attn_impl,
322
+ quantization_config=qconf,
323
+ device_map=device_map,
324
+ )
325
+ except TypeError as e:
326
+ raise RuntimeError(
327
+ "Your transformers version does not support attn_implementation=... "
328
+ "Upgrade transformers to use FlashAttention-2."
329
+ ) from e
330
+ else:
331
+ try:
332
+ base = AutoModel.from_pretrained(
333
+ base_model,
334
+ trust_remote_code=True,
335
+ attn_implementation=attn_impl,
336
+ torch_dtype=torch.bfloat16,
337
+ device_map=device_map,
338
+ )
339
+ except TypeError as e:
340
+ raise RuntimeError(
341
+ "Your transformers version does not support attn_implementation=... "
342
+ "Upgrade transformers to use FlashAttention-2."
343
+ ) from e
344
+
345
+ # Load adapter from this repo folder
346
+ adapter_dir = os.path.join(repo_path, adapter_subdir)
347
+ if not os.path.isfile(os.path.join(adapter_dir, "adapter_config.json")):
348
+ raise FileNotFoundError(f"adapter_config.json not found under: {adapter_dir}")
349
+
350
+ model = PeftModel.from_pretrained(base, adapter_dir)
351
+ model.eval()
352
+
353
+ # Attach section pooler
354
+ pooler_cfg_path = os.path.join(repo_path, pooler_cfg)
355
+ pooler_pt_path = os.path.join(repo_path, pooler_pt)
356
+ if not os.path.isfile(pooler_cfg_path):
357
+ raise FileNotFoundError(f"Missing pooler config: {pooler_cfg_path}")
358
+ if not os.path.isfile(pooler_pt_path):
359
+ raise FileNotFoundError(f"Missing pooler weights: {pooler_pt_path}")
360
+
361
+ pcfg = _read_json(pooler_cfg_path)
362
+
363
+ hidden_size = int(getattr(model.module if hasattr(model, "module") else model, "config").hidden_size)
364
+ mlp_hidden = int(pcfg.get("mlp_hidden", hidden_size))
365
+ use_layernorm = bool(pcfg.get("use_layernorm", True))
366
+ pool_dropout = float(pcfg.get("pool_dropout", 0.1))
367
+ pool_scale = float(pcfg.get("pool_scale", 0.0))
368
+
369
+ pooler = SectionQueryAttnPooler(
370
+ hidden_size=hidden_size,
371
+ num_sections=len(sections),
372
+ mlp_hidden=mlp_hidden,
373
+ use_layernorm=use_layernorm,
374
+ pool_dropout=pool_dropout,
375
+ pool_scale=pool_scale,
376
+ )
377
+ sd = torch.load(pooler_pt_path, map_location="cpu")
378
+ pooler.load_state_dict(sd, strict=True)
379
+ pooler.eval()
380
+
381
+ # Move pooler to same device/dtype as hidden states
382
+ # (we keep inference in autocast)
383
+ pooler.to(device=device_t, dtype=torch.bfloat16 if device_t.type == "cuda" else torch.float32)
384
+
385
+ return cls(tokenizer=tokenizer, model=model, pooler=pooler, sections=sections, device=device_t)
386
+
387
+ @torch.inference_mode()
388
+ def embed_texts(
389
+ self,
390
+ texts: List[str],
391
+ *,
392
+ max_len: int = 512,
393
+ batch_size: int = 16,
394
+ return_cpu_float32: bool = True,
395
+ ) -> EmbedOutput:
396
+ """
397
+ Encodes arbitrary texts (candidates, section strings, etc.)
398
+ Returns:
399
+ - section_matrix: [N,9,H]
400
+ - global_embedding: [N,H] = mean(section_matrix) then L2
401
+ - by_section_name: dict[name] -> [N,H]
402
+ - by_alias: dict['lungs'/'impression'/...] -> [N,H]
403
+ """
404
+ # Determine AMP
405
+ device = self.device
406
+ if device.type == "cuda":
407
+ amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
408
+ use_amp = True
409
+ else:
410
+ amp_dtype = torch.float32
411
+ use_amp = False
412
+
413
+ outs = []
414
+ for i in range(0, len(texts), batch_size):
415
+ chunk = [str(t) for t in texts[i:i + batch_size]]
416
+ enc = encode_with_eos_ids(self.tokenizer, chunk, max_len)
417
+ input_ids = enc["input_ids"].to(device, non_blocking=True)
418
+ attention_mask = enc["attention_mask"].to(device, non_blocking=True)
419
+
420
+ with torch.autocast(device_type=("cuda" if device.type == "cuda" else "cpu"),
421
+ dtype=amp_dtype, enabled=use_amp):
422
+ h = get_last_hidden_state(self.model, input_ids, attention_mask) # [B,T,H]
423
+ _ensure_pooler_device_dtype(self.pooler, device=h.device, dtype=h.dtype)
424
+ sec = self.pooler.forward_all(h, attention_mask) # [B,S,H] normalized
425
+
426
+ outs.append(sec.detach())
427
+
428
+ section_matrix = torch.cat(outs, dim=0) # on device, dtype ~ bf16
429
+ # Global embedding: mean over sections then normalize
430
+ global_emb = F.normalize(section_matrix.float().mean(dim=1), p=2, dim=-1)
431
+
432
+ # Move to CPU float32 if requested (recommended for retrieval stability)
433
+ if return_cpu_float32:
434
+ section_matrix_cpu = section_matrix.float().cpu()
435
+ # re-normalize to fix any numerical drift
436
+ section_matrix_cpu = F.normalize(section_matrix_cpu, p=2, dim=-1)
437
+ global_cpu = global_emb.float().cpu()
438
+ global_cpu = F.normalize(global_cpu, p=2, dim=-1)
439
+ else:
440
+ section_matrix_cpu = section_matrix
441
+ global_cpu = global_emb
442
+
443
+ by_section_name = {name: section_matrix_cpu[:, idx, :] for idx, name in enumerate(self.sections)}
444
+
445
+ # Helpful aliases for quick access
446
+ by_alias: Dict[str, torch.Tensor] = {}
447
+ by_alias["global"] = global_cpu
448
+ for alias, real in SECTION_ALIASES.items():
449
+ if real == "global":
450
+ continue
451
+ if real in by_section_name:
452
+ by_alias[alias] = by_section_name[real]
453
+
454
+ return EmbedOutput(
455
+ section_matrix=section_matrix_cpu,
456
+ global_embedding=global_cpu,
457
+ by_section_name=by_section_name,
458
+ by_alias=by_alias,
459
+ )
460
+
461
+ @torch.inference_mode()
462
+ def embed_instruction_query(
463
+ self,
464
+ instructions: List[str],
465
+ queries: List[str],
466
+ *,
467
+ max_len: int = 512,
468
+ batch_size: int = 16,
469
+ return_cpu_float32: bool = True,
470
+ ) -> EmbedOutput:
471
+ if len(instructions) != len(queries):
472
+ raise ValueError("instructions and queries must have the same length.")
473
+ q_texts = [build_qwen_query(i, q) for i, q in zip(instructions, queries)]
474
+ return self.embed_texts(
475
+ q_texts,
476
+ max_len=max_len,
477
+ batch_size=batch_size,
478
+ return_cpu_float32=return_cpu_float32,
479
+ )
480
+
481
+ @staticmethod
482
+ def cosine_topk(
483
+ query_emb: torch.Tensor, # [Nq,H] CPU float32 recommended
484
+ cand_emb: torch.Tensor, # [Nd,H] CPU float32 recommended
485
+ k: int = 10,
486
+ *,
487
+ device: str = "cuda",
488
+ query_batch_size: int = 256,
489
+ doc_chunk_size: int = 8192,
490
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
491
+ """
492
+ Chunked cosine top-k, stable in float32.
493
+ Returns (top_scores [Nq,k], top_indices [Nq,k]) on CPU.
494
+ """
495
+ device_t = torch.device(device)
496
+ q = F.normalize(query_emb.float(), p=2, dim=-1)
497
+ d = F.normalize(cand_emb.float(), p=2, dim=-1)
498
+ Nq, H = q.shape
499
+ Nd = d.shape[0]
500
+ k = min(int(k), Nd)
501
+
502
+ top_scores_all = torch.empty((Nq, k), dtype=torch.float32)
503
+ top_indices_all = torch.empty((Nq, k), dtype=torch.long)
504
+
505
+ for qs in range(0, Nq, query_batch_size):
506
+ qe = q[qs:qs + query_batch_size].to(device_t, non_blocking=True)
507
+ bq = qe.size(0)
508
+
509
+ top_scores = torch.full((bq, k), -1e9, device=device_t, dtype=torch.float32)
510
+ top_indices = torch.full((bq, k), -1, device=device_t, dtype=torch.long)
511
+
512
+ for ds in range(0, Nd, doc_chunk_size):
513
+ de = d[ds:ds + doc_chunk_size].to(device_t, non_blocking=True)
514
+ scores = (qe @ de.T).float()
515
+
516
+ chunk = scores.size(1)
517
+ idx_chunk = torch.arange(ds, ds + chunk, device=device_t, dtype=torch.long).unsqueeze(0).expand(bq, -1)
518
+
519
+ comb_scores = torch.cat([top_scores, scores], dim=1)
520
+ comb_idx = torch.cat([top_indices, idx_chunk], dim=1)
521
+
522
+ new_scores, new_pos = torch.topk(comb_scores, k, dim=1)
523
+ new_idx = comb_idx.gather(1, new_pos)
524
+
525
+ top_scores, top_indices = new_scores, new_idx
526
+
527
+ top_scores_all[qs:qs + bq] = top_scores.cpu()
528
+ top_indices_all[qs:qs + bq] = top_indices.cpu()
529
+
530
+ return top_scores_all, top_indices_all
chest2vec_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "chest2vec_0.6b_cxr",
3
+ "base_model": "Qwen/Qwen3-Embedding-0.6B",
4
+ "adapter_subdir": "contrastive",
5
+ "pooler_pt": "section_pooler.pt",
6
+ "pooler_cfg": "section_pooler_config.json",
7
+ "require_flash_attention_2": true,
8
+ "default_max_len": 512,
9
+ "sections": [
10
+ "Lungs and Airways",
11
+ "Pleura",
12
+ "Cardiovascular",
13
+ "Hila and Mediastinum",
14
+ "Tubes & Devices",
15
+ "Musculoskeletal and Chest Wall",
16
+ "Abdominal",
17
+ "impression",
18
+ "Other"
19
+ ],
20
+ "global_pool": "mean_of_sections_then_l2"
21
+ }
22
+
contrastive/adapter_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": {
4
+ "base_model_class": "Qwen3Model",
5
+ "parent_library": "transformers.models.qwen3.modeling_qwen3"
6
+ },
7
+ "base_model_name_or_path": "Qwen/Qwen3-Embedding-0.6B",
8
+ "bias": "none",
9
+ "corda_config": null,
10
+ "eva_config": null,
11
+ "exclude_modules": null,
12
+ "fan_in_fan_out": false,
13
+ "inference_mode": true,
14
+ "init_lora_weights": true,
15
+ "layer_replication": null,
16
+ "layers_pattern": null,
17
+ "layers_to_transform": null,
18
+ "loftq_config": {},
19
+ "lora_alpha": 32,
20
+ "lora_bias": false,
21
+ "lora_dropout": 0.1,
22
+ "megatron_config": null,
23
+ "megatron_core": "megatron.core",
24
+ "modules_to_save": null,
25
+ "peft_type": "LORA",
26
+ "r": 16,
27
+ "rank_pattern": {},
28
+ "revision": null,
29
+ "target_modules": [
30
+ "down_proj",
31
+ "up_proj",
32
+ "v_proj",
33
+ "k_proj",
34
+ "gate_proj",
35
+ "o_proj",
36
+ "q_proj"
37
+ ],
38
+ "task_type": null,
39
+ "trainable_token_indices": null,
40
+ "use_dora": false,
41
+ "use_rslora": false
42
+ }
contrastive/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c22b272dfd1861cd967b63c634195fbfad937716be51a60d53ba58b167f9220
3
+ size 40419816
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ peft
4
+ huggingface_hub
5
+ flash-attn>=2
6
+ bitsandbytes
7
+ accelerate
8
+ numpy
section_pooler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d766b5ef522274fd3238d1607d59dc48ccffbf1f5819a550ba16fcb6bda2ce2
3
+ size 4219403
section_pooler_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pooler_type": "query_attn",
3
+ "sections": [
4
+ "Lungs and Airways",
5
+ "Pleura",
6
+ "Cardiovascular",
7
+ "Hila and Mediastinum",
8
+ "Tubes & Devices",
9
+ "Musculoskeletal and Chest Wall",
10
+ "Abdominal",
11
+ "impression",
12
+ "Other"
13
+ ],
14
+ "hidden_size": 1024,
15
+ "mlp_hidden": 1024,
16
+ "use_layernorm": true,
17
+ "pool_dropout": 0.1,
18
+ "pool_scale": 0.0,
19
+ "loss": "multipos_sigmoid"
20
+ }