lukeingawesome commited on
Commit
91c029d
·
1 Parent(s): 81bdf88

Update package structure for PyPI: fix setup.py, update README with installation steps, add install script

Browse files
.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Build artifacts
2
+ dist/
3
+ build/
4
+ *.egg-info/
5
+ __pycache__/
6
+ *.pyc
7
+ *.pyo
8
+ *.pyd
9
+ .Python
10
+
11
+ # Testing
12
+ .pytest_cache/
13
+ .coverage
14
+ htmlcov/
15
+
16
+ # IDE
17
+ .vscode/
18
+ .idea/
19
+ *.swp
20
+ *.swo
21
+
22
+ # Jupyter
23
+ .ipynb_checkpoints/
24
+
PUBLISH.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Publishing chest2vec to PyPI
2
+
3
+ ## Prerequisites
4
+
5
+ 1. Create a PyPI account at https://pypi.org/account/register/
6
+ 2. Create an API token at https://pypi.org/manage/account/token/
7
+ 3. Install twine: `pip install twine`
8
+
9
+ ## Build the package
10
+
11
+ ```bash
12
+ python3 -m build
13
+ ```
14
+
15
+ This creates `dist/chest2vec-0.6.0-py3-none-any.whl` and `dist/chest2vec-0.6.0.tar.gz`
16
+
17
+ ## Upload to PyPI
18
+
19
+ ### Test first on TestPyPI
20
+
21
+ ```bash
22
+ # Upload to TestPyPI first to test
23
+ twine upload --repository testpypi dist/*
24
+ # You'll be prompted for username (__token__) and password (your API token)
25
+ ```
26
+
27
+ ### Then upload to PyPI
28
+
29
+ ```bash
30
+ twine upload dist/*
31
+ # You'll be prompted for username (__token__) and password (your API token)
32
+ ```
33
+
34
+ ## After publishing
35
+
36
+ Once published, users can install with:
37
+
38
+ ```bash
39
+ pip install chest2vec
40
+ ```
41
+
42
+ Note: Users will still need to install PyTorch and flash-attention separately as documented in the README, since PyPI doesn't support custom index URLs in dependencies.
43
+
README.md CHANGED
@@ -41,17 +41,30 @@ The model produces embeddings for 9 distinct sections:
41
 
42
  ## Installation
43
 
44
- Install the package directly from PyPI:
45
 
46
  ```bash
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  pip install chest2vec
48
  ```
49
 
50
- This will automatically install all required dependencies including:
51
- - PyTorch 2.6.0 (with CUDA 12.6 support)
52
- - Transformers 4.57.3
53
- - FlashAttention 2.8.3
54
- - And other required packages
55
 
56
  ## Requirements
57
 
 
41
 
42
  ## Installation
43
 
44
+ Install the package and all dependencies:
45
 
46
  ```bash
47
+ # Install PyTorch with CUDA 12.6 support
48
+ pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
49
+
50
+ # Install transformers and trl
51
+ pip install transformers==4.57.3 trl==0.9.3
52
+
53
+ # Install deepspeed
54
+ pip install deepspeed==0.16.9
55
+
56
+ # Install flash-attention
57
+ pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
58
+
59
+ # Install chest2vec package
60
  pip install chest2vec
61
  ```
62
 
63
+ Or use the installation script:
64
+
65
+ ```bash
66
+ bash install_deps.sh
67
+ ```
68
 
69
  ## Requirements
70
 
__pycache__/chest2vec.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
chest2vec.egg-info/PKG-INFO ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: chest2vec
3
+ Version: 0.6.0
4
+ Summary: Section-aware embeddings for chest X-ray reports
5
+ Home-page: https://github.com/chest2vec/chest2vec
6
+ Author: Chest2Vec Team
7
+ Project-URL: Homepage, https://github.com/chest2vec/chest2vec
8
+ Requires-Python: >=3.8
9
+ Description-Content-Type: text/markdown
10
+ Requires-Dist: transformers==4.57.3
11
+ Requires-Dist: trl==0.9.3
12
+ Requires-Dist: deepspeed==0.16.9
13
+ Requires-Dist: peft
14
+ Requires-Dist: huggingface_hub
15
+ Requires-Dist: bitsandbytes
16
+ Requires-Dist: accelerate
17
+ Requires-Dist: numpy
18
+ Dynamic: author
19
+ Dynamic: home-page
20
+ Dynamic: requires-python
21
+
22
+ ---
23
+ tags:
24
+ - text-embeddings
25
+ - retrieval
26
+ - radiology
27
+ - cxr
28
+ - qwen
29
+ library_name: transformers
30
+ ---
31
+
32
+ # chest2vec_0.6b_cxr
33
+
34
+ This repository contains the *delta weights and pooling head* for a section-aware embedding model on top of **Qwen/Qwen3-Embedding-0.6B**:
35
+
36
+ - **Stage-2**: Frozen LoRA adapter (contrastive) under `./contrastive/`
37
+ - **Stage-3**: Section pooler `section_pooler.pt` producing **9 section embeddings**
38
+ - **Inference helper**: `chest2vec.py`
39
+
40
+ Base model weights are **not** included; they are downloaded from Hugging Face at runtime.
41
+
42
+ ## Model Architecture
43
+
44
+ Chest2Vec is a three-stage model:
45
+ 1. **Base**: Qwen/Qwen3-Embedding-0.6B (downloaded at runtime)
46
+ 2. **Stage-2**: Contrastive LoRA adapter trained with multi-positive sigmoid loss
47
+ 3. **Stage-3**: Section-aware query-attention pooler producing embeddings for 9 radiology report sections
48
+
49
+ ## Sections
50
+
51
+ The model produces embeddings for 9 distinct sections:
52
+
53
+ 1. Lungs and Airways
54
+ 2. Pleura
55
+ 3. Cardiovascular
56
+ 4. Hila and Mediastinum
57
+ 5. Tubes & Devices
58
+ 6. Musculoskeletal and Chest Wall
59
+ 7. Abdominal
60
+ 8. impression
61
+ 9. Other
62
+
63
+ ## Installation
64
+
65
+ Install the package and all dependencies:
66
+
67
+ ```bash
68
+ # Install PyTorch with CUDA 12.6 support
69
+ pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
70
+
71
+ # Install transformers and trl
72
+ pip install transformers==4.57.3 trl==0.9.3
73
+
74
+ # Install deepspeed
75
+ pip install deepspeed==0.16.9
76
+
77
+ # Install flash-attention
78
+ pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
79
+
80
+ # Install chest2vec package
81
+ pip install chest2vec
82
+ ```
83
+
84
+ Or use the installation script:
85
+
86
+ ```bash
87
+ bash install_deps.sh
88
+ ```
89
+
90
+ ## Requirements
91
+
92
+ This model **requires FlashAttention-2** (CUDA) by default, which is automatically installed with the package.
93
+
94
+ ## Quickstart
95
+
96
+ ### Installation + Loading
97
+
98
+ ```python
99
+ from chest2vec import Chest2Vec
100
+
101
+ # Load model from Hugging Face Hub
102
+ m = Chest2Vec.from_pretrained("chest2vec/chest2vec_0.6b_cxr", device="cuda:0")
103
+ ```
104
+
105
+ ### Instruction + Query Embeddings
106
+
107
+ ```python
108
+ instructions = ["Find findings about the lungs."]
109
+ queries = ["Consolidation in the right lower lobe."]
110
+
111
+ out = m.embed_instruction_query(instructions, queries, max_len=512, batch_size=8)
112
+
113
+ # Global embedding (derived): mean of 9 section vectors then L2-normalized
114
+ g = out.global_embedding # [N, H]
115
+
116
+ # Per-section embeddings (by full name)
117
+ lung = out.by_section_name["Lungs and Airways"] # [N, H]
118
+ imp = out.by_section_name["impression"] # [N, H]
119
+
120
+ # Or use aliases (case-insensitive)
121
+ lung = out.by_alias["lungs"] # [N, H]
122
+ cardio = out.by_alias["cardio"] # [N, H]
123
+ ```
124
+
125
+ ### Candidate Embeddings (Retrieval Bank)
126
+
127
+ ```python
128
+ candidates = [
129
+ "Lungs are clear. No focal consolidation.",
130
+ "Pleural effusion on the left.",
131
+ "Cardiomediastinal silhouette is normal."
132
+ ]
133
+
134
+ cand_out = m.embed_texts(candidates, max_len=512, batch_size=16)
135
+
136
+ cand_global = cand_out.global_embedding # [N, H]
137
+ cand_lung = cand_out.by_alias["lungs"] # [N, H]
138
+ ```
139
+
140
+ ### Retrieval Example (Cosine Top-K)
141
+
142
+ ```python
143
+ # Query embeddings for "Lungs and Airways" section
144
+ q = out.by_alias["lungs"] # [Nq, H]
145
+
146
+ # Document embeddings for "Lungs and Airways" section
147
+ d = cand_out.by_alias["lungs"] # [Nd, H]
148
+
149
+ # Compute top-k cosine similarities
150
+ scores, idx = Chest2Vec.cosine_topk(q, d, k=5, device="cuda")
151
+ # scores: [Nq, k] - similarity scores
152
+ # idx: [Nq, k] - indices of top-k candidates
153
+
154
+ print(f"Top-5 scores: {scores[0]}")
155
+ print(f"Top-5 indices: {idx[0]}")
156
+ ```
157
+
158
+ ## API Reference
159
+
160
+ ### `Chest2Vec.from_pretrained()`
161
+
162
+ Load the model from Hugging Face Hub or local path.
163
+
164
+ ```python
165
+ m = Chest2Vec.from_pretrained(
166
+ repo_id_or_path: str, # Hugging Face repo ID or local path
167
+ device: str = "cuda:0", # Device to load model on
168
+ use_4bit: bool = False, # Use 4-bit quantization
169
+ force_flash_attention_2: bool = True
170
+ )
171
+ ```
172
+
173
+ ### `embed_instruction_query()`
174
+
175
+ Embed instruction-query pairs. Returns `EmbedOutput` with:
176
+ - `section_matrix`: `[N, 9, H]` - embeddings for all 9 sections
177
+ - `global_embedding`: `[N, H]` - global embedding (mean of sections, L2-normalized)
178
+ - `by_section_name`: Dict mapping full section names to `[N, H]` tensors
179
+ - `by_alias`: Dict mapping aliases to `[N, H]` tensors
180
+
181
+ ```python
182
+ out = m.embed_instruction_query(
183
+ instructions: List[str],
184
+ queries: List[str],
185
+ max_len: int = 512,
186
+ batch_size: int = 16
187
+ )
188
+ ```
189
+
190
+ ### `embed_texts()`
191
+
192
+ Embed plain texts (for document/candidate encoding).
193
+
194
+ ```python
195
+ out = m.embed_texts(
196
+ texts: List[str],
197
+ max_len: int = 512,
198
+ batch_size: int = 16
199
+ )
200
+ ```
201
+
202
+ ### `cosine_topk()`
203
+
204
+ Static method for efficient top-k cosine similarity search.
205
+
206
+ ```python
207
+ scores, idx = Chest2Vec.cosine_topk(
208
+ query_emb: torch.Tensor, # [Nq, H]
209
+ cand_emb: torch.Tensor, # [Nd, H]
210
+ k: int = 10,
211
+ device: str = "cuda"
212
+ )
213
+ ```
214
+
215
+ ## Model Files
216
+
217
+ - `chest2vec.py` - Model class and inference utilities
218
+ - `chest2vec_config.json` - Model configuration
219
+ - `section_pooler.pt` - Stage-3 pooler weights
220
+ - `section_pooler_config.json` - Pooler configuration
221
+ - `contrastive/` - Stage-2 LoRA adapter directory
222
+ - `adapter_config.json` - LoRA adapter configuration
223
+ - `adapter_model.safetensors` - LoRA adapter weights
224
+
225
+ ## Citation
226
+
227
+ If you use this model, please cite:
228
+
229
+ ```bibtex
230
+ @misc{chest2vec_0.6b_cxr,
231
+ title={Chest2Vec: Section-Aware Embeddings for Chest X-Ray Reports},
232
+ author={Your Name},
233
+ year={2024},
234
+ howpublished={\url{https://huggingface.co/chest2vec/chest2vec_0.6b_cxr}}
235
+ }
236
+ ```
237
+
238
+ ## License
239
+
240
+ [Specify your license here]
chest2vec.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ chest2vec.py
3
+ pyproject.toml
4
+ setup.py
5
+ chest2vec.egg-info/PKG-INFO
6
+ chest2vec.egg-info/SOURCES.txt
7
+ chest2vec.egg-info/dependency_links.txt
8
+ chest2vec.egg-info/requires.txt
9
+ chest2vec.egg-info/top_level.txt
chest2vec.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
chest2vec.egg-info/requires.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers==4.57.3
2
+ trl==0.9.3
3
+ deepspeed==0.16.9
4
+ peft
5
+ huggingface_hub
6
+ bitsandbytes
7
+ accelerate
8
+ numpy
chest2vec.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ chest2vec
dist/chest2vec-0.6.0-py3-none-any.whl ADDED
Binary file (9.17 kB). View file
 
dist/chest2vec-0.6.0.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3a3a80bb958164a268395e79ce37d240639144ff2175a23fe188fa6991c6051
3
+ size 9487
install_deps.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Installation script for chest2vec dependencies
3
+ # This script installs PyTorch and flash-attention with the correct versions
4
+
5
+ set -e
6
+
7
+ echo "Installing PyTorch packages from custom index..."
8
+ pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
9
+
10
+ echo "Installing flash-attention from GitHub release..."
11
+ pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
12
+
13
+ echo "Installing chest2vec package..."
14
+ pip install chest2vec
15
+
16
+ echo "Installation complete!"
17
+
pyproject.toml CHANGED
@@ -9,9 +9,6 @@ description = "Section-aware embeddings for chest X-ray reports"
9
  readme = "README.md"
10
  requires-python = ">=3.8"
11
  dependencies = [
12
- "torch==2.6.0",
13
- "torchvision==0.21.0",
14
- "torchaudio==2.6.0",
15
  "transformers==4.57.3",
16
  "trl==0.9.3",
17
  "deepspeed==0.16.9",
@@ -22,11 +19,6 @@ dependencies = [
22
  "numpy",
23
  ]
24
 
25
- [project.optional-dependencies]
26
- flash-attn = [
27
- "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl",
28
- ]
29
-
30
  [project.urls]
31
  Homepage = "https://github.com/chest2vec/chest2vec"
32
 
 
9
  readme = "README.md"
10
  requires-python = ">=3.8"
11
  dependencies = [
 
 
 
12
  "transformers==4.57.3",
13
  "trl==0.9.3",
14
  "deepspeed==0.16.9",
 
19
  "numpy",
20
  ]
21
 
 
 
 
 
 
22
  [project.urls]
23
  Homepage = "https://github.com/chest2vec/chest2vec"
24
 
setup.py CHANGED
@@ -1,42 +1,10 @@
1
  from setuptools import setup, find_packages
2
- from setuptools.command.install import install
3
  from pathlib import Path
4
- import subprocess
5
- import sys
6
 
7
  # Read README for long description
8
  readme_file = Path(__file__).parent / "README.md"
9
  long_description = readme_file.read_text(encoding="utf-8") if readme_file.exists() else ""
10
 
11
-
12
- class CustomInstall(install):
13
- """Custom install command that installs PyTorch from custom index first."""
14
-
15
- def run(self):
16
- # Install PyTorch packages from custom index first
17
- pytorch_packages = [
18
- "torch==2.6.0",
19
- "torchvision==0.21.0",
20
- "torchaudio==2.6.0",
21
- ]
22
-
23
- print("Installing PyTorch packages from custom index...")
24
- subprocess.check_call([
25
- sys.executable, "-m", "pip", "install",
26
- "--index-url", "https://download.pytorch.org/whl/cu126"
27
- ] + pytorch_packages)
28
-
29
- # Install flash-attention from GitHub release
30
- print("Installing flash-attention from GitHub release...")
31
- flash_attn_url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl"
32
- subprocess.check_call([
33
- sys.executable, "-m", "pip", "install", flash_attn_url
34
- ])
35
-
36
- # Now run the standard install
37
- install.run(self)
38
-
39
-
40
  setup(
41
  name="chest2vec",
42
  version="0.6.0",
@@ -47,13 +15,12 @@ setup(
47
  url="https://github.com/chest2vec/chest2vec",
48
  packages=find_packages(),
49
  py_modules=["chest2vec"],
50
- cmdclass={"install": CustomInstall},
 
51
  install_requires=[
52
- # PyTorch packages are installed separately in CustomInstall
53
  "transformers==4.57.3",
54
  "trl==0.9.3",
55
  "deepspeed==0.16.9",
56
- # Flash attention is installed separately in CustomInstall
57
  "peft",
58
  "huggingface_hub",
59
  "bitsandbytes",
 
1
  from setuptools import setup, find_packages
 
2
  from pathlib import Path
 
 
3
 
4
  # Read README for long description
5
  readme_file = Path(__file__).parent / "README.md"
6
  long_description = readme_file.read_text(encoding="utf-8") if readme_file.exists() else ""
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  setup(
9
  name="chest2vec",
10
  version="0.6.0",
 
15
  url="https://github.com/chest2vec/chest2vec",
16
  packages=find_packages(),
17
  py_modules=["chest2vec"],
18
+ include_package_data=True,
19
+ package_data={"": ["__init__.py"]},
20
  install_requires=[
 
21
  "transformers==4.57.3",
22
  "trl==0.9.3",
23
  "deepspeed==0.16.9",
 
24
  "peft",
25
  "huggingface_hub",
26
  "bitsandbytes",
test.ipynb ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "5a6c76f2",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Chest2VEC"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "26215417",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "from chest2vec import Chest2Vec\n",
19
+ "import os\n",
20
+ "os.environ[\"HF_HOME\"] = \"/model/huggingface\"\n",
21
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
22
+ "\n",
23
+ "\n",
24
+ "\n",
25
+ "m = Chest2Vec.from_pretrained(\"chest2vec/chest2vec_0.6b_cxr\", device=\"cuda:0\")"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": 4,
31
+ "id": "624ad061",
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "instructions = [\"Find findings about the lungs.\"]\n",
36
+ "queries = [\"Consolidation in the right lower lobe.\"]\n",
37
+ "\n",
38
+ "out = m.embed_instruction_query(instructions, queries, max_len=512, batch_size=8)\n",
39
+ "\n",
40
+ "# Global embedding (derived): mean of 9 section vectors then L2-normalized\n",
41
+ "g = out.global_embedding # [N, H]\n",
42
+ "\n",
43
+ "# Per-section embeddings (by full name)\n",
44
+ "lung = out.by_section_name[\"Lungs and Airways\"] # [N, H]\n",
45
+ "imp = out.by_section_name[\"impression\"] # [N, H]\n",
46
+ "\n",
47
+ "# Or use aliases (case-insensitive)\n",
48
+ "lung = out.by_alias[\"lungs\"] # [N, H]\n",
49
+ "cardio = out.by_alias[\"cardio\"] # [N, H]\n"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 10,
55
+ "id": "b083b9a8",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "candidates = [\n",
60
+ " \"Lungs are clear. No focal consolidation.\",\n",
61
+ " \"Pleural effusion on the left.\",\n",
62
+ " \"Right lower lobe consolidation.\",\n",
63
+ " \"Cardiomediastinal silhouette is normal.\"\n",
64
+ "]\n",
65
+ "\n",
66
+ "cand_out = m.embed_texts(candidates, max_len=512, batch_size=16)\n",
67
+ "\n",
68
+ "cand_global = cand_out.global_embedding # [N, H]\n",
69
+ "cand_lung = cand_out.by_alias[\"lungs\"] # [N, H]\n"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": 15,
75
+ "id": "98ebf6d5",
76
+ "metadata": {},
77
+ "outputs": [
78
+ {
79
+ "name": "stdout",
80
+ "output_type": "stream",
81
+ "text": [
82
+ "Top-5 scores: tensor([ 0.3646, -0.0407, -0.0810, -0.1504])\n",
83
+ "Top-5 indices: tensor([2, 1, 3, 0])\n"
84
+ ]
85
+ }
86
+ ],
87
+ "source": [
88
+ "# Query embeddings for \"Lungs and Airways\" section\n",
89
+ "q = out.global_embedding # [Nq, H]\n",
90
+ "\n",
91
+ "# Document embeddings for \"Lungs and Airways\" section\n",
92
+ "d = cand_out.global_embedding # [Nd, H]\n",
93
+ "\n",
94
+ "# Compute top-k cosine similarities\n",
95
+ "scores, idx = Chest2Vec.cosine_topk(q, d, k=5, device=\"cuda\")\n",
96
+ "# scores: [Nq, k] - similarity scores\n",
97
+ "# idx: [Nq, k] - indices of top-k candidates\n",
98
+ "\n",
99
+ "print(f\"Top-5 scores: {scores[0]}\")\n",
100
+ "print(f\"Top-5 indices: {idx[0]}\")\n"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "markdown",
105
+ "id": "906d89b8",
106
+ "metadata": {},
107
+ "source": [
108
+ "## CT"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "id": "347dd738",
115
+ "metadata": {},
116
+ "outputs": [
117
+ {
118
+ "data": {
119
+ "text/plain": [
120
+ "EmbedOutput(section_matrix=tensor([[[ 6.9175e-02, -1.6013e-04, -1.1102e-02, ..., -3.0460e-02,\n",
121
+ " -5.8357e-02, 3.6722e-02],\n",
122
+ " [ 7.2193e-02, -6.4974e-04, -1.3356e-02, ..., -2.9022e-02,\n",
123
+ " -6.0931e-02, 3.6963e-02],\n",
124
+ " [ 7.2526e-02, -2.2293e-03, -1.8355e-02, ..., -2.7643e-02,\n",
125
+ " -6.0637e-02, 4.1613e-02],\n",
126
+ " ...,\n",
127
+ " [ 7.4716e-02, -3.1762e-03, -2.3746e-02, ..., -1.9208e-02,\n",
128
+ " -5.9592e-02, 4.8399e-02],\n",
129
+ " [ 7.0074e-02, -5.4069e-05, -1.2760e-02, ..., -2.8404e-02,\n",
130
+ " -5.8251e-02, 3.8930e-02],\n",
131
+ " [ 7.5282e-02, -3.2632e-03, -2.3526e-02, ..., -1.8972e-02,\n",
132
+ " -6.0407e-02, 4.7962e-02]]]), global_embedding=tensor([[ 0.0731, -0.0017, -0.0180, ..., -0.0245, -0.0603, 0.0423]]), by_section_name={'Lungs and Airways': tensor([[ 0.0692, -0.0002, -0.0111, ..., -0.0305, -0.0584, 0.0367]]), 'Pleura': tensor([[ 0.0722, -0.0006, -0.0134, ..., -0.0290, -0.0609, 0.0370]]), 'Cardiovascular': tensor([[ 0.0725, -0.0022, -0.0184, ..., -0.0276, -0.0606, 0.0416]]), 'Hila and Mediastinum': tensor([[ 0.0749, -0.0023, -0.0224, ..., -0.0191, -0.0601, 0.0463]]), 'Tubes & Devices': tensor([[ 0.0730, -0.0007, -0.0161, ..., -0.0234, -0.0616, 0.0395]]), 'Musculoskeletal and Chest Wall': tensor([[ 0.0740, -0.0023, -0.0202, ..., -0.0237, -0.0614, 0.0432]]), 'Abdominal': tensor([[ 0.0747, -0.0032, -0.0237, ..., -0.0192, -0.0596, 0.0484]]), 'impression': tensor([[ 7.0074e-02, -5.4069e-05, -1.2760e-02, ..., -2.8404e-02,\n",
133
+ " -5.8251e-02, 3.8930e-02]]), 'Other': tensor([[ 0.0753, -0.0033, -0.0235, ..., -0.0190, -0.0604, 0.0480]])}, by_alias={'global': tensor([[ 0.0731, -0.0017, -0.0180, ..., -0.0245, -0.0603, 0.0423]]), 'lungs': tensor([[ 0.0692, -0.0002, -0.0111, ..., -0.0305, -0.0584, 0.0367]]), 'lung': tensor([[ 0.0692, -0.0002, -0.0111, ..., -0.0305, -0.0584, 0.0367]]), 'pleura': tensor([[ 0.0722, -0.0006, -0.0134, ..., -0.0290, -0.0609, 0.0370]]), 'cardio': tensor([[ 0.0725, -0.0022, -0.0184, ..., -0.0276, -0.0606, 0.0416]]), 'cardiovascular': tensor([[ 0.0725, -0.0022, -0.0184, ..., -0.0276, -0.0606, 0.0416]]), 'hila': tensor([[ 0.0749, -0.0023, -0.0224, ..., -0.0191, -0.0601, 0.0463]]), 'mediastinum': tensor([[ 0.0749, -0.0023, -0.0224, ..., -0.0191, -0.0601, 0.0463]]), 'tubes': tensor([[ 0.0730, -0.0007, -0.0161, ..., -0.0234, -0.0616, 0.0395]]), 'devices': tensor([[ 0.0730, -0.0007, -0.0161, ..., -0.0234, -0.0616, 0.0395]]), 'msk': tensor([[ 0.0740, -0.0023, -0.0202, ..., -0.0237, -0.0614, 0.0432]]), 'musculoskeletal': tensor([[ 0.0740, -0.0023, -0.0202, ..., -0.0237, -0.0614, 0.0432]]), 'abd': tensor([[ 0.0747, -0.0032, -0.0237, ..., -0.0192, -0.0596, 0.0484]]), 'abdominal': tensor([[ 0.0747, -0.0032, -0.0237, ..., -0.0192, -0.0596, 0.0484]]), 'impression': tensor([[ 7.0074e-02, -5.4069e-05, -1.2760e-02, ..., -2.8404e-02,\n",
134
+ " -5.8251e-02, 3.8930e-02]]), 'other': tensor([[ 0.0753, -0.0033, -0.0235, ..., -0.0190, -0.0604, 0.0480]])})"
135
+ ]
136
+ },
137
+ "execution_count": 16,
138
+ "metadata": {},
139
+ "output_type": "execute_result"
140
+ }
141
+ ],
142
+ "source": [
143
+ "# !pip install nibabel, monai\n",
144
+ "import numpy as np\n",
145
+ "from pathlib import Path\n",
146
+ "import matplotlib.pyplot as plt\n",
147
+ "\n",
148
+ "# Optional (for nicer overlays). If scipy isn't installed, code will fall back.\n",
149
+ "try:\n",
150
+ " from scipy.ndimage import binary_erosion\n",
151
+ " _HAS_SCIPY = True\n",
152
+ "except Exception:\n",
153
+ " _HAS_SCIPY = False\n",
154
+ "\n",
155
+ "# Point this to your preprocessed folder\n",
156
+ "NPZ_ROOT = Path(\"./data/preprocessed\") # <-- EDIT\n"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "id": "edbddaf7",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "def load_npz_case(npz_path: Path):\n",
167
+ " npz_path = Path(npz_path)\n",
168
+ " with np.load(npz_path, allow_pickle=False) as z:\n",
169
+ " keys = list(z.keys())\n",
170
+ "\n",
171
+ " # Support a few common key names\n",
172
+ " if \"ct\" in keys:\n",
173
+ " ct = z[\"ct\"]\n",
174
+ " elif \"image\" in keys:\n",
175
+ " ct = z[\"image\"]\n",
176
+ " else:\n",
177
+ " raise KeyError(f\"No CT key found. Available keys: {keys}\")\n",
178
+ "\n",
179
+ " rex = z[\"rex\"] if \"rex\" in keys else None\n",
180
+ " tot = z[\"totalseg\"] if \"totalseg\" in keys else (z[\"label\"] if \"label\" in keys else None)\n",
181
+ "\n",
182
+ " # Basic sanity checks\n",
183
+ " assert ct.ndim == 4 and ct.shape[0] == 1, f\"Expected ct shape (1,D,H,W), got {ct.shape}\"\n",
184
+ " D, H, W = ct.shape[1], ct.shape[2], ct.shape[3]\n",
185
+ "\n",
186
+ " if tot is not None:\n",
187
+ " assert tot.ndim == 4 and tot.shape[0] == 1, f\"Expected totalseg shape (1,D,H,W), got {tot.shape}\"\n",
188
+ " assert tot.shape[1:] == (D, H, W), f\"totalseg spatial mismatch: {tot.shape} vs ct {ct.shape}\"\n",
189
+ "\n",
190
+ " if rex is not None:\n",
191
+ " assert rex.ndim == 4, f\"Expected rex shape (F,D,H,W), got {rex.shape}\"\n",
192
+ " assert rex.shape[1:] == (D, H, W), f\"rex spatial mismatch: {rex.shape} vs ct {ct.shape}\"\n",
193
+ "\n",
194
+ " return ct, rex, tot, keys\n",
195
+ "\n",
196
+ "# List files\n",
197
+ "npz_files = sorted(NPZ_ROOT.rglob(\"*.npz\"))\n",
198
+ "print(\"Found npz files:\", len(npz_files))\n",
199
+ "print(\"Example:\", npz_files[0] if npz_files else \"NONE\")\n",
200
+ "\n",
201
+ "# Pick one (edit index or set by name)\n",
202
+ "case_path = npz_files[0] # <-- change to inspect a specific file\n",
203
+ "ct, rex, tot, keys = load_npz_case(case_path)\n",
204
+ "\n",
205
+ "print(\"Loaded:\", case_path.name)\n",
206
+ "print(\"Keys:\", keys)\n",
207
+ "print(\"CT:\", ct.shape, ct.dtype, f\"min={ct.min():.3f}, max={ct.max():.3f}\")\n",
208
+ "print(\"Rex:\", None if rex is None else (rex.shape, rex.dtype, f\"channels={rex.shape[0]}\"))\n",
209
+ "print(\"Tot:\", None if tot is None else (tot.shape, tot.dtype))\n"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "id": "c1d10b70",
216
+ "metadata": {},
217
+ "outputs": [],
218
+ "source": [
219
+ "def choose_rex_channel(rex_arr: np.ndarray):\n",
220
+ " \"\"\"\n",
221
+ " Returns (best_channel_index, counts_per_channel)\n",
222
+ " counts = number of voxels > 0 in each channel\n",
223
+ " \"\"\"\n",
224
+ " if rex_arr is None:\n",
225
+ " return None, None\n",
226
+ " counts = (rex_arr > 0).reshape(rex_arr.shape[0], -1).sum(axis=1)\n",
227
+ " best = int(np.argmax(counts))\n",
228
+ " return best, counts\n",
229
+ "\n",
230
+ "rex_ch, rex_counts = choose_rex_channel(rex)\n",
231
+ "if rex is not None:\n",
232
+ " print(\"Top 10 ReX channels by voxel count:\")\n",
233
+ " top = np.argsort(-rex_counts)[:10]\n",
234
+ " for i in top:\n",
235
+ " print(f\" ch={int(i):4d} voxels={int(rex_counts[i])}\")\n",
236
+ " print(\"Auto-selected channel:\", rex_ch)\n",
237
+ "else:\n",
238
+ " print(\"No ReX mask in this NPZ.\")\n"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": null,
244
+ "id": "3c75853f",
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": [
248
+ "def mask_edges_2d(m2d: np.ndarray) -> np.ndarray:\n",
249
+ " \"\"\"Thin-ish edge for 2D mask.\"\"\"\n",
250
+ " m2d = (m2d > 0)\n",
251
+ " if not _HAS_SCIPY:\n",
252
+ " return m2d.astype(np.uint8) # fallback: filled mask\n",
253
+ " er = binary_erosion(m2d)\n",
254
+ " return (m2d ^ er).astype(np.uint8)\n",
255
+ "\n",
256
+ "def top_slices_by_area(mask_3d: np.ndarray, topk: int = 8):\n",
257
+ " \"\"\"\n",
258
+ " mask_3d: (D,H,W) boolean/int\n",
259
+ " returns list of axial slice indices with largest mask area\n",
260
+ " \"\"\"\n",
261
+ " areas = (mask_3d > 0).sum(axis=(1,2))\n",
262
+ " idx = np.argsort(-areas)[:topk]\n",
263
+ " return [int(i) for i in idx if areas[i] > 0], areas\n",
264
+ "\n",
265
+ "# Build binary masks for display\n",
266
+ "ct_vol = ct[0] # (D,H,W)\n",
267
+ "rex_mask = None\n",
268
+ "if rex is not None:\n",
269
+ " rex_mask = (rex[rex_ch] > 0) # (D,H,W)\n",
270
+ "\n",
271
+ "tot_mask = None\n",
272
+ "if tot is not None:\n",
273
+ " tot_mask = (tot[0] > 0) # (D,H,W)\n",
274
+ "\n",
275
+ "if rex_mask is not None:\n",
276
+ " idxs, areas = top_slices_by_area(rex_mask, topk=10)\n",
277
+ " print(\"Top axial slices by ReX area:\", idxs[:10])\n",
278
+ "else:\n",
279
+ " print(\"No ReX mask to suggest slices.\")\n",
280
+ "\n",
281
+ "if tot_mask is not None:\n",
282
+ " idxs2, areas2 = top_slices_by_area(tot_mask, topk=10)\n",
283
+ " print(\"Top axial slices by TotalSeg area:\", idxs2[:10])\n",
284
+ "else:\n",
285
+ " print(\"No TotalSeg mask to suggest slices.\")\n"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": null,
291
+ "id": "683579a8",
292
+ "metadata": {},
293
+ "outputs": [],
294
+ "source": [
295
+ "def show_axial_grid(ct_vol, rex_mask=None, tot_mask=None, slice_indices=None, rex_title=\"ReX\", tot_title=\"TotalSeg\"):\n",
296
+ " \"\"\"\n",
297
+ " ct_vol: (D,H,W) float\n",
298
+ " rex_mask / tot_mask: (D,H,W) bool\n",
299
+ " slice_indices: list[int]\n",
300
+ " \"\"\"\n",
301
+ " if slice_indices is None or len(slice_indices) == 0:\n",
302
+ " slice_indices = [ct_vol.shape[0] // 2]\n",
303
+ "\n",
304
+ " n = len(slice_indices)\n",
305
+ " fig, axes = plt.subplots(nrows=n, ncols=3, figsize=(12, 4*n))\n",
306
+ " if n == 1:\n",
307
+ " axes = np.array([axes])\n",
308
+ "\n",
309
+ " for r, d in enumerate(slice_indices):\n",
310
+ " ct2d = ct_vol[d]\n",
311
+ "\n",
312
+ " # Panel 1: CT\n",
313
+ " ax = axes[r, 0]\n",
314
+ " ax.imshow(ct2d, cmap=\"gray\", origin=\"lower\")\n",
315
+ " ax.set_title(f\"CT (axial d={d})\")\n",
316
+ " ax.axis(\"off\")\n",
317
+ "\n",
318
+ " # Panel 2: CT + ReX\n",
319
+ " ax = axes[r, 1]\n",
320
+ " ax.imshow(ct2d, cmap=\"gray\", origin=\"lower\")\n",
321
+ " if rex_mask is not None:\n",
322
+ " e = mask_edges_2d(rex_mask[d])\n",
323
+ " ax.imshow(e, cmap=\"Reds\", alpha=0.7, origin=\"lower\")\n",
324
+ " ax.set_title(f\"CT + {rex_title}\")\n",
325
+ " ax.axis(\"off\")\n",
326
+ "\n",
327
+ " # Panel 3: CT + TotalSeg\n",
328
+ " ax = axes[r, 2]\n",
329
+ " ax.imshow(ct2d, cmap=\"gray\", origin=\"lower\")\n",
330
+ " if tot_mask is not None:\n",
331
+ " e = mask_edges_2d(tot_mask[d])\n",
332
+ " ax.imshow(e, cmap=\"Blues\", alpha=0.6, origin=\"lower\")\n",
333
+ " ax.set_title(f\"CT + {tot_title}\")\n",
334
+ " ax.axis(\"off\")\n",
335
+ "\n",
336
+ " plt.tight_layout()\n",
337
+ " plt.show()\n",
338
+ "\n",
339
+ "# Choose slices to visualize (prefer slices with ReX content if present)\n",
340
+ "if rex_mask is not None:\n",
341
+ " slices, _ = top_slices_by_area(rex_mask, topk=3)\n",
342
+ " if len(slices) == 0:\n",
343
+ " slices = [ct_vol.shape[0]//2]\n",
344
+ "else:\n",
345
+ " slices = [ct_vol.shape[0]//2]\n",
346
+ "\n",
347
+ "show_axial_grid(ct_vol, rex_mask=rex_mask, tot_mask=tot_mask, slice_indices=slices[:3],\n",
348
+ " rex_title=f\"ReX(ch={rex_ch})\", tot_title=\"TotalSeg\")\n"
349
+ ]
350
+ }
351
+ ],
352
+ "metadata": {
353
+ "kernelspec": {
354
+ "display_name": "Python 3",
355
+ "language": "python",
356
+ "name": "python3"
357
+ },
358
+ "language_info": {
359
+ "codemirror_mode": {
360
+ "name": "ipython",
361
+ "version": 3
362
+ },
363
+ "file_extension": ".py",
364
+ "mimetype": "text/x-python",
365
+ "name": "python",
366
+ "nbconvert_exporter": "python",
367
+ "pygments_lexer": "ipython3",
368
+ "version": "3.10.12"
369
+ }
370
+ },
371
+ "nbformat": 4,
372
+ "nbformat_minor": 5
373
+ }