lukeingawesome commited on
Commit
9286661
·
verified ·
1 Parent(s): 9a59089

Add trust_remote_code integration (Qwen3-Embedding + LoRA)

Browse files
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +34,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
 
34
  *.zip filter=lfs diff=lfs merge=lfs -text
35
  *.zst filter=lfs diff=lfs merge=lfs -text
36
  *tfevents* filter=lfs diff=lfs merge=lfs -text
37
+ =======
38
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
39
+ *.pt filter=lfs diff=lfs merge=lfs -text
40
+ *.bin filter=lfs diff=lfs merge=lfs -text
41
+ >>>>>>> 2dc7409 (Release chest2vec_0.6b_cxr (Stage2 LoRA + Stage3 pooler + API))
.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Build artifacts
2
+ dist/
3
+ build/
4
+ *.egg-info/
5
+ __pycache__/
6
+ *.pyc
7
+ *.pyo
8
+ *.pyd
9
+ .Python
10
+
11
+ # Testing
12
+ .pytest_cache/
13
+ .coverage
14
+ htmlcov/
15
+
16
+ # IDE
17
+ .vscode/
18
+ .idea/
19
+ *.swp
20
+ *.swo
21
+
22
+ # Jupyter
23
+ .ipynb_checkpoints/
24
+
README.md ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - text-embeddings
4
+ - retrieval
5
+ - radiology
6
+ - chest
7
+ - qwen
8
+ library_name: transformers
9
+ ---
10
+
11
+ # chest2vec_0.6B
12
+
13
+ This repository contains the *delta weights* for a global embedding model on top of **Qwen/Qwen3-Embedding-0.6B**:
14
+
15
+ - **LoRA Adapter**: Contrastive LoRA adapter trained with multi-positive sigmoid loss under `./contrastive/`
16
+ - **Inference helper**: `chest2vec.py`
17
+
18
+ Base model weights are **not** included; they are downloaded from Hugging Face at runtime.
19
+
20
+ ## Model Architecture
21
+
22
+ Chest2Vec is a two-stage model:
23
+ 1. **Base**: Qwen/Qwen3-Embedding-0.6B (downloaded at runtime)
24
+ 2. **LoRA Adapter**: Contrastive LoRA adapter trained with multi-positive sigmoid loss
25
+ 3. **Pooling**: Last-token pooling (EOS token) for global embeddings
26
+
27
+ The model produces **global embeddings only** (no section-specific embeddings).
28
+
29
+ ## Installation
30
+
31
+ Install the package and all dependencies:
32
+
33
+ ```bash
34
+ # Install PyTorch with CUDA 12.6 support
35
+ pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
36
+
37
+ # Install transformers and trl
38
+ pip install transformers==4.57.3 trl==0.9.3
39
+
40
+ # Install deepspeed
41
+ pip install deepspeed==0.16.9
42
+
43
+ # Install flash-attention
44
+ pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
45
+
46
+ # Install chest2vec package
47
+ pip install chest2vec
48
+ ```
49
+
50
+ Or use the installation script:
51
+
52
+ ```bash
53
+ bash install_deps.sh
54
+ ```
55
+
56
+ ## Requirements
57
+
58
+ This model **requires FlashAttention-2** (CUDA) by default, which is automatically installed with the package.
59
+
60
+ ## Quickstart
61
+
62
+ ### Installation + Loading
63
+
64
+ ```python
65
+ from chest2vec import Chest2Vec
66
+
67
+ # Load model from Hugging Face Hub
68
+ m = Chest2Vec.from_pretrained("lukeingawesome/chest2vec_0.6b_chest", device="cuda:0")
69
+ ```
70
+
71
+ ### Instruction + Query Embeddings
72
+
73
+ ```python
74
+ instructions = ["Find findings about the lungs."]
75
+ queries = ["Consolidation in the right lower lobe."]
76
+
77
+ out = m.embed_instruction_query(instructions, queries, max_len=512, batch_size=8)
78
+
79
+ # Global embedding (last-token pooling)
80
+ emb = out.embedding # [N, H]
81
+ ```
82
+
83
+ ### Candidate Embeddings (Retrieval Bank)
84
+
85
+ ```python
86
+ candidates = [
87
+ "Lungs are clear. No focal consolidation.",
88
+ "Pleural effusion on the left.",
89
+ "Cardiomediastinal silhouette is normal."
90
+ ]
91
+
92
+ cand_out = m.embed_texts(candidates, max_len=512, batch_size=16)
93
+
94
+ cand_emb = cand_out.embedding # [N, H]
95
+ ```
96
+
97
+ ### Retrieval Example (Cosine Top-K)
98
+
99
+ ```python
100
+ # Query embeddings
101
+ q = out.embedding # [Nq, H]
102
+
103
+ # Document embeddings
104
+ d = cand_out.embedding # [Nd, H]
105
+
106
+ # Compute top-k cosine similarities
107
+ scores, idx = Chest2Vec.cosine_topk(q, d, k=5, device="cuda")
108
+ # scores: [Nq, k] - similarity scores
109
+ # idx: [Nq, k] - indices of top-k candidates
110
+
111
+ print(f"Top-5 scores: {scores[0]}")
112
+ print(f"Top-5 indices: {idx[0]}")
113
+ ```
114
+
115
+ ## API Reference
116
+
117
+ ### `Chest2Vec.from_pretrained()`
118
+
119
+ Load the model from Hugging Face Hub or local path.
120
+
121
+ ```python
122
+ m = Chest2Vec.from_pretrained(
123
+ repo_id_or_path: str, # Hugging Face repo ID or local path
124
+ device: str = "cuda:0", # Device to load model on
125
+ use_4bit: bool = False, # Use 4-bit quantization
126
+ force_flash_attention_2: bool = True
127
+ )
128
+ ```
129
+
130
+ ### `embed_instruction_query()`
131
+
132
+ Embed instruction-query pairs. Returns `EmbedOutput` with:
133
+ - `embedding`: `[N, H]` - global embeddings (L2-normalized, last-token pooling)
134
+
135
+ ```python
136
+ out = m.embed_instruction_query(
137
+ instructions: List[str],
138
+ queries: List[str],
139
+ max_len: int = 512,
140
+ batch_size: int = 16
141
+ )
142
+ ```
143
+
144
+ ### `embed_texts()`
145
+
146
+ Embed plain texts (for document/candidate encoding).
147
+
148
+ ```python
149
+ out = m.embed_texts(
150
+ texts: List[str],
151
+ max_len: int = 512,
152
+ batch_size: int = 16
153
+ )
154
+ ```
155
+
156
+ Returns `EmbedOutput` with:
157
+ - `embedding`: `[N, H]` - global embeddings (L2-normalized, last-token pooling)
158
+
159
+ ### `cosine_topk()`
160
+
161
+ Static method for efficient top-k cosine similarity search.
162
+
163
+ ```python
164
+ scores, idx = Chest2Vec.cosine_topk(
165
+ query_emb: torch.Tensor, # [Nq, H]
166
+ cand_emb: torch.Tensor, # [Nd, H]
167
+ k: int = 10,
168
+ device: str = "cuda"
169
+ )
170
+ ```
171
+
172
+ ## Model Files
173
+
174
+ - `chest2vec.py` - Model class and inference utilities
175
+ - `chest2vec_config.json` - Model configuration
176
+ - `contrastive/` - LoRA adapter directory
177
+ - `adapter_config.json` - LoRA adapter configuration
178
+ - `adapter_model.safetensors` - LoRA adapter weights
179
+
180
+ ## Citation
181
+
182
+ If you use this model, please cite:
183
+
184
+ ```bibtex
185
+ @misc{chest2vec_0.6b_chest,
186
+ title={Chest2Vec: Global Embeddings for Chest X-Ray Reports},
187
+ author={Your Name},
188
+ year={2024},
189
+ howpublished={\url{https://huggingface.co/lukeingawesome/chest2vec_0.6b_chest}}
190
+ }
191
+ ```
192
+
193
+ ## License
194
+
195
+ [Specify your license here]
196
+
197
+ ## Usage (🤗 transformers)
198
+
199
+ ```python
200
+ from transformers import AutoModel
201
+
202
+ # base Qwen3-Embedding weights download automatically; needs trust_remote_code
203
+ model = AutoModel.from_pretrained("chest2vec/chest2vec_0.6B", trust_remote_code=True)
204
+
205
+ emb = model.embed_texts([
206
+ "Frontal chest radiograph. No focal consolidation. No pneumothorax. Heart size normal.",
207
+ ])
208
+ emb # [N, H] L2-normalized report embedding (last-token / EOS pooling)
209
+
210
+ # similarity
211
+ (emb[0] @ emb[1]) # cosine similarity (rows are unit-norm)
212
+ ```
213
+
214
+ FlashAttention-2 is used automatically on CUDA when `flash-attn>=2` is installed
215
+ (matching training); otherwise it falls back to SDPA so the model also loads on CPU.
config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "chest2vec",
3
+ "architectures": [
4
+ "Chest2VecModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_chest2vec.Chest2VecConfig",
8
+ "AutoModel": "modeling_chest2vec.Chest2VecModel"
9
+ },
10
+ "base_model": "Qwen/Qwen3-Embedding-0.6B",
11
+ "adapter_subdir": "contrastive",
12
+ "require_flash_attention_2": true,
13
+ "default_max_len": 512
14
+ }
configuration_chest2vec.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration for Chest2Vec — a LoRA-tuned Qwen3-Embedding model for
2
+ chest radiology report embeddings.
3
+
4
+ Chest2Vec = Qwen3-Embedding base + contrastive LoRA adapter. It produces a
5
+ single L2-normalized report embedding (last-token / EOS pooling), matching the
6
+ Qwen3-Embedding convention.
7
+ """
8
+ from transformers import PretrainedConfig
9
+
10
+
11
+ class Chest2VecConfig(PretrainedConfig):
12
+ model_type = "chest2vec"
13
+
14
+ def __init__(
15
+ self,
16
+ base_model: str = "Qwen/Qwen3-Embedding-0.6B",
17
+ adapter_subdir: str = "contrastive",
18
+ require_flash_attention_2: bool = True,
19
+ default_max_len: int = 512,
20
+ **kwargs,
21
+ ):
22
+ self.base_model = base_model
23
+ self.adapter_subdir = adapter_subdir
24
+ self.require_flash_attention_2 = require_flash_attention_2
25
+ self.default_max_len = default_max_len
26
+ super().__init__(**kwargs)
contrastive/adapter_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": {
4
+ "base_model_class": "Qwen3Model",
5
+ "parent_library": "transformers.models.qwen3.modeling_qwen3"
6
+ },
7
+ "base_model_name_or_path": "Qwen/Qwen3-Embedding-0.6B",
8
+ "bias": "none",
9
+ "corda_config": null,
10
+ "eva_config": null,
11
+ "exclude_modules": null,
12
+ "fan_in_fan_out": false,
13
+ "inference_mode": true,
14
+ "init_lora_weights": true,
15
+ "layer_replication": null,
16
+ "layers_pattern": null,
17
+ "layers_to_transform": null,
18
+ "loftq_config": {},
19
+ "lora_alpha": 32,
20
+ "lora_bias": false,
21
+ "lora_dropout": 0.1,
22
+ "megatron_config": null,
23
+ "megatron_core": "megatron.core",
24
+ "modules_to_save": null,
25
+ "peft_type": "LORA",
26
+ "r": 16,
27
+ "rank_pattern": {},
28
+ "revision": null,
29
+ "target_modules": [
30
+ "up_proj",
31
+ "k_proj",
32
+ "o_proj",
33
+ "v_proj",
34
+ "gate_proj",
35
+ "q_proj",
36
+ "down_proj"
37
+ ],
38
+ "task_type": null,
39
+ "trainable_token_indices": null,
40
+ "use_dora": false,
41
+ "use_rslora": false
42
+ }
contrastive/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74eda0bb349ef0a65df7228172dca048da8478a16a223c7a81c8292ecd4eb75c
3
+ size 20234904
modeling_chest2vec.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chest2Vec — LoRA-tuned Qwen3-Embedding model for chest radiology reports.
2
+
3
+ Load with:
4
+
5
+ from transformers import AutoModel
6
+ model = AutoModel.from_pretrained("chest2vec/chest2vec_0.6B", trust_remote_code=True)
7
+ emb = model.embed_texts(["Frontal chest radiograph. No pneumothorax."]) # [N, H], L2-normalized
8
+
9
+ Architecture:
10
+ 1. Base : Qwen/Qwen3-Embedding-{0.6B,4B} (downloaded at runtime)
11
+ 2. Adapter: frozen contrastive LoRA adapter (./contrastive)
12
+
13
+ Embeddings use last-token (EOS) pooling with left padding, matching Qwen3-Embedding
14
+ and the Stage-2 training setup. FlashAttention-2 is used when CUDA + flash-attn>=2
15
+ are available (matching training); otherwise it falls back to SDPA so the model
16
+ also loads on CPU.
17
+ """
18
+ import os
19
+ from typing import Dict, List, Optional
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+
24
+ from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig, PreTrainedModel
25
+
26
+ from .configuration_chest2vec import Chest2VecConfig
27
+
28
+ try:
29
+ from peft import PeftModel
30
+ _HAS_PEFT = True
31
+ except Exception:
32
+ PeftModel = None
33
+ _HAS_PEFT = False
34
+
35
+ try:
36
+ from huggingface_hub import snapshot_download
37
+ _HAS_HUB = True
38
+ except Exception:
39
+ snapshot_download = None
40
+ _HAS_HUB = False
41
+
42
+
43
+ # ----------------------------------------------------------------------------
44
+ # Attention backend selection
45
+ # ----------------------------------------------------------------------------
46
+ def _flash_attn_available() -> bool:
47
+ if not torch.cuda.is_available():
48
+ return False
49
+ try:
50
+ import flash_attn # noqa: F401
51
+ ver = getattr(flash_attn, "__version__", "0.0.0")
52
+ return int(str(ver).split(".")[0]) >= 2
53
+ except Exception:
54
+ return False
55
+
56
+
57
+ def _pick_attn_impl(requested: Optional[str], want_flash: bool) -> str:
58
+ import warnings
59
+ if requested:
60
+ return requested
61
+ if want_flash and _flash_attn_available():
62
+ return "flash_attention_2"
63
+ if want_flash:
64
+ warnings.warn(
65
+ "Chest2Vec was trained with FlashAttention-2, but it is unavailable "
66
+ "(needs CUDA + flash-attn>=2). Falling back to 'sdpa'; embeddings may "
67
+ "differ very slightly from the reference implementation.",
68
+ RuntimeWarning,
69
+ )
70
+ return "sdpa"
71
+
72
+
73
+ # ----------------------------------------------------------------------------
74
+ # Tokenization / pooling helpers (match Qwen3-Embedding + training)
75
+ # ----------------------------------------------------------------------------
76
+ def build_qwen_query(instruction: str, query: str) -> str:
77
+ return f"Instruct: {str(instruction).strip()}\nQuery: {str(query).strip()}"
78
+
79
+
80
+ def get_pool_token_id(tok) -> int:
81
+ eod_id = tok.convert_tokens_to_ids("<|endoftext|>")
82
+ if eod_id is None or eod_id < 0:
83
+ eod_id = tok.pad_token_id
84
+ return eod_id
85
+
86
+
87
+ def encode_with_eos_ids(tok, texts: List[str], max_len: int) -> Dict[str, torch.Tensor]:
88
+ """add_special_tokens=False, truncate to max_len-1, append <|endoftext|>, left-pad."""
89
+ pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
90
+ eod_id = get_pool_token_id(tok)
91
+ enc = tok(
92
+ [str(t) for t in texts],
93
+ add_special_tokens=False,
94
+ truncation=True,
95
+ max_length=max_len - 1,
96
+ padding=False,
97
+ return_attention_mask=False,
98
+ )
99
+ input_ids = [ids + [eod_id] for ids in enc["input_ids"]]
100
+ attn_mask = [[1] * len(ids) for ids in input_ids]
101
+ T = max((len(ids) for ids in input_ids), default=1)
102
+ input_ids = [[pad_id] * (T - len(ids)) + ids for ids in input_ids]
103
+ attn_mask = [[0] * (T - len(m)) + m for m in attn_mask]
104
+ return {
105
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
106
+ "attention_mask": torch.tensor(attn_mask, dtype=torch.long),
107
+ }
108
+
109
+
110
+ def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
111
+ """Left-padding-aware last-token (EOS) pooling."""
112
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
113
+ if left_padding:
114
+ return last_hidden_states[:, -1]
115
+ idx = attention_mask.sum(dim=1) - 1
116
+ return last_hidden_states[torch.arange(last_hidden_states.size(0), device=last_hidden_states.device), idx]
117
+
118
+
119
+ def get_last_hidden_state(model, input_ids, attention_mask):
120
+ m = model.module if hasattr(model, "module") else model
121
+ position_ids = attention_mask.long().cumsum(-1) - 1
122
+ position_ids.masked_fill_(attention_mask == 0, 0)
123
+ out = m(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
124
+ use_cache=False, return_dict=True)
125
+ if getattr(out, "last_hidden_state", None) is not None:
126
+ return out.last_hidden_state
127
+ out = m(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
128
+ output_hidden_states=True, use_cache=False, return_dict=True)
129
+ return out.hidden_states[-1]
130
+
131
+
132
+ class Chest2VecModel(PreTrainedModel):
133
+ """LoRA-tuned Qwen3-Embedding model producing L2-normalized report embeddings."""
134
+
135
+ config_class = Chest2VecConfig
136
+ base_model_prefix = "chest2vec"
137
+ # Attention is handled by the inner Qwen3 backbone; advertise support so the
138
+ # transformers attn-implementation validator on this wrapper passes.
139
+ _supports_sdpa = True
140
+ _supports_flash_attn_2 = True
141
+ _supports_flash_attn = True
142
+ _supports_attention_backend = True
143
+
144
+ def __init__(self, config: Chest2VecConfig):
145
+ super().__init__(config)
146
+ # The base+adapter are assembled in `from_pretrained` (base downloads at runtime).
147
+ self.backbone = None
148
+ self.tokenizer = None
149
+ self._device = torch.device("cpu")
150
+ self.register_buffer("_anchor", torch.zeros(1), persistent=False)
151
+
152
+ def get_input_embeddings(self):
153
+ return None
154
+
155
+ def set_input_embeddings(self, value):
156
+ pass
157
+
158
+ @classmethod
159
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
160
+ config = kwargs.pop("config", None)
161
+ device = kwargs.pop("device", None)
162
+ use_4bit = kwargs.pop("use_4bit", False)
163
+ attn_implementation = kwargs.pop("attn_implementation", None)
164
+ torch_dtype = kwargs.pop("torch_dtype", None)
165
+ token = kwargs.pop("token", None) or kwargs.pop("use_auth_token", None)
166
+ cache_dir = kwargs.pop("cache_dir", None)
167
+ # remaining HF plumbing kwargs (state_dict, low_cpu_mem_usage, ...) are ignored
168
+
169
+ repo_path = pretrained_model_name_or_path
170
+ if not os.path.isdir(repo_path):
171
+ if not _HAS_HUB:
172
+ raise RuntimeError("huggingface_hub is required to load by repo_id.")
173
+ repo_path = snapshot_download(repo_path, token=token, cache_dir=cache_dir)
174
+
175
+ if config is None:
176
+ config = Chest2VecConfig.from_pretrained(repo_path)
177
+
178
+ if device is None:
179
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
180
+ device_t = torch.device(device)
181
+ if torch_dtype is None:
182
+ torch_dtype = torch.bfloat16 if device_t.type == "cuda" else torch.float32
183
+
184
+ model = cls(config)
185
+ model._assemble(repo_path, device=device_t, use_4bit=use_4bit,
186
+ attn_implementation=attn_implementation, torch_dtype=torch_dtype, token=token)
187
+ return model
188
+
189
+ def _assemble(self, repo_path, *, device, use_4bit, attn_implementation, torch_dtype, token=None):
190
+ cfg = self.config
191
+ if not _HAS_PEFT:
192
+ raise RuntimeError("peft is required. Install: pip install peft")
193
+
194
+ attn_impl = _pick_attn_impl(attn_implementation, bool(cfg.require_flash_attention_2))
195
+
196
+ tokenizer = AutoTokenizer.from_pretrained(
197
+ cfg.base_model, padding_side="left", trust_remote_code=True, token=token
198
+ )
199
+ if tokenizer.pad_token_id is None:
200
+ tokenizer.pad_token = tokenizer.eos_token
201
+
202
+ base_kwargs = dict(trust_remote_code=True, attn_implementation=attn_impl, token=token)
203
+ if use_4bit:
204
+ base_kwargs["quantization_config"] = BitsAndBytesConfig(
205
+ load_in_4bit=True, bnb_4bit_quant_type="nf4",
206
+ bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16,
207
+ )
208
+ base_kwargs["device_map"] = {"": str(device)}
209
+ else:
210
+ base_kwargs["torch_dtype"] = torch_dtype
211
+ if device.type == "cuda":
212
+ base_kwargs["device_map"] = {"": str(device)}
213
+ try:
214
+ base = AutoModel.from_pretrained(cfg.base_model, **base_kwargs)
215
+ except TypeError as e:
216
+ raise RuntimeError("transformers too old for attn_implementation=...; please upgrade.") from e
217
+ if device.type != "cuda" and not use_4bit:
218
+ base = base.to(device)
219
+
220
+ adapter_dir = os.path.join(repo_path, cfg.adapter_subdir)
221
+ if not os.path.isfile(os.path.join(adapter_dir, "adapter_config.json")):
222
+ raise FileNotFoundError(f"adapter_config.json not found under: {adapter_dir}")
223
+ backbone = PeftModel.from_pretrained(base, adapter_dir)
224
+ backbone.eval()
225
+
226
+ self.backbone = backbone
227
+ self.tokenizer = tokenizer
228
+ self._device = device
229
+ self.eval()
230
+
231
+ @property
232
+ def device(self):
233
+ return self._device
234
+
235
+ @torch.inference_mode()
236
+ def embed_texts(self, texts: List[str], *, max_len: Optional[int] = None,
237
+ batch_size: int = 16, return_cpu_float32: bool = True) -> torch.Tensor:
238
+ """Return L2-normalized report embeddings, shape [N, H]."""
239
+ if self.backbone is None:
240
+ raise RuntimeError("Model not assembled; load via from_pretrained(...).")
241
+ max_len = int(max_len or self.config.default_max_len)
242
+ device = self._device
243
+ if device.type == "cuda":
244
+ amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
245
+ use_amp = True
246
+ else:
247
+ amp_dtype, use_amp = torch.float32, False
248
+
249
+ outs = []
250
+ for i in range(0, len(texts), batch_size):
251
+ chunk = [str(t) for t in texts[i:i + batch_size]]
252
+ enc = encode_with_eos_ids(self.tokenizer, chunk, max_len)
253
+ input_ids = enc["input_ids"].to(device, non_blocking=True)
254
+ attention_mask = enc["attention_mask"].to(device, non_blocking=True)
255
+ with torch.autocast(device_type=("cuda" if device.type == "cuda" else "cpu"),
256
+ dtype=amp_dtype, enabled=use_amp):
257
+ h = get_last_hidden_state(self.backbone, input_ids, attention_mask)
258
+ emb = F.normalize(last_token_pool(h, attention_mask).float(), p=2, dim=-1)
259
+ outs.append(emb.detach())
260
+ embeddings = torch.cat(outs, dim=0)
261
+ if return_cpu_float32:
262
+ embeddings = F.normalize(embeddings.float().cpu(), p=2, dim=-1)
263
+ return embeddings
264
+
265
+ @torch.inference_mode()
266
+ def embed_instruction_query(self, instructions: List[str], queries: List[str], **kw) -> torch.Tensor:
267
+ if len(instructions) != len(queries):
268
+ raise ValueError("instructions and queries must have the same length.")
269
+ return self.embed_texts([build_qwen_query(i, q) for i, q in zip(instructions, queries)], **kw)
270
+
271
+ def forward(self, texts: List[str], **kw) -> torch.Tensor: # type: ignore[override]
272
+ return self.embed_texts(texts, **kw)
273
+
274
+ @staticmethod
275
+ def cosine_topk(query_emb, cand_emb, k=10, *, device="cuda",
276
+ query_batch_size=256, doc_chunk_size=8192):
277
+ device_t = torch.device(device if torch.cuda.is_available() else "cpu")
278
+ q = F.normalize(query_emb.float(), p=2, dim=-1)
279
+ d = F.normalize(cand_emb.float(), p=2, dim=-1)
280
+ Nq, _ = q.shape
281
+ Nd = d.shape[0]
282
+ k = min(int(k), Nd)
283
+ top_scores_all = torch.empty((Nq, k), dtype=torch.float32)
284
+ top_indices_all = torch.empty((Nq, k), dtype=torch.long)
285
+ for qs in range(0, Nq, query_batch_size):
286
+ qe = q[qs:qs + query_batch_size].to(device_t, non_blocking=True)
287
+ bq = qe.size(0)
288
+ top_scores = torch.full((bq, k), -1e9, device=device_t, dtype=torch.float32)
289
+ top_indices = torch.full((bq, k), -1, device=device_t, dtype=torch.long)
290
+ for ds in range(0, Nd, doc_chunk_size):
291
+ de = d[ds:ds + doc_chunk_size].to(device_t, non_blocking=True)
292
+ scores = (qe @ de.T).float()
293
+ chunk = scores.size(1)
294
+ idx_chunk = torch.arange(ds, ds + chunk, device=device_t, dtype=torch.long).unsqueeze(0).expand(bq, -1)
295
+ comb_scores = torch.cat([top_scores, scores], dim=1)
296
+ comb_idx = torch.cat([top_indices, idx_chunk], dim=1)
297
+ new_scores, new_pos = torch.topk(comb_scores, k, dim=1)
298
+ top_scores, top_indices = new_scores, comb_idx.gather(1, new_pos)
299
+ top_scores_all[qs:qs + bq] = top_scores.cpu()
300
+ top_indices_all[qs:qs + bq] = top_indices.cpu()
301
+ return top_scores_all, top_indices_all
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.6.0
2
+ torchvision==0.21.0
3
+ torchaudio==2.6.0
4
+ transformers==4.57.3
5
+ trl==0.9.3
6
+ deepspeed==0.16.9
7
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
8
+ peft
9
+ huggingface_hub
10
+ bitsandbytes
11
+ accelerate
12
+ numpy