manu02 commited on
Commit
1ccd21e
·
verified ·
1 Parent(s): 46fc6c1

Upload MIMIC test evaluation results

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/AnatomicalAttention.gif filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: transformers
4
+ pipeline_tag: image-to-text
5
+ tags:
6
+ - medical-ai
7
+ - radiology
8
+ - chest-xray
9
+ - report-generation
10
+ - segmentation
11
+ - anatomical-attention
12
+ metrics:
13
+ - BLEU
14
+ - METEOR
15
+ - ROUGE
16
+ - CIDEr
17
+ ---
18
+
19
+ # LAnA
20
+
21
+ **Layer-Wise Anatomical Attention model**
22
+
23
+ [![ArXiv](https://img.shields.io/badge/ArXiv-2512.16841-B31B1B?logo=arxiv&logoColor=white)](https://arxiv.org/abs/2512.16841)
24
+ [![LinkedIn](https://img.shields.io/badge/LinkedIn-devmuniz-0A66C2?logo=linkedin&logoColor=white)](https://www.linkedin.com/in/devmuniz)
25
+ [![GitHub Profile](https://img.shields.io/badge/GitHub-devMuniz02-181717?logo=github&logoColor=white)](https://github.com/devMuniz02)
26
+ [![Portfolio](https://img.shields.io/badge/Portfolio-devmuniz02.github.io-0F172A?logo=googlechrome&logoColor=white)](https://devmuniz02.github.io/)
27
+ [![GitHub Repo](https://img.shields.io/badge/Repository-layer--wise--anatomical--attention-181717?logo=github&logoColor=white)](https://github.com/devMuniz02/layer-wise-anatomical-attention)
28
+ [![Hugging Face](https://img.shields.io/badge/Hugging%20Face-manu02-FFD21E?logoColor=black)](https://huggingface.co/manu02)
29
+
30
+ ![Layer-Wise Anatomical Attention](assets/AnatomicalAttention.gif)
31
+
32
+ ## Overview
33
+
34
+ LAnA is a medical report-generation project for chest X-ray images. The completed project is intended to generate radiology reports with a vision-language model guided by layer-wise anatomical attention built from predicted anatomical masks.
35
+
36
+ The architecture combines a DINOv3 vision encoder, lung and heart segmentation heads, and a GPT-2 decoder modified so each transformer layer receives a different anatomical attention bias derived from the segmentation mask.
37
+
38
+ ## How to Run
39
+
40
+ Standard `AutoModel.from_pretrained(..., trust_remote_code=True)` loading is currently blocked for this repo because the custom model constructor performs nested pretrained submodel loads.
41
+ Use the verified manual load path below instead: download the HF repo snapshot, import the downloaded package, and load the exported `model.safetensors` directly.
42
+ You must set an `HF_TOKEN` environment variable with permission to access the DINOv3 model repositories used by this project, otherwise the required vision backbones cannot be downloaded.
43
+
44
+ ```python
45
+ from pathlib import Path
46
+ import sys
47
+
48
+ import numpy as np
49
+ import torch
50
+ from PIL import Image
51
+ from huggingface_hub import snapshot_download
52
+ from safetensors.torch import load_file
53
+ from transformers import AutoTokenizer
54
+
55
+ repo_dir = Path(snapshot_download("manu02/LAnA"))
56
+ sys.path.insert(0, str(repo_dir))
57
+
58
+ from lana_radgen import LanaConfig, LanaForConditionalGeneration
59
+
60
+ config = LanaConfig.from_pretrained(repo_dir)
61
+ config.lung_segmenter_checkpoint = str(repo_dir / "segmenters" / "lung_segmenter_dinounet_finetuned.pth")
62
+ config.heart_segmenter_checkpoint = str(repo_dir / "segmenters" / "heart_segmenter_dinounet_best.pth")
63
+
64
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+
66
+ model = LanaForConditionalGeneration(config)
67
+ state_dict = load_file(str(repo_dir / "model.safetensors"))
68
+ missing, unexpected = model.load_state_dict(state_dict, strict=True)
69
+ assert not missing and not unexpected
70
+
71
+ model.tokenizer = AutoTokenizer.from_pretrained(repo_dir, trust_remote_code=True)
72
+ model.move_non_quantized_modules(device)
73
+ model.eval()
74
+
75
+ image_path = Path("example.png")
76
+ image = Image.open(image_path).convert("RGB")
77
+ image = image.resize((512, 512), resample=Image.BICUBIC)
78
+ array = np.asarray(image, dtype=np.float32) / 255.0
79
+ pixel_values = torch.from_numpy(array).permute(2, 0, 1)
80
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
81
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
82
+ pixel_values = ((pixel_values - mean) / std).unsqueeze(0).to(device)
83
+
84
+ with torch.no_grad():
85
+ generated = model.generate(pixel_values=pixel_values, max_new_tokens=128)
86
+
87
+ report = model.tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
88
+ print(report)
89
+ ```
90
+
91
+ ## Intended Use
92
+
93
+ - Input: a chest X-ray image resized to `512x512` and normalized with ImageNet mean/std.
94
+ - Output: a generated radiology report.
95
+ - Best fit: research use, report-generation experiments, and anatomical-attention ablations.
96
+
97
+ ## MIMIC Test Results
98
+
99
+ Frontal-only evaluation using `PA/AP` studies only.
100
+
101
+ ### Current Checkpoint Results
102
+
103
+ ### All Frontal Test Studies
104
+
105
+ | Metric | Value |
106
+ | --- | --- |
107
+ | Number of studies | `3041` |
108
+ | ROUGE-L | `0.1641` |
109
+ | BLEU-1 | `0.2243` |
110
+ | BLEU-4 | `0.0383` |
111
+ | METEOR | `0.2005` |
112
+ | RadGraph F1 | `0.0941` |
113
+ | RadGraph entity F1 | `0.1819` |
114
+ | RadGraph relation F1 | `0.1652` |
115
+ | CheXpert F1 14-micro | `0.1245` |
116
+ | CheXpert F1 5-micro | `0.2190` |
117
+ | CheXpert F1 14-macro | `0.0443` |
118
+ | CheXpert F1 5-macro | `0.0991` |
119
+
120
+ ### Findings-Only Frontal Test Studies
121
+
122
+ | Metric | Value |
123
+ | --- | --- |
124
+ | Number of studies | `2210` |
125
+ | ROUGE-L | `0.1721` |
126
+ | BLEU-1 | `0.2310` |
127
+ | BLEU-4 | `0.0429` |
128
+ | METEOR | `0.2125` |
129
+ | RadGraph F1 | `0.1017` |
130
+ | RadGraph entity F1 | `0.1922` |
131
+ | RadGraph relation F1 | `0.1741` |
132
+ | CheXpert F1 14-micro | `0.1166` |
133
+ | CheXpert F1 5-micro | `0.2071` |
134
+ | CheXpert F1 14-macro | `0.0406` |
135
+ | CheXpert F1 5-macro | `0.0920` |
136
+
137
+ ### Final Completed Training Results
138
+
139
+ The final table will be populated when the planned training run is completed. Until then, final-report metrics remain `TBD`.
140
+
141
+ | Metric | Value |
142
+ | --- | --- |
143
+ | Number of studies | TBD |
144
+ | RadGraph F1 | TBD |
145
+ | RadGraph entity F1 | TBD |
146
+ | RadGraph relation F1 | TBD |
147
+ | CheXpert F1 14-micro | TBD |
148
+ | CheXpert F1 5-micro | TBD |
149
+ | CheXpert F1 14-macro | TBD |
150
+ | CheXpert F1 5-macro | TBD |
151
+
152
+ ## Data
153
+
154
+ - Full project datasets: CheXpert and MIMIC-CXR.
155
+ - Intended project scope: train on curated chest X-ray/report data from both datasets and evaluate on MIMIC-CXR test studies.
156
+ - Current released checkpoint datasets: `MIMIC-CXR (findings-only)` for training and `MIMIC-CXR (findings-only)` for validation.
157
+ - Current published evaluation: MIMIC-CXR test split, `frontal-only (PA/AP)` studies.
158
+
159
+ ## Evaluation
160
+
161
+ - Medical report metrics implemented in the repository include RadGraph F1 and CheXpert F1 (`14-micro`, `5-micro`, `14-macro`, `5-macro`).
162
+
163
+ ## Training Snapshot
164
+
165
+ - Run: `mimic only`
166
+ - This section describes the current public checkpoint, not the final completed project.
167
+ - Method: `lora_adamw`
168
+ - Vision encoder: `facebook/dinov3-vits16-pretrain-lvd1689m`
169
+ - Text decoder: `gpt2`
170
+ - Segmentation encoder: `facebook/dinov3-convnext-small-pretrain-lvd1689m`
171
+ - Image size: `512`
172
+ - Local batch size: `1`
173
+ - Effective global batch size: `8`
174
+ - Scheduler: `cosine`
175
+ - Warmup steps: `2636`
176
+ - Weight decay: `0.01`
177
+ - Steps completed: `6692`
178
+ - Planned total steps: `52716`
179
+ - Images seen: `53540`
180
+ - Total training time: `1.0000` hours
181
+ - Hardware: `NVIDIA GeForce RTX 5070`
182
+ - Final train loss: `0.6296`
183
+ - Validation loss: `2.3133`
184
+
185
+ ## Status
186
+
187
+ - Project status: `Training in progress`
188
+ - Release status: `Research preview checkpoint`
189
+ - Current checkpoint status: `Not final`
190
+ - Training completion toward planned run: `12.70%` (`0` / `3` epochs)
191
+ - Current published metrics are intermediate and will change as training continues.
192
+
193
+ ## Notes
194
+
195
+ - Set `HF_TOKEN` with permission to access the DINOv3 repositories required by this model before downloading or running inference.
196
+ - `segmenters/` contains the lung and heart segmentation checkpoints used to build anatomical attention masks.
197
+ - `evaluations/mimic_test_metrics.json` contains the latest saved MIMIC test metrics.
assets/AnatomicalAttention.gif ADDED

Git LFS Details

  • SHA256: 3854885a631419336dca34b3375e29be91597a994349e3abdf1460e6908ec391
  • Pointer size: 133 Bytes
  • Size of remote file: 27.3 MB
benchmark_results.json ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "results": [
3
+ {
4
+ "method": "qlora_paged_adamw8bit",
5
+ "local_batch_size": 1,
6
+ "global_batch_size_requested": 1,
7
+ "status": "failed",
8
+ "error": "element 0 of tensors does not require grad and does not have a grad_fn"
9
+ },
10
+ {
11
+ "method": "qlora_paged_adamw8bit",
12
+ "local_batch_size": 1,
13
+ "global_batch_size_requested": 8,
14
+ "status": "failed",
15
+ "error": "element 0 of tensors does not require grad and does not have a grad_fn"
16
+ },
17
+ {
18
+ "method": "qlora_paged_adamw8bit",
19
+ "local_batch_size": 1,
20
+ "global_batch_size_requested": 16,
21
+ "status": "failed",
22
+ "error": "element 0 of tensors does not require grad and does not have a grad_fn"
23
+ },
24
+ {
25
+ "method": "qlora_paged_adamw8bit",
26
+ "local_batch_size": 2,
27
+ "global_batch_size_requested": 2,
28
+ "status": "failed",
29
+ "error": "element 0 of tensors does not require grad and does not have a grad_fn"
30
+ },
31
+ {
32
+ "method": "qlora_paged_adamw8bit",
33
+ "local_batch_size": 2,
34
+ "global_batch_size_requested": 8,
35
+ "status": "failed",
36
+ "error": "element 0 of tensors does not require grad and does not have a grad_fn"
37
+ },
38
+ {
39
+ "method": "qlora_paged_adamw8bit",
40
+ "local_batch_size": 2,
41
+ "global_batch_size_requested": 16,
42
+ "status": "failed",
43
+ "error": "element 0 of tensors does not require grad and does not have a grad_fn"
44
+ },
45
+ {
46
+ "method": "qlora_paged_adamw8bit",
47
+ "local_batch_size": 4,
48
+ "global_batch_size_requested": 4,
49
+ "status": "failed",
50
+ "error": "element 0 of tensors does not require grad and does not have a grad_fn"
51
+ },
52
+ {
53
+ "method": "qlora_paged_adamw8bit",
54
+ "local_batch_size": 4,
55
+ "global_batch_size_requested": 8,
56
+ "status": "failed",
57
+ "error": "element 0 of tensors does not require grad and does not have a grad_fn"
58
+ },
59
+ {
60
+ "method": "qlora_paged_adamw8bit",
61
+ "local_batch_size": 4,
62
+ "global_batch_size_requested": 16,
63
+ "status": "failed",
64
+ "error": "element 0 of tensors does not require grad and does not have a grad_fn"
65
+ },
66
+ {
67
+ "method": "lora_adamw",
68
+ "local_batch_size": 1,
69
+ "global_batch_size_requested": 1,
70
+ "status": "ok",
71
+ "effective_global_batch_size": 1,
72
+ "gradient_accumulation_steps": 1,
73
+ "optimizer_step_time_sec": 0.12944729999981064,
74
+ "images_per_sec": 7.7251514709187665,
75
+ "mean_loss": 9.920842170715332,
76
+ "trainable_params": 1106688
77
+ },
78
+ {
79
+ "method": "lora_adamw",
80
+ "local_batch_size": 1,
81
+ "global_batch_size_requested": 8,
82
+ "status": "ok",
83
+ "effective_global_batch_size": 8,
84
+ "gradient_accumulation_steps": 8,
85
+ "optimizer_step_time_sec": 0.792737899999338,
86
+ "images_per_sec": 10.091607831550228,
87
+ "mean_loss": 8.131502032279968,
88
+ "trainable_params": 1106688
89
+ },
90
+ {
91
+ "method": "lora_adamw",
92
+ "local_batch_size": 1,
93
+ "global_batch_size_requested": 16,
94
+ "status": "ok",
95
+ "effective_global_batch_size": 16,
96
+ "gradient_accumulation_steps": 16,
97
+ "optimizer_step_time_sec": 1.6773667999987083,
98
+ "images_per_sec": 9.538760395169572,
99
+ "mean_loss": 8.80642619729042,
100
+ "trainable_params": 1106688
101
+ },
102
+ {
103
+ "method": "lora_adamw",
104
+ "local_batch_size": 2,
105
+ "global_batch_size_requested": 2,
106
+ "status": "ok",
107
+ "effective_global_batch_size": 2,
108
+ "gradient_accumulation_steps": 1,
109
+ "optimizer_step_time_sec": 0.20009290000052715,
110
+ "images_per_sec": 9.995357156574427,
111
+ "mean_loss": 9.088608741760254,
112
+ "trainable_params": 1106688
113
+ },
114
+ {
115
+ "method": "lora_adamw",
116
+ "local_batch_size": 2,
117
+ "global_batch_size_requested": 8,
118
+ "status": "ok",
119
+ "effective_global_batch_size": 8,
120
+ "gradient_accumulation_steps": 4,
121
+ "optimizer_step_time_sec": 0.8304937000011705,
122
+ "images_per_sec": 9.63282442719159,
123
+ "mean_loss": 8.245712995529175,
124
+ "trainable_params": 1106688
125
+ },
126
+ {
127
+ "method": "lora_adamw",
128
+ "local_batch_size": 2,
129
+ "global_batch_size_requested": 16,
130
+ "status": "ok",
131
+ "effective_global_batch_size": 16,
132
+ "gradient_accumulation_steps": 8,
133
+ "optimizer_step_time_sec": 1.6668036999981268,
134
+ "images_per_sec": 9.599210752902685,
135
+ "mean_loss": 9.106984257698059,
136
+ "trainable_params": 1106688
137
+ },
138
+ {
139
+ "method": "lora_adamw",
140
+ "local_batch_size": 4,
141
+ "global_batch_size_requested": 4,
142
+ "status": "ok",
143
+ "effective_global_batch_size": 4,
144
+ "gradient_accumulation_steps": 1,
145
+ "optimizer_step_time_sec": 0.4656030999994982,
146
+ "images_per_sec": 8.591008092524106,
147
+ "mean_loss": 8.862140655517578,
148
+ "trainable_params": 1106688
149
+ },
150
+ {
151
+ "method": "lora_adamw",
152
+ "local_batch_size": 4,
153
+ "global_batch_size_requested": 8,
154
+ "status": "ok",
155
+ "effective_global_batch_size": 8,
156
+ "gradient_accumulation_steps": 2,
157
+ "optimizer_step_time_sec": 2.6093234999989363,
158
+ "images_per_sec": 3.0659287742601715,
159
+ "mean_loss": 8.241507053375244,
160
+ "trainable_params": 1106688
161
+ },
162
+ {
163
+ "method": "lora_adamw",
164
+ "local_batch_size": 4,
165
+ "global_batch_size_requested": 16,
166
+ "status": "ok",
167
+ "effective_global_batch_size": 16,
168
+ "gradient_accumulation_steps": 4,
169
+ "optimizer_step_time_sec": 18.058491499999946,
170
+ "images_per_sec": 0.8860097755119827,
171
+ "mean_loss": 8.916554927825928,
172
+ "trainable_params": 1106688
173
+ },
174
+ {
175
+ "method": "full_adam",
176
+ "local_batch_size": 1,
177
+ "global_batch_size_requested": 1,
178
+ "status": "ok",
179
+ "effective_global_batch_size": 1,
180
+ "gradient_accumulation_steps": 1,
181
+ "optimizer_step_time_sec": 1.4309436000003188,
182
+ "images_per_sec": 0.6988395629288094,
183
+ "mean_loss": 8.042855262756348,
184
+ "trainable_params": 125521920
185
+ },
186
+ {
187
+ "method": "full_adam",
188
+ "local_batch_size": 1,
189
+ "global_batch_size_requested": 8,
190
+ "status": "ok",
191
+ "effective_global_batch_size": 8,
192
+ "gradient_accumulation_steps": 8,
193
+ "optimizer_step_time_sec": 2.7121656999988772,
194
+ "images_per_sec": 2.9496722858796245,
195
+ "mean_loss": 7.829526960849762,
196
+ "trainable_params": 125521920
197
+ },
198
+ {
199
+ "method": "full_adam",
200
+ "local_batch_size": 1,
201
+ "global_batch_size_requested": 16,
202
+ "status": "ok",
203
+ "effective_global_batch_size": 16,
204
+ "gradient_accumulation_steps": 16,
205
+ "optimizer_step_time_sec": 1.8378386999993381,
206
+ "images_per_sec": 8.705878268863183,
207
+ "mean_loss": 9.189274996519089,
208
+ "trainable_params": 125521920
209
+ },
210
+ {
211
+ "method": "full_adam",
212
+ "local_batch_size": 2,
213
+ "global_batch_size_requested": 2,
214
+ "status": "ok",
215
+ "effective_global_batch_size": 2,
216
+ "gradient_accumulation_steps": 1,
217
+ "optimizer_step_time_sec": 0.23647629999868514,
218
+ "images_per_sec": 8.457507158269646,
219
+ "mean_loss": 9.128178596496582,
220
+ "trainable_params": 125521920
221
+ },
222
+ {
223
+ "method": "full_adam",
224
+ "local_batch_size": 2,
225
+ "global_batch_size_requested": 8,
226
+ "status": "ok",
227
+ "effective_global_batch_size": 8,
228
+ "gradient_accumulation_steps": 4,
229
+ "optimizer_step_time_sec": 0.8083188999989943,
230
+ "images_per_sec": 9.897083935572896,
231
+ "mean_loss": 8.64337944984436,
232
+ "trainable_params": 125521920
233
+ },
234
+ {
235
+ "method": "full_adam",
236
+ "local_batch_size": 2,
237
+ "global_batch_size_requested": 16,
238
+ "status": "ok",
239
+ "effective_global_batch_size": 16,
240
+ "gradient_accumulation_steps": 8,
241
+ "optimizer_step_time_sec": 1.8274533999974665,
242
+ "images_per_sec": 8.755353214490823,
243
+ "mean_loss": 8.331470370292664,
244
+ "trainable_params": 125521920
245
+ },
246
+ {
247
+ "method": "full_adam",
248
+ "local_batch_size": 4,
249
+ "global_batch_size_requested": 4,
250
+ "status": "ok",
251
+ "effective_global_batch_size": 4,
252
+ "gradient_accumulation_steps": 1,
253
+ "optimizer_step_time_sec": 0.511095199999545,
254
+ "images_per_sec": 7.826330593602838,
255
+ "mean_loss": 8.954268455505371,
256
+ "trainable_params": 125521920
257
+ },
258
+ {
259
+ "method": "full_adam",
260
+ "local_batch_size": 4,
261
+ "global_batch_size_requested": 8,
262
+ "status": "ok",
263
+ "effective_global_batch_size": 8,
264
+ "gradient_accumulation_steps": 2,
265
+ "optimizer_step_time_sec": 2.2738564999981463,
266
+ "images_per_sec": 3.518251921353226,
267
+ "mean_loss": 9.192809581756592,
268
+ "trainable_params": 125521920
269
+ },
270
+ {
271
+ "method": "full_adam",
272
+ "local_batch_size": 4,
273
+ "global_batch_size_requested": 16,
274
+ "status": "ok",
275
+ "effective_global_batch_size": 16,
276
+ "gradient_accumulation_steps": 4,
277
+ "optimizer_step_time_sec": 18.631701800000883,
278
+ "images_per_sec": 0.8587513997244869,
279
+ "mean_loss": 8.159156560897827,
280
+ "trainable_params": 125521920
281
+ },
282
+ {
283
+ "method": "full_adam8bit",
284
+ "local_batch_size": 1,
285
+ "global_batch_size_requested": 1,
286
+ "status": "ok",
287
+ "effective_global_batch_size": 1,
288
+ "gradient_accumulation_steps": 1,
289
+ "optimizer_step_time_sec": 0.13992360000156623,
290
+ "images_per_sec": 7.146757230294293,
291
+ "mean_loss": 9.259998321533203,
292
+ "trainable_params": 125521920
293
+ },
294
+ {
295
+ "method": "full_adam8bit",
296
+ "local_batch_size": 1,
297
+ "global_batch_size_requested": 8,
298
+ "status": "ok",
299
+ "effective_global_batch_size": 8,
300
+ "gradient_accumulation_steps": 8,
301
+ "optimizer_step_time_sec": 0.8451360999988538,
302
+ "images_per_sec": 9.465930990299492,
303
+ "mean_loss": 8.10985803604126,
304
+ "trainable_params": 125521920
305
+ },
306
+ {
307
+ "method": "full_adam8bit",
308
+ "local_batch_size": 1,
309
+ "global_batch_size_requested": 16,
310
+ "status": "ok",
311
+ "effective_global_batch_size": 16,
312
+ "gradient_accumulation_steps": 16,
313
+ "optimizer_step_time_sec": 1.8945816999930685,
314
+ "images_per_sec": 8.445135936897595,
315
+ "mean_loss": 8.591163873672485,
316
+ "trainable_params": 125521920
317
+ },
318
+ {
319
+ "method": "full_adam8bit",
320
+ "local_batch_size": 2,
321
+ "global_batch_size_requested": 2,
322
+ "status": "ok",
323
+ "effective_global_batch_size": 2,
324
+ "gradient_accumulation_steps": 1,
325
+ "optimizer_step_time_sec": 0.23971350000101666,
326
+ "images_per_sec": 8.343293139483249,
327
+ "mean_loss": 9.75894832611084,
328
+ "trainable_params": 125521920
329
+ },
330
+ {
331
+ "method": "full_adam8bit",
332
+ "local_batch_size": 2,
333
+ "global_batch_size_requested": 8,
334
+ "status": "ok",
335
+ "effective_global_batch_size": 8,
336
+ "gradient_accumulation_steps": 4,
337
+ "optimizer_step_time_sec": 0.9259438999997656,
338
+ "images_per_sec": 8.6398322835779,
339
+ "mean_loss": 8.462790489196777,
340
+ "trainable_params": 125521920
341
+ },
342
+ {
343
+ "method": "full_adam8bit",
344
+ "local_batch_size": 2,
345
+ "global_batch_size_requested": 16,
346
+ "status": "ok",
347
+ "effective_global_batch_size": 16,
348
+ "gradient_accumulation_steps": 8,
349
+ "optimizer_step_time_sec": 1.8237968999983423,
350
+ "images_per_sec": 8.772906676184471,
351
+ "mean_loss": 10.191668510437012,
352
+ "trainable_params": 125521920
353
+ },
354
+ {
355
+ "method": "full_adam8bit",
356
+ "local_batch_size": 4,
357
+ "global_batch_size_requested": 4,
358
+ "status": "ok",
359
+ "effective_global_batch_size": 4,
360
+ "gradient_accumulation_steps": 1,
361
+ "optimizer_step_time_sec": 0.5224713000006886,
362
+ "images_per_sec": 7.655922918626779,
363
+ "mean_loss": 8.14057445526123,
364
+ "trainable_params": 125521920
365
+ },
366
+ {
367
+ "method": "full_adam8bit",
368
+ "local_batch_size": 4,
369
+ "global_batch_size_requested": 8,
370
+ "status": "ok",
371
+ "effective_global_batch_size": 8,
372
+ "gradient_accumulation_steps": 2,
373
+ "optimizer_step_time_sec": 3.7809107000011863,
374
+ "images_per_sec": 2.1158923430795364,
375
+ "mean_loss": 8.521550178527832,
376
+ "trainable_params": 125521920
377
+ },
378
+ {
379
+ "method": "full_adam8bit",
380
+ "local_batch_size": 4,
381
+ "global_batch_size_requested": 16,
382
+ "status": "ok",
383
+ "effective_global_batch_size": 16,
384
+ "gradient_accumulation_steps": 4,
385
+ "optimizer_step_time_sec": 27.688971800002037,
386
+ "images_per_sec": 0.5778473868790903,
387
+ "mean_loss": 9.247632026672363,
388
+ "trainable_params": 125521920
389
+ }
390
+ ]
391
+ }
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "anatomical_attention_bias": 2.0,
3
+ "architectures": [
4
+ "LanaForConditionalGeneration"
5
+ ],
6
+ "decoder_compute_dtype": "bfloat16",
7
+ "decoder_load_in_4bit": false,
8
+ "dtype": "float32",
9
+ "freeze_segmenter": true,
10
+ "heart_segmenter_checkpoint": "segmenters/heart_segmenter_dinounet_best.pth",
11
+ "image_size": 512,
12
+ "layer_mask_base_kernel_size": 3,
13
+ "layer_mask_kernel_growth": 2,
14
+ "lung_segmenter_checkpoint": "segmenters/lung_segmenter_dinounet_finetuned.pth",
15
+ "mask_size": 32,
16
+ "max_position_embeddings": 2048,
17
+ "model_type": "lana_radgen",
18
+ "num_attention_layers": 12,
19
+ "segmentation_attention_implementation": "sdpa",
20
+ "segmentation_model_name": "facebook/dinov3-convnext-small-pretrain-lvd1689m",
21
+ "text_hidden_size": 768,
22
+ "text_model_name": "gpt2",
23
+ "transformers_version": "5.3.0",
24
+ "use_cache": true,
25
+ "use_segmentation_mask": true,
26
+ "vision_model_name": "facebook/dinov3-vits16-pretrain-lvd1689m",
27
+ "visual_feature_dim": 384,
28
+ "vocab_size": 50257,
29
+ "auto_map": {
30
+ "AutoConfig": "configuration_lana.LanaConfig",
31
+ "AutoModel": "modeling_lana.LanaForConditionalGeneration"
32
+ }
33
+ }
configuration_lana.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from lana_radgen.configuration_lana import LanaConfig
2
+
3
+ __all__ = ["LanaConfig"]
evaluations/mimic_test_findings_only_metrics.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "split": "test",
3
+ "subset": "findings-only frontal studies",
4
+ "dataset": "mimic-cxr",
5
+ "view_filter": "frontal-only (PA/AP), structured Findings section only",
6
+ "num_examples": 2210,
7
+ "bleu_1": 0.23099023872215996,
8
+ "bleu_4": 0.0429479479188206,
9
+ "meteor": 0.21248313160360002,
10
+ "rouge_l": 0.17210734193417726,
11
+ "chexpert_f1_14_micro": 0.11655011655011654,
12
+ "chexpert_f1_5_micro": 0.20709914320685435,
13
+ "chexpert_f1_14_macro": 0.04057070914402376,
14
+ "chexpert_f1_5_macro": 0.09202593660588262,
15
+ "chexpert_f1_micro": 0.11655011655011654,
16
+ "chexpert_f1_macro": 0.04057070914402376,
17
+ "chexpert_per_label_f1": {
18
+ "Enlarged Cardiomediastinum": 0.0,
19
+ "Cardiomegaly": 0.0,
20
+ "Lung Opacity": 0.0,
21
+ "Lung Lesion": 0.0,
22
+ "Edema": 0.022471910112359553,
23
+ "Consolidation": 0.05797101449275362,
24
+ "Pneumonia": 0.01673640167364017,
25
+ "Atelectasis": 0.0,
26
+ "Pneumothorax": 0.05716798592788039,
27
+ "Pleural Effusion": 0.3796867584243,
28
+ "Pleural Other": 0.0,
29
+ "Fracture": 0.0,
30
+ "Support Devices": 0.03395585738539898,
31
+ "No Finding": 0.0
32
+ },
33
+ "radgraph_f1": 0.10172866854646034,
34
+ "radgraph_f1_entity": 0.19217701907879298,
35
+ "radgraph_f1_relation": 0.17414731467894073,
36
+ "radgraph_available": true,
37
+ "radgraph_error": null
38
+ }
evaluations/mimic_test_findings_only_predictions.csv ADDED
The diff for this file is too large to render. See raw diff
 
evaluations/mimic_test_metrics.json ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "split": "test",
3
+ "subset": "all frontal studies",
4
+ "dataset": "mimic-cxr",
5
+ "view_filter": "frontal-only (PA/AP)",
6
+ "num_examples": 3041,
7
+ "bleu_1": 0.224283665754139,
8
+ "bleu_4": 0.038250603325884945,
9
+ "meteor": 0.20053919101008125,
10
+ "rouge_l": 0.16408262967411588,
11
+ "chexpert_f1_14_micro": 0.1244686823321838,
12
+ "chexpert_f1_5_micro": 0.21896350608231963,
13
+ "chexpert_f1_14_macro": 0.04429260458577706,
14
+ "chexpert_f1_5_macro": 0.09914404243362482,
15
+ "chexpert_f1_micro": 0.1244686823321838,
16
+ "chexpert_f1_macro": 0.04429260458577706,
17
+ "chexpert_per_label_f1": {
18
+ "Enlarged Cardiomediastinum": 0.0,
19
+ "Cardiomegaly": 0.0,
20
+ "Lung Opacity": 0.0,
21
+ "Lung Lesion": 0.0,
22
+ "Edema": 0.023121387283236997,
23
+ "Consolidation": 0.056790123456790124,
24
+ "Pneumonia": 0.02762430939226519,
25
+ "Atelectasis": 0.0,
26
+ "Pneumothorax": 0.059987236758136574,
27
+ "Pleural Effusion": 0.415808701428097,
28
+ "Pleural Other": 0.0,
29
+ "Fracture": 0.0,
30
+ "Support Devices": 0.036764705882352935,
31
+ "No Finding": 0.0
32
+ },
33
+ "radgraph_f1": 0.0941067057393548,
34
+ "radgraph_f1_entity": 0.18191243977753782,
35
+ "radgraph_f1_relation": 0.1652384677607375,
36
+ "radgraph_available": true,
37
+ "radgraph_error": null,
38
+ "evaluation_suite": "mimic_test_dual",
39
+ "all_test": {
40
+ "split": "test",
41
+ "subset": "all frontal studies",
42
+ "dataset": "mimic-cxr",
43
+ "view_filter": "frontal-only (PA/AP)",
44
+ "num_examples": 3041,
45
+ "bleu_1": 0.224283665754139,
46
+ "bleu_4": 0.038250603325884945,
47
+ "meteor": 0.20053919101008125,
48
+ "rouge_l": 0.16408262967411588,
49
+ "chexpert_f1_14_micro": 0.1244686823321838,
50
+ "chexpert_f1_5_micro": 0.21896350608231963,
51
+ "chexpert_f1_14_macro": 0.04429260458577706,
52
+ "chexpert_f1_5_macro": 0.09914404243362482,
53
+ "chexpert_f1_micro": 0.1244686823321838,
54
+ "chexpert_f1_macro": 0.04429260458577706,
55
+ "chexpert_per_label_f1": {
56
+ "Enlarged Cardiomediastinum": 0.0,
57
+ "Cardiomegaly": 0.0,
58
+ "Lung Opacity": 0.0,
59
+ "Lung Lesion": 0.0,
60
+ "Edema": 0.023121387283236997,
61
+ "Consolidation": 0.056790123456790124,
62
+ "Pneumonia": 0.02762430939226519,
63
+ "Atelectasis": 0.0,
64
+ "Pneumothorax": 0.059987236758136574,
65
+ "Pleural Effusion": 0.415808701428097,
66
+ "Pleural Other": 0.0,
67
+ "Fracture": 0.0,
68
+ "Support Devices": 0.036764705882352935,
69
+ "No Finding": 0.0
70
+ },
71
+ "radgraph_f1": 0.0941067057393548,
72
+ "radgraph_f1_entity": 0.18191243977753782,
73
+ "radgraph_f1_relation": 0.1652384677607375,
74
+ "radgraph_available": true,
75
+ "radgraph_error": null
76
+ },
77
+ "findings_only_test": {
78
+ "split": "test",
79
+ "subset": "findings-only frontal studies",
80
+ "dataset": "mimic-cxr",
81
+ "view_filter": "frontal-only (PA/AP), structured Findings section only",
82
+ "num_examples": 2210,
83
+ "bleu_1": 0.23099023872215996,
84
+ "bleu_4": 0.0429479479188206,
85
+ "meteor": 0.21248313160360002,
86
+ "rouge_l": 0.17210734193417726,
87
+ "chexpert_f1_14_micro": 0.11655011655011654,
88
+ "chexpert_f1_5_micro": 0.20709914320685435,
89
+ "chexpert_f1_14_macro": 0.04057070914402376,
90
+ "chexpert_f1_5_macro": 0.09202593660588262,
91
+ "chexpert_f1_micro": 0.11655011655011654,
92
+ "chexpert_f1_macro": 0.04057070914402376,
93
+ "chexpert_per_label_f1": {
94
+ "Enlarged Cardiomediastinum": 0.0,
95
+ "Cardiomegaly": 0.0,
96
+ "Lung Opacity": 0.0,
97
+ "Lung Lesion": 0.0,
98
+ "Edema": 0.022471910112359553,
99
+ "Consolidation": 0.05797101449275362,
100
+ "Pneumonia": 0.01673640167364017,
101
+ "Atelectasis": 0.0,
102
+ "Pneumothorax": 0.05716798592788039,
103
+ "Pleural Effusion": 0.3796867584243,
104
+ "Pleural Other": 0.0,
105
+ "Fracture": 0.0,
106
+ "Support Devices": 0.03395585738539898,
107
+ "No Finding": 0.0
108
+ },
109
+ "radgraph_f1": 0.10172866854646034,
110
+ "radgraph_f1_entity": 0.19217701907879298,
111
+ "radgraph_f1_relation": 0.17414731467894073,
112
+ "radgraph_available": true,
113
+ "radgraph_error": null
114
+ }
115
+ }
evaluations/mimic_test_predictions.csv ADDED
The diff for this file is too large to render. See raw diff
 
lana_radgen/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .configuration_lana import LanaConfig
2
+ from .modeling_lana import LanaForConditionalGeneration
3
+ from .modeling_outputs import LanaModelOutput
4
+
5
+ __all__ = [
6
+ "LanaConfig",
7
+ "LanaForConditionalGeneration",
8
+ "LanaModelOutput",
9
+ ]
lana_radgen/attention/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .layerwise_anatomical_attention import build_layerwise_attention_bias
2
+
3
+ __all__ = ["build_layerwise_attention_bias"]
lana_radgen/attention/layerwise_anatomical_attention.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def _gaussian_kernel_1d(kernel_size: int, sigma: float, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
6
+ radius = kernel_size // 2
7
+ x = torch.arange(-radius, radius + 1, device=device, dtype=dtype)
8
+ kernel = torch.exp(-(x * x) / (2.0 * sigma * sigma))
9
+ return kernel / kernel.sum()
10
+
11
+
12
+ @torch.no_grad()
13
+ def build_layerwise_attention_bias(
14
+ masks: torch.Tensor,
15
+ num_layers: int,
16
+ target_tokens: int,
17
+ base_kernel_size: int = 3,
18
+ kernel_growth: int = 2,
19
+ strength: float = 2.0,
20
+ eps: float = 1e-8,
21
+ ) -> torch.Tensor:
22
+ if masks.ndim == 3:
23
+ masks = masks.unsqueeze(1)
24
+ if masks.ndim != 4 or masks.shape[1] != 1:
25
+ raise ValueError(f"Expected masks shaped (B,1,H,W) or (B,H,W), got {tuple(masks.shape)}")
26
+
27
+ masks = masks.float()
28
+ batch_size = masks.shape[0]
29
+ resized = F.interpolate(masks, size=(32, 32), mode="bilinear", align_corners=False).clamp(0.0, 1.0)
30
+
31
+ max_kernel = base_kernel_size + max(num_layers, 0) * kernel_growth
32
+ if max_kernel % 2 == 0:
33
+ max_kernel += 1
34
+ pad = max_kernel // 2
35
+
36
+ weight_h = torch.zeros((num_layers, 1, 1, max_kernel), device=resized.device, dtype=resized.dtype)
37
+ weight_v = torch.zeros((num_layers, 1, max_kernel, 1), device=resized.device, dtype=resized.dtype)
38
+
39
+ for layer_idx in range(num_layers):
40
+ kernel_size = base_kernel_size + (num_layers - layer_idx) * kernel_growth
41
+ if kernel_size % 2 == 0:
42
+ kernel_size += 1
43
+ sigma = max((kernel_size - 1) / 6.0, 1e-3)
44
+ kernel = _gaussian_kernel_1d(kernel_size, sigma, resized.device, resized.dtype)
45
+ start = (max_kernel - kernel_size) // 2
46
+ end = start + kernel_size
47
+ weight_h[layer_idx, 0, 0, start:end] = kernel
48
+ weight_v[layer_idx, 0, start:end, 0] = kernel
49
+
50
+ repeated = resized.expand(batch_size, num_layers, 32, 32).contiguous()
51
+ horizontal = F.conv2d(F.pad(repeated, (pad, pad, 0, 0), mode="reflect"), weight_h, groups=num_layers)
52
+ vertical = F.conv2d(F.pad(horizontal, (0, 0, pad, pad), mode="reflect"), weight_v, groups=num_layers)
53
+
54
+ min_vals = vertical.amin(dim=(2, 3), keepdim=True)
55
+ max_vals = vertical.amax(dim=(2, 3), keepdim=True)
56
+ normalized = (vertical - min_vals) / (max_vals - min_vals).clamp_min(eps)
57
+
58
+ flat = normalized.view(batch_size, num_layers, -1)
59
+ if flat.shape[-1] != target_tokens:
60
+ flat = F.interpolate(flat, size=target_tokens, mode="linear", align_corners=False)
61
+ layerwise_bias = flat.unsqueeze(-2).expand(-1, -1, target_tokens, -1)
62
+ return torch.tril(layerwise_bias) * strength
lana_radgen/configuration_lana.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class LanaConfig(PretrainedConfig):
5
+ model_type = "lana_radgen"
6
+
7
+ def __init__(
8
+ self,
9
+ vision_model_name: str = "facebook/dinov3-vits16-pretrain-lvd1689m",
10
+ text_model_name: str = "gpt2",
11
+ image_size: int = 512,
12
+ mask_size: int = 32,
13
+ num_attention_layers: int = 12,
14
+ max_position_embeddings: int = 2048,
15
+ visual_feature_dim: int = 384,
16
+ text_hidden_size: int = 768,
17
+ vocab_size: int = 50257,
18
+ layer_mask_base_kernel_size: int = 3,
19
+ layer_mask_kernel_growth: int = 2,
20
+ anatomical_attention_bias: float = 2.0,
21
+ use_segmentation_mask: bool = True,
22
+ segmentation_model_name: str = "facebook/dinov3-convnext-small-pretrain-lvd1689m",
23
+ segmentation_attention_implementation: str = "sdpa",
24
+ freeze_segmenter: bool = True,
25
+ lung_segmenter_checkpoint: str = "",
26
+ heart_segmenter_checkpoint: str = "",
27
+ use_cache: bool = True,
28
+ decoder_load_in_4bit: bool = False,
29
+ decoder_compute_dtype: str = "float16",
30
+ **kwargs,
31
+ ):
32
+ self.vision_model_name = vision_model_name
33
+ self.text_model_name = text_model_name
34
+ self.image_size = image_size
35
+ self.mask_size = mask_size
36
+ self.num_attention_layers = num_attention_layers
37
+ self.max_position_embeddings = max_position_embeddings
38
+ self.visual_feature_dim = visual_feature_dim
39
+ self.text_hidden_size = text_hidden_size
40
+ self.vocab_size = vocab_size
41
+ self.layer_mask_base_kernel_size = layer_mask_base_kernel_size
42
+ self.layer_mask_kernel_growth = layer_mask_kernel_growth
43
+ self.anatomical_attention_bias = anatomical_attention_bias
44
+ self.use_segmentation_mask = use_segmentation_mask
45
+ self.segmentation_model_name = segmentation_model_name
46
+ self.segmentation_attention_implementation = segmentation_attention_implementation
47
+ self.freeze_segmenter = freeze_segmenter
48
+ self.lung_segmenter_checkpoint = lung_segmenter_checkpoint
49
+ self.heart_segmenter_checkpoint = heart_segmenter_checkpoint
50
+ self.use_cache = use_cache
51
+ self.decoder_load_in_4bit = decoder_load_in_4bit
52
+ self.decoder_compute_dtype = decoder_compute_dtype
53
+ super().__init__(**kwargs)
lana_radgen/gpt2_modified.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from transformers import GPT2Config, GPT2LMHeadModel, GPT2Model
7
+ from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
8
+ from transformers.masking_utils import create_causal_mask
9
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
10
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
11
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
12
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, eager_attention_forward
13
+
14
+
15
+ class GPT2AttentionModified(GPT2Attention):
16
+ def forward(
17
+ self,
18
+ hidden_states: Optional[tuple[torch.FloatTensor]],
19
+ past_key_values: Optional[Cache] = None,
20
+ cache_position: Optional[torch.LongTensor] = None,
21
+ attention_mask: Optional[torch.FloatTensor] = None,
22
+ head_mask: Optional[torch.FloatTensor] = None,
23
+ encoder_hidden_states: Optional[torch.Tensor] = None,
24
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
25
+ output_attentions: Optional[bool] = False,
26
+ **kwargs,
27
+ ):
28
+ is_cross_attention = encoder_hidden_states is not None
29
+ if past_key_values is not None:
30
+ if isinstance(past_key_values, EncoderDecoderCache):
31
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
32
+ curr_past_key_value = past_key_values.cross_attention_cache if is_cross_attention else past_key_values.self_attention_cache
33
+ else:
34
+ curr_past_key_value = past_key_values
35
+
36
+ if is_cross_attention:
37
+ if not hasattr(self, "q_attn"):
38
+ raise ValueError("Cross-attention requires q_attn to be defined.")
39
+ query_states = self.q_attn(hidden_states)
40
+ attention_mask = encoder_attention_mask
41
+ if past_key_values is not None and is_updated:
42
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
43
+ value_states = curr_past_key_value.layers[self.layer_idx].values
44
+ else:
45
+ key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
46
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
47
+ key_states = key_states.view(shape_kv).transpose(1, 2)
48
+ value_states = value_states.view(shape_kv).transpose(1, 2)
49
+ else:
50
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
51
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
52
+ key_states = key_states.view(shape_kv).transpose(1, 2)
53
+ value_states = value_states.view(shape_kv).transpose(1, 2)
54
+
55
+ shape_q = (*query_states.shape[:-1], -1, self.head_dim)
56
+ query_states = query_states.view(shape_q).transpose(1, 2)
57
+
58
+ if (past_key_values is not None and not is_cross_attention) or (
59
+ past_key_values is not None and is_cross_attention and not is_updated
60
+ ):
61
+ cache_position = cache_position if not is_cross_attention else None
62
+ key_states, value_states = curr_past_key_value.update(
63
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
64
+ )
65
+ if is_cross_attention:
66
+ past_key_values.is_updated[self.layer_idx] = True
67
+
68
+ is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
69
+ attention_interface = eager_attention_forward
70
+ if self.config._attn_implementation != "eager":
71
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
72
+
73
+ attn_output, attn_weights = attention_interface(
74
+ self,
75
+ query_states,
76
+ key_states,
77
+ value_states,
78
+ attention_mask,
79
+ head_mask=head_mask,
80
+ dropout=self.attn_dropout.p if self.training else 0.0,
81
+ is_causal=is_causal,
82
+ **kwargs,
83
+ )
84
+
85
+ attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
86
+ attn_output = self.c_proj(attn_output)
87
+ attn_output = self.resid_dropout(attn_output)
88
+ return attn_output, attn_weights
89
+
90
+
91
+ class GPT2BlockModified(GPT2Block):
92
+ def __init__(self, config, layer_idx=None):
93
+ super().__init__(config=config, layer_idx=layer_idx)
94
+ self.attn = GPT2AttentionModified(config=config, layer_idx=layer_idx)
95
+
96
+
97
+ class GPT2ModelModified(GPT2Model):
98
+ def __init__(self, config):
99
+ super().__init__(config)
100
+ self.config_causal = config
101
+ self.config_causal._attn_implementation = "eager"
102
+ self.h = nn.ModuleList([GPT2BlockModified(config, layer_idx=i) for i in range(config.num_hidden_layers)])
103
+
104
+ def forward(
105
+ self,
106
+ input_ids: Optional[torch.LongTensor] = None,
107
+ past_key_values: Optional[Union[tuple[tuple[torch.Tensor]], Cache]] = None,
108
+ cache_position: Optional[torch.LongTensor] = None,
109
+ attention_mask: Optional[torch.FloatTensor] = None,
110
+ token_type_ids: Optional[torch.LongTensor] = None,
111
+ position_ids: Optional[torch.LongTensor] = None,
112
+ head_mask: Optional[torch.FloatTensor] = None,
113
+ inputs_embeds: Optional[torch.FloatTensor] = None,
114
+ encoder_hidden_states: Optional[torch.Tensor] = None,
115
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
116
+ use_cache: Optional[bool] = None,
117
+ output_attentions: Optional[bool] = None,
118
+ output_hidden_states: Optional[bool] = None,
119
+ return_dict: Optional[bool] = None,
120
+ segmentation_mask: Optional[torch.FloatTensor] = None,
121
+ **kwargs,
122
+ ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
123
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
124
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
125
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
126
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
127
+
128
+ if input_ids is not None and inputs_embeds is not None:
129
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
130
+ if input_ids is not None:
131
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
132
+ input_shape = input_ids.size()
133
+ input_ids = input_ids.view(-1, input_shape[-1])
134
+ batch_size = input_ids.shape[0]
135
+ elif inputs_embeds is not None:
136
+ input_shape = inputs_embeds.size()[:-1]
137
+ batch_size = inputs_embeds.shape[0]
138
+ else:
139
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
140
+
141
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
142
+
143
+ if token_type_ids is not None:
144
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
145
+
146
+ if self.gradient_checkpointing and self.training and use_cache:
147
+ use_cache = False
148
+
149
+ if use_cache:
150
+ if past_key_values is None:
151
+ past_key_values = DynamicCache()
152
+ elif isinstance(past_key_values, tuple):
153
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
154
+ if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
155
+ past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
156
+
157
+ if inputs_embeds is None:
158
+ inputs_embeds = self.wte(input_ids)
159
+
160
+ if cache_position is None:
161
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
162
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)
163
+ if position_ids is None:
164
+ position_ids = cache_position.unsqueeze(0)
165
+
166
+ position_embeds = self.wpe(position_ids)
167
+ hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
168
+
169
+ if attention_mask is not None and attention_mask.ndim < 4:
170
+ attention_mask = attention_mask.view(batch_size, -1)
171
+
172
+ causal_mask = create_causal_mask(
173
+ config=self.config_causal,
174
+ input_embeds=inputs_embeds,
175
+ attention_mask=attention_mask,
176
+ cache_position=cache_position,
177
+ past_key_values=past_key_values,
178
+ position_ids=position_ids,
179
+ )
180
+
181
+ _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
182
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
183
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
184
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
185
+ if encoder_attention_mask is None:
186
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
187
+ if _use_sdpa:
188
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
189
+ mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
190
+ )
191
+ elif self._attn_implementation != "flash_attention_2":
192
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
193
+ else:
194
+ encoder_attention_mask = None
195
+
196
+ if head_mask is None:
197
+ head_mask = [None] * self.config.n_layer
198
+
199
+ if token_type_ids is not None:
200
+ hidden_states = hidden_states + self.wte(token_type_ids)
201
+
202
+ hidden_states = self.drop(hidden_states)
203
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
204
+ all_self_attentions = () if output_attentions else None
205
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
206
+ all_hidden_states = () if output_hidden_states else None
207
+
208
+ for i, block in enumerate(self.h):
209
+ if output_hidden_states:
210
+ all_hidden_states = all_hidden_states + (hidden_states,)
211
+
212
+ block_mask = causal_mask
213
+ if segmentation_mask is not None and causal_mask is not None:
214
+ block_mask = causal_mask.clone()
215
+ seq_len = input_shape[-1]
216
+ if block_mask.shape[2] != seq_len or block_mask.shape[3] != seq_len:
217
+ block_mask = block_mask[:, :, :seq_len, :seq_len]
218
+ layer_bias = segmentation_mask[:, i, : block_mask.shape[2], : block_mask.shape[3]].unsqueeze(1)
219
+ block_mask = block_mask + layer_bias.to(dtype=block_mask.dtype, device=block_mask.device)
220
+
221
+ outputs = block(
222
+ hidden_states=hidden_states,
223
+ past_key_values=past_key_values if not (self.gradient_checkpointing and self.training) else None,
224
+ cache_position=cache_position,
225
+ attention_mask=block_mask,
226
+ encoder_hidden_states=encoder_hidden_states,
227
+ encoder_attention_mask=encoder_attention_mask,
228
+ use_cache=use_cache,
229
+ output_attentions=output_attentions,
230
+ head_mask=head_mask[i],
231
+ **kwargs,
232
+ )
233
+ if isinstance(outputs, tuple):
234
+ hidden_states = outputs[0]
235
+ if output_attentions and len(outputs) > 1:
236
+ all_self_attentions = all_self_attentions + (outputs[1],)
237
+ if self.config.add_cross_attention and len(outputs) > 2:
238
+ all_cross_attentions = all_cross_attentions + (outputs[2],)
239
+ else:
240
+ hidden_states = outputs
241
+
242
+ hidden_states = self.ln_f(hidden_states)
243
+ hidden_states = hidden_states.view(output_shape)
244
+ if output_hidden_states:
245
+ all_hidden_states = all_hidden_states + (hidden_states,)
246
+
247
+ past_key_values = past_key_values if use_cache else None
248
+ if not return_dict:
249
+ return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None)
250
+
251
+ return BaseModelOutputWithPastAndCrossAttentions(
252
+ last_hidden_state=hidden_states,
253
+ past_key_values=past_key_values,
254
+ hidden_states=all_hidden_states,
255
+ attentions=all_self_attentions,
256
+ cross_attentions=all_cross_attentions,
257
+ )
258
+
259
+
260
+ class GPT2LMHeadModelModified(GPT2LMHeadModel):
261
+ def __init__(self, config):
262
+ super().__init__(config)
263
+ self.transformer = GPT2ModelModified(config)
264
+ self.post_init()
265
+
266
+ def forward(
267
+ self,
268
+ input_ids: Optional[torch.LongTensor] = None,
269
+ past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None,
270
+ cache_position: Optional[torch.LongTensor] = None,
271
+ attention_mask: Optional[torch.FloatTensor] = None,
272
+ token_type_ids: Optional[torch.LongTensor] = None,
273
+ position_ids: Optional[torch.LongTensor] = None,
274
+ head_mask: Optional[torch.FloatTensor] = None,
275
+ inputs_embeds: Optional[torch.FloatTensor] = None,
276
+ encoder_hidden_states: Optional[torch.Tensor] = None,
277
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
278
+ labels: Optional[torch.LongTensor] = None,
279
+ use_cache: Optional[bool] = None,
280
+ output_attentions: Optional[bool] = None,
281
+ output_hidden_states: Optional[bool] = None,
282
+ return_dict: Optional[bool] = None,
283
+ logits_to_keep: Union[int, torch.Tensor] = 0,
284
+ segmentation_mask: Optional[torch.FloatTensor] = None,
285
+ **kwargs,
286
+ ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
287
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
288
+ transformer_outputs = self.transformer(
289
+ input_ids,
290
+ past_key_values=past_key_values,
291
+ attention_mask=attention_mask,
292
+ cache_position=cache_position,
293
+ token_type_ids=token_type_ids,
294
+ position_ids=position_ids,
295
+ head_mask=head_mask,
296
+ inputs_embeds=inputs_embeds,
297
+ encoder_hidden_states=encoder_hidden_states,
298
+ encoder_attention_mask=encoder_attention_mask,
299
+ use_cache=use_cache,
300
+ output_attentions=output_attentions,
301
+ output_hidden_states=output_hidden_states,
302
+ return_dict=return_dict,
303
+ segmentation_mask=segmentation_mask,
304
+ **kwargs,
305
+ )
306
+ hidden_states = transformer_outputs[0]
307
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep > 0 else slice(None)
308
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
309
+
310
+ loss = None
311
+ if labels is not None:
312
+ loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
313
+
314
+ if not return_dict:
315
+ output = (logits,) + transformer_outputs[1:]
316
+ return ((loss,) + output) if loss is not None else output
317
+
318
+ return CausalLMOutputWithCrossAttentions(
319
+ loss=loss,
320
+ logits=logits,
321
+ past_key_values=transformer_outputs.past_key_values,
322
+ hidden_states=transformer_outputs.hidden_states,
323
+ attentions=transformer_outputs.attentions,
324
+ cross_attentions=transformer_outputs.cross_attentions,
325
+ )
326
+
327
+
328
+ @torch.no_grad()
329
+ def expand_gpt2_positional_embeddings(
330
+ model: torch.nn.Module,
331
+ new_max_positions: int,
332
+ mode: str = "linear",
333
+ align_corners: bool = True,
334
+ ):
335
+ if hasattr(model, "transformer") and hasattr(model.transformer, "wpe"):
336
+ model_for_wpe = model.transformer
337
+ elif hasattr(model, "wpe"):
338
+ model_for_wpe = model
339
+ else:
340
+ raise ValueError("Model does not expose GPT-2 positional embeddings.")
341
+
342
+ wpe = model_for_wpe.wpe
343
+ old_n, d = wpe.weight.shape
344
+ if new_max_positions == old_n:
345
+ return model
346
+
347
+ device = wpe.weight.device
348
+ dtype = wpe.weight.dtype
349
+ if new_max_positions < old_n:
350
+ new_weight = wpe.weight[:new_max_positions].clone()
351
+ else:
352
+ if mode != "linear":
353
+ raise ValueError(f"Unsupported positional expansion mode: {mode}")
354
+ w = wpe.weight.transpose(0, 1).unsqueeze(0)
355
+ w_new = F.interpolate(w, size=new_max_positions, mode="linear", align_corners=align_corners)
356
+ new_weight = w_new.squeeze(0).transpose(0, 1).contiguous()
357
+
358
+ new_wpe = torch.nn.Embedding(new_max_positions, d, device=device, dtype=dtype)
359
+ new_wpe.weight.copy_(new_weight)
360
+ if hasattr(model, "transformer") and hasattr(model.transformer, "wpe"):
361
+ model.transformer.wpe = new_wpe
362
+ else:
363
+ model.wpe = new_wpe
364
+ if hasattr(model.config, "n_positions"):
365
+ model.config.n_positions = new_max_positions
366
+ if hasattr(model.config, "n_ctx"):
367
+ model.config.n_ctx = new_max_positions
368
+ return model
369
+
370
+
371
+ def create_decoder(text_model_name: str, attention_implementation: str, max_position_embeddings: int, **decoder_kwargs):
372
+ config = GPT2Config.from_pretrained(text_model_name)
373
+ config._attn_implementation = attention_implementation
374
+ config.n_positions = max_position_embeddings
375
+ config.n_ctx = max_position_embeddings
376
+ config.use_cache = decoder_kwargs.pop("use_cache", True)
377
+ decoder = GPT2LMHeadModelModified.from_pretrained(text_model_name, config=config, **decoder_kwargs)
378
+ decoder.config._attn_implementation = attention_implementation
379
+ return expand_gpt2_positional_embeddings(decoder, new_max_positions=max_position_embeddings, mode="linear")
lana_radgen/modeling_lana.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import AutoConfig, AutoModel, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel
7
+
8
+ from .attention import build_layerwise_attention_bias
9
+ from .configuration_lana import LanaConfig
10
+ from .gpt2_modified import create_decoder
11
+ from .modeling_outputs import LanaModelOutput
12
+ from .segmenters import AnatomicalSegmenter
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class LanaForConditionalGeneration(PreTrainedModel):
18
+ config_class = LanaConfig
19
+ base_model_prefix = "lana"
20
+ supports_gradient_checkpointing = True
21
+
22
+ def __init__(self, config: LanaConfig):
23
+ super().__init__(config)
24
+ vision_config = AutoConfig.from_pretrained(config.vision_model_name, trust_remote_code=True)
25
+ if getattr(vision_config, "hidden_size", None) is not None:
26
+ config.visual_feature_dim = vision_config.hidden_size
27
+
28
+ self.vision_encoder = AutoModel.from_pretrained(config.vision_model_name, trust_remote_code=True)
29
+ decoder_kwargs = {
30
+ "ignore_mismatched_sizes": True,
31
+ "use_cache": config.use_cache,
32
+ }
33
+ if config.decoder_load_in_4bit:
34
+ compute_dtype = getattr(torch, config.decoder_compute_dtype, torch.float16)
35
+ decoder_kwargs["quantization_config"] = BitsAndBytesConfig(
36
+ load_in_4bit=True,
37
+ bnb_4bit_quant_type="nf4",
38
+ bnb_4bit_use_double_quant=True,
39
+ bnb_4bit_compute_dtype=compute_dtype,
40
+ )
41
+ decoder_kwargs["device_map"] = {"": 0}
42
+ self.text_decoder = create_decoder(
43
+ text_model_name=config.text_model_name,
44
+ attention_implementation=config.segmentation_attention_implementation,
45
+ max_position_embeddings=config.max_position_embeddings,
46
+ **decoder_kwargs,
47
+ )
48
+ self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_name)
49
+ if self.tokenizer.pad_token_id is None:
50
+ self.tokenizer.pad_token = self.tokenizer.eos_token
51
+
52
+ config.vocab_size = self.text_decoder.config.vocab_size
53
+ config.text_hidden_size = self.text_decoder.config.hidden_size
54
+ config.num_attention_layers = self.text_decoder.config.n_layer
55
+
56
+ self.visual_projection = nn.Sequential(
57
+ nn.Linear(config.visual_feature_dim, config.text_hidden_size),
58
+ nn.GELU(),
59
+ nn.Linear(config.text_hidden_size, config.text_hidden_size),
60
+ nn.GELU(),
61
+ nn.Linear(config.text_hidden_size, config.text_hidden_size),
62
+ nn.GELU(),
63
+ nn.Linear(config.text_hidden_size, config.text_hidden_size),
64
+ )
65
+ self.segmenter = None
66
+ if config.use_segmentation_mask:
67
+ self.segmenter = AnatomicalSegmenter(
68
+ model_name=config.segmentation_model_name,
69
+ freeze=config.freeze_segmenter,
70
+ lung_checkpoint=config.lung_segmenter_checkpoint,
71
+ heart_checkpoint=config.heart_segmenter_checkpoint,
72
+ )
73
+ self.post_init()
74
+
75
+ def move_non_quantized_modules(self, device: torch.device) -> None:
76
+ self.vision_encoder.to(device)
77
+ self.visual_projection.to(device)
78
+ if self.segmenter is not None:
79
+ self.segmenter.to(device)
80
+ if not getattr(self.config, "decoder_load_in_4bit", False):
81
+ self.text_decoder.to(device)
82
+
83
+ def _encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
84
+ if any(param.requires_grad for param in self.vision_encoder.parameters()):
85
+ outputs = self.vision_encoder(pixel_values=pixel_values)
86
+ else:
87
+ with torch.no_grad():
88
+ outputs = self.vision_encoder(pixel_values=pixel_values)
89
+ hidden = outputs.last_hidden_state
90
+ if hidden.shape[1] > 1:
91
+ hidden = hidden[:, 1:, :]
92
+ return self.visual_projection(hidden)
93
+
94
+ def _build_layerwise_bias(self, anatomical_masks: Optional[torch.Tensor], total_sequence_length: int) -> Optional[torch.Tensor]:
95
+ if anatomical_masks is None:
96
+ return None
97
+ return build_layerwise_attention_bias(
98
+ masks=anatomical_masks,
99
+ num_layers=self.config.num_attention_layers,
100
+ target_tokens=total_sequence_length,
101
+ base_kernel_size=self.config.layer_mask_base_kernel_size,
102
+ kernel_growth=self.config.layer_mask_kernel_growth,
103
+ strength=self.config.anatomical_attention_bias,
104
+ )
105
+
106
+ def _resolve_attention_bias(self, pixel_values: torch.Tensor, anatomical_masks: Optional[torch.Tensor], total_sequence_length: int):
107
+ if anatomical_masks is not None:
108
+ return self._build_layerwise_bias(anatomical_masks, total_sequence_length=total_sequence_length)
109
+ if self.segmenter is None:
110
+ return None
111
+ layerwise_bias = self.segmenter(
112
+ pixel_values,
113
+ num_layers=self.config.num_attention_layers,
114
+ target_tokens=total_sequence_length,
115
+ strength=self.config.anatomical_attention_bias,
116
+ )
117
+ if layerwise_bias is None:
118
+ logger.warning("Segmentation attention is enabled but no segmenter checkpoints were loaded; continuing without anatomical attention.")
119
+ return layerwise_bias
120
+
121
+ def forward(
122
+ self,
123
+ pixel_values: torch.Tensor,
124
+ input_ids: Optional[torch.LongTensor] = None,
125
+ attention_mask: Optional[torch.Tensor] = None,
126
+ anatomical_masks: Optional[torch.Tensor] = None,
127
+ labels: Optional[torch.LongTensor] = None,
128
+ output_attentions: Optional[bool] = None,
129
+ output_hidden_states: Optional[bool] = None,
130
+ return_dict: Optional[bool] = True,
131
+ **kwargs,
132
+ ) -> LanaModelOutput:
133
+ vision_features = self._encode_images(pixel_values)
134
+ batch_size, prefix_length, _ = vision_features.shape
135
+
136
+ if input_ids is None:
137
+ bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
138
+ input_ids = torch.full((batch_size, 1), bos, device=vision_features.device, dtype=torch.long)
139
+ attention_mask = torch.ones_like(input_ids)
140
+ elif attention_mask is None:
141
+ attention_mask = torch.ones_like(input_ids)
142
+
143
+ text_embeds = self.text_decoder.transformer.wte(input_ids)
144
+ inputs_embeds = torch.cat([vision_features, text_embeds], dim=1)
145
+ merged_attention_mask = torch.cat(
146
+ [
147
+ torch.ones((batch_size, prefix_length), device=attention_mask.device, dtype=attention_mask.dtype),
148
+ attention_mask,
149
+ ],
150
+ dim=1,
151
+ )
152
+
153
+ merged_labels = None
154
+ if labels is not None:
155
+ ignore_prefix = torch.full((batch_size, prefix_length), -100, device=labels.device, dtype=labels.dtype)
156
+ merged_labels = torch.cat([ignore_prefix, labels], dim=1)
157
+
158
+ layerwise_bias = self._resolve_attention_bias(
159
+ pixel_values=pixel_values,
160
+ anatomical_masks=anatomical_masks,
161
+ total_sequence_length=inputs_embeds.shape[1],
162
+ )
163
+ decoder_outputs = self.text_decoder(
164
+ inputs_embeds=inputs_embeds,
165
+ attention_mask=merged_attention_mask,
166
+ labels=merged_labels,
167
+ segmentation_mask=layerwise_bias,
168
+ use_cache=False,
169
+ output_attentions=output_attentions,
170
+ output_hidden_states=output_hidden_states,
171
+ return_dict=True,
172
+ **kwargs,
173
+ )
174
+
175
+ return LanaModelOutput(
176
+ loss=decoder_outputs.loss,
177
+ logits=decoder_outputs.logits,
178
+ attentions=decoder_outputs.attentions,
179
+ layerwise_attentions=layerwise_bias,
180
+ hidden_states=decoder_outputs.hidden_states,
181
+ vision_features=vision_features,
182
+ )
183
+
184
+ @torch.inference_mode()
185
+ def generate(
186
+ self,
187
+ pixel_values: torch.Tensor,
188
+ anatomical_masks: Optional[torch.Tensor] = None,
189
+ max_new_tokens: int = 128,
190
+ **kwargs,
191
+ ):
192
+ vision_features = self._encode_images(pixel_values)
193
+ batch_size = pixel_values.shape[0]
194
+ bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
195
+ start_tokens = torch.full((batch_size, 1), bos, device=pixel_values.device, dtype=torch.long)
196
+ text_embeds = self.text_decoder.transformer.wte(start_tokens)
197
+ inputs_embeds = torch.cat([vision_features, text_embeds], dim=1)
198
+ attention_mask = torch.ones(inputs_embeds.shape[:2], device=pixel_values.device, dtype=torch.long)
199
+
200
+ layerwise_bias = self._resolve_attention_bias(
201
+ pixel_values=pixel_values,
202
+ anatomical_masks=anatomical_masks,
203
+ total_sequence_length=inputs_embeds.shape[1] + max_new_tokens,
204
+ )
205
+ return self.text_decoder.generate(
206
+ inputs_embeds=inputs_embeds,
207
+ attention_mask=attention_mask,
208
+ max_new_tokens=max_new_tokens,
209
+ pad_token_id=self.tokenizer.pad_token_id,
210
+ eos_token_id=self.tokenizer.eos_token_id,
211
+ segmentation_mask=layerwise_bias,
212
+ use_cache=True,
213
+ **kwargs,
214
+ )
lana_radgen/modeling_outputs.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from transformers.utils import ModelOutput
6
+
7
+
8
+ @dataclass
9
+ class LanaModelOutput(ModelOutput):
10
+ loss: Optional[torch.FloatTensor] = None
11
+ logits: Optional[torch.FloatTensor] = None
12
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
13
+ layerwise_attentions: Optional[torch.FloatTensor] = None
14
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
15
+ vision_features: Optional[torch.FloatTensor] = None
lana_radgen/segmenters.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import AutoModel
7
+
8
+ from .attention.layerwise_anatomical_attention import build_layerwise_attention_bias
9
+
10
+ LOGGER = logging.getLogger(__name__)
11
+
12
+
13
+ def _freeze_module(module: nn.Module) -> None:
14
+ for param in module.parameters():
15
+ param.requires_grad = False
16
+
17
+
18
+ class _DinoUNetLung(nn.Module):
19
+ def __init__(self, model_name: str, freeze: bool = True):
20
+ super().__init__()
21
+ self.encoder = AutoModel.from_pretrained(model_name, trust_remote_code=True)
22
+ self.channel_adapter = nn.Conv2d(768, 512, kernel_size=1)
23
+ self.decoder = nn.Sequential(
24
+ nn.Conv2d(512, 256, 3, padding=1),
25
+ nn.ReLU(inplace=True),
26
+ nn.ConvTranspose2d(256, 128, 2, stride=2),
27
+ nn.ReLU(inplace=True),
28
+ nn.ConvTranspose2d(128, 64, 2, stride=2),
29
+ nn.ReLU(inplace=True),
30
+ nn.Conv2d(64, 1, 1),
31
+ )
32
+ if freeze:
33
+ _freeze_module(self)
34
+
35
+ @torch.no_grad()
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ enc_feats = self.encoder(x, output_hidden_states=True, return_dict=True)
38
+ feats = next(h for h in reversed(enc_feats.hidden_states) if isinstance(h, torch.Tensor) and h.ndim == 4)
39
+ feats = self.channel_adapter(feats)
40
+ pred = self.decoder(feats)
41
+ return (torch.sigmoid(pred) > 0.5).float()
42
+
43
+
44
+ class _DinoUNetHeart(nn.Module):
45
+ def __init__(self, model_name: str, freeze: bool = True):
46
+ super().__init__()
47
+ self.encoder = AutoModel.from_pretrained(model_name, trust_remote_code=True)
48
+ self.adapter = nn.Conv2d(768, 512, 1)
49
+ self.decoder = nn.Sequential(
50
+ nn.Conv2d(512, 256, 3, padding=1),
51
+ nn.ReLU(True),
52
+ nn.ConvTranspose2d(256, 128, 2, 2),
53
+ nn.ReLU(True),
54
+ nn.ConvTranspose2d(128, 64, 2, 2),
55
+ nn.ReLU(True),
56
+ nn.Conv2d(64, 3, 1),
57
+ )
58
+ if freeze:
59
+ _freeze_module(self)
60
+
61
+ @torch.no_grad()
62
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
63
+ enc = self.encoder(x, output_hidden_states=True, return_dict=True)
64
+ feat = next(h for h in reversed(enc.hidden_states) if isinstance(h, torch.Tensor) and h.ndim == 4)
65
+ feat = self.adapter(feat)
66
+ logits = self.decoder(feat)
67
+ pred = torch.argmax(logits, dim=1)
68
+ return (pred == 2).unsqueeze(1).float()
69
+
70
+
71
+ class AnatomicalSegmenter(nn.Module):
72
+ def __init__(
73
+ self,
74
+ model_name: str,
75
+ freeze: bool = True,
76
+ lung_checkpoint: str = "",
77
+ heart_checkpoint: str = "",
78
+ ):
79
+ super().__init__()
80
+ self.lung_model = _DinoUNetLung(model_name=model_name, freeze=freeze)
81
+ self.heart_model = _DinoUNetHeart(model_name=model_name, freeze=freeze)
82
+ self.loaded_lung_checkpoint = self._load_submodule(self.lung_model, lung_checkpoint, "lung")
83
+ self.loaded_heart_checkpoint = self._load_submodule(self.heart_model, heart_checkpoint, "heart")
84
+
85
+ @staticmethod
86
+ def _load_submodule(module: nn.Module, checkpoint_path: str, label: str) -> bool:
87
+ if not checkpoint_path:
88
+ return False
89
+ path = Path(checkpoint_path)
90
+ if not path.exists():
91
+ LOGGER.warning("Requested %s segmenter checkpoint does not exist: %s", label, path)
92
+ return False
93
+ state = torch.load(path, map_location="cpu", weights_only=False)
94
+ if isinstance(state, dict) and "state_dict" in state:
95
+ state = state["state_dict"]
96
+ module.load_state_dict(state, strict=False)
97
+ LOGGER.info("Loaded %s segmenter checkpoint from %s", label, path)
98
+ return True
99
+
100
+ @property
101
+ def has_any_checkpoint(self) -> bool:
102
+ return self.loaded_lung_checkpoint or self.loaded_heart_checkpoint
103
+
104
+ @torch.no_grad()
105
+ def forward(self, pixel_values: torch.Tensor, num_layers: int, target_tokens: int, strength: float) -> torch.Tensor | None:
106
+ if not self.has_any_checkpoint:
107
+ return None
108
+
109
+ masks = []
110
+ if self.loaded_heart_checkpoint:
111
+ masks.append(self.heart_model(pixel_values))
112
+ if self.loaded_lung_checkpoint:
113
+ masks.append(self.lung_model(pixel_values))
114
+ if not masks:
115
+ return None
116
+
117
+ combined_mask = torch.clamp(sum(masks), 0.0, 1.0)
118
+ return build_layerwise_attention_bias(
119
+ masks=combined_mask,
120
+ num_layers=num_layers,
121
+ target_tokens=target_tokens,
122
+ strength=strength,
123
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5b80f4c62ee4863205f671c7d8670ebdd8e119bb449194633b15fa80b6479a7
3
+ size 1159628024
modeling_lana.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from lana_radgen.modeling_lana import LanaForConditionalGeneration
2
+
3
+ __all__ = ["LanaForConditionalGeneration"]
pipeline_autotune.json ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "method": "lora_adamw",
4
+ "batch_size": 1,
5
+ "global_batch_size": 8,
6
+ "candidate_dir": "C:\\Users\\emman\\Desktop\\PROYECTOS_VS_CODE\\PRUEBAS_DE_PYTHON\\Chest-X-ray-Diagnosis-Automated-Reporting-using-CNNs-and-LLMs---UDEM-PEF-Thesis-Fall-2025\\artifacts\\full_3_epoch_mask_run\\_autotune\\train\\candidate_0_lora_adamw_b1_g8",
7
+ "status": "ok",
8
+ "elapsed_seconds": 111.93326840000373,
9
+ "images_per_second": 9.704605089103115,
10
+ "steps": 16,
11
+ "train_loss_last": 7.840930461883545
12
+ },
13
+ "eval": {
14
+ "batch_size": 8,
15
+ "status": "ok",
16
+ "elapsed_seconds": 37.971331600005215,
17
+ "examples_per_second": 1.685482107243013
18
+ },
19
+ "benchmarks": {
20
+ "train": [
21
+ {
22
+ "method": "lora_adamw",
23
+ "batch_size": 1,
24
+ "global_batch_size": 8,
25
+ "candidate_dir": "C:\\Users\\emman\\Desktop\\PROYECTOS_VS_CODE\\PRUEBAS_DE_PYTHON\\Chest-X-ray-Diagnosis-Automated-Reporting-using-CNNs-and-LLMs---UDEM-PEF-Thesis-Fall-2025\\artifacts\\full_3_epoch_mask_run\\_autotune\\train\\candidate_0_lora_adamw_b1_g8",
26
+ "status": "ok",
27
+ "elapsed_seconds": 111.93326840000373,
28
+ "images_per_second": 9.704605089103115,
29
+ "steps": 16,
30
+ "train_loss_last": 7.840930461883545
31
+ },
32
+ {
33
+ "method": "lora_adamw",
34
+ "batch_size": 2,
35
+ "global_batch_size": 8,
36
+ "candidate_dir": "C:\\Users\\emman\\Desktop\\PROYECTOS_VS_CODE\\PRUEBAS_DE_PYTHON\\Chest-X-ray-Diagnosis-Automated-Reporting-using-CNNs-and-LLMs---UDEM-PEF-Thesis-Fall-2025\\artifacts\\full_3_epoch_mask_run\\_autotune\\train\\candidate_1_lora_adamw_b2_g8",
37
+ "status": "ok",
38
+ "elapsed_seconds": 114.21902520000003,
39
+ "images_per_second": 8.770855046196548,
40
+ "steps": 16,
41
+ "train_loss_last": 7.009981155395508
42
+ },
43
+ {
44
+ "method": "lora_adamw",
45
+ "batch_size": 2,
46
+ "global_batch_size": 4,
47
+ "candidate_dir": "C:\\Users\\emman\\Desktop\\PROYECTOS_VS_CODE\\PRUEBAS_DE_PYTHON\\Chest-X-ray-Diagnosis-Automated-Reporting-using-CNNs-and-LLMs---UDEM-PEF-Thesis-Fall-2025\\artifacts\\full_3_epoch_mask_run\\_autotune\\train\\candidate_2_lora_adamw_b2_g4",
48
+ "status": "ok",
49
+ "elapsed_seconds": 116.37137660000008,
50
+ "images_per_second": 8.927318141708337,
51
+ "steps": 16,
52
+ "train_loss_last": 7.958044052124023
53
+ },
54
+ {
55
+ "method": "full_adam8bit",
56
+ "batch_size": 1,
57
+ "global_batch_size": 8,
58
+ "candidate_dir": "C:\\Users\\emman\\Desktop\\PROYECTOS_VS_CODE\\PRUEBAS_DE_PYTHON\\Chest-X-ray-Diagnosis-Automated-Reporting-using-CNNs-and-LLMs---UDEM-PEF-Thesis-Fall-2025\\artifacts\\full_3_epoch_mask_run\\_autotune\\train\\candidate_3_full_adam8bit_b1_g8",
59
+ "status": "ok",
60
+ "elapsed_seconds": 121.47497799999837,
61
+ "images_per_second": 8.749786245453723,
62
+ "steps": 16,
63
+ "train_loss_last": 6.867033958435059
64
+ },
65
+ {
66
+ "method": "full_adamw",
67
+ "batch_size": 1,
68
+ "global_batch_size": 8,
69
+ "candidate_dir": "C:\\Users\\emman\\Desktop\\PROYECTOS_VS_CODE\\PRUEBAS_DE_PYTHON\\Chest-X-ray-Diagnosis-Automated-Reporting-using-CNNs-and-LLMs---UDEM-PEF-Thesis-Fall-2025\\artifacts\\full_3_epoch_mask_run\\_autotune\\train\\candidate_4_full_adamw_b1_g8",
70
+ "status": "ok",
71
+ "elapsed_seconds": 120.72124660000554,
72
+ "images_per_second": 8.777123798241934,
73
+ "steps": 16,
74
+ "train_loss_last": 6.959325313568115
75
+ },
76
+ {
77
+ "method": "qlora_paged_adamw8bit",
78
+ "batch_size": 1,
79
+ "global_batch_size": 8,
80
+ "status": "failed",
81
+ "error": "Command '['C:\\\\Users\\\\emman\\\\Desktop\\\\PROYECTOS_VS_CODE\\\\PRUEBAS_DE_PYTHON\\\\Chest-X-ray-Diagnosis-Automated-Reporting-using-CNNs-and-LLMs---UDEM-PEF-Thesis-Fall-2025\\\\venv310\\\\Scripts\\\\python.exe', 'scripts/train.py', '--run-name', 'autotune_train_5', '--dataset', 'combined', '--epochs', '1', '--batch-size', '1', '--global-batch-size', '8', '--eval-batch-size', '1', '--image-size', '512', '--device', 'cuda', '--output-dir', 'C:\\\\Users\\\\emman\\\\Desktop\\\\PROYECTOS_VS_CODE\\\\PRUEBAS_DE_PYTHON\\\\Chest-X-ray-Diagnosis-Automated-Reporting-using-CNNs-and-LLMs---UDEM-PEF-Thesis-Fall-2025\\\\artifacts\\\\full_3_epoch_mask_run\\\\_autotune\\\\train\\\\candidate_5_qlora_paged_adamw8bit_b1_g8', '--method', 'qlora_paged_adamw8bit', '--max-train-steps', '16', '--save-every-n-steps', '1000', '--log-level', 'INFO', '--disable-wandb']' returned non-zero exit status 1."
82
+ }
83
+ ],
84
+ "eval": [
85
+ {
86
+ "batch_size": 1,
87
+ "status": "ok",
88
+ "elapsed_seconds": 105.64014480000333,
89
+ "examples_per_second": 0.6058302941667114
90
+ },
91
+ {
92
+ "batch_size": 2,
93
+ "status": "ok",
94
+ "elapsed_seconds": 70.24700000000303,
95
+ "examples_per_second": 0.9110709354135728
96
+ },
97
+ {
98
+ "batch_size": 4,
99
+ "status": "ok",
100
+ "elapsed_seconds": 46.93239729999914,
101
+ "examples_per_second": 1.3636635603952234
102
+ },
103
+ {
104
+ "batch_size": 8,
105
+ "status": "ok",
106
+ "elapsed_seconds": 37.971331600005215,
107
+ "examples_per_second": 1.685482107243013
108
+ }
109
+ ]
110
+ }
111
+ }
run_summary.json ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "method": "lora_adamw",
3
+ "run_name": "mimic only",
4
+ "steps": 6692,
5
+ "epochs_completed": 0,
6
+ "epoch_index": 0,
7
+ "target_epochs": 3,
8
+ "progress_epochs": 0.38088056399348363,
9
+ "training_completion_percent": 12.696018799782788,
10
+ "elapsed_seconds": 3600.0493718999996,
11
+ "images_seen": 53540,
12
+ "train_loss_last": 0.6295745968818665,
13
+ "train_loss_mean": 2.933644461443075,
14
+ "val_loss": 2.31326277256012,
15
+ "images_per_second": 14.872018261167115,
16
+ "trainable_params": 2878464,
17
+ "vision_model_name": "facebook/dinov3-vits16-pretrain-lvd1689m",
18
+ "text_model_name": "gpt2",
19
+ "segmentation_model_name": "facebook/dinov3-convnext-small-pretrain-lvd1689m",
20
+ "lung_segmenter_checkpoint": "models/lung_segmenter_dinounet_finetuned.pth",
21
+ "heart_segmenter_checkpoint": "models/heart_segmenter_dinounet_best.pth",
22
+ "image_size": 512,
23
+ "batch_size": 1,
24
+ "global_batch_size": 8,
25
+ "gradient_accumulation_steps": 8,
26
+ "steps_per_epoch": 17572,
27
+ "planned_total_steps": 52716,
28
+ "scheduler": "cosine",
29
+ "warmup_steps": 2636,
30
+ "warmup_ratio": 0.05,
31
+ "weight_decay": 0.01,
32
+ "precision": "bf16",
33
+ "torch_compile": false,
34
+ "torch_compile_mode": "default",
35
+ "hardware": "NVIDIA GeForce RTX 5070",
36
+ "seed": 42,
37
+ "resume_supported": true,
38
+ "checkpoint_every_n_steps": 1000,
39
+ "cumulative_loss_sum": 157067.32446566224,
40
+ "cumulative_loss_count": 53540,
41
+ "completed": false,
42
+ "target_duration_seconds": 3600,
43
+ "target_duration_mode": "per_invocation",
44
+ "train_datasets": "MIMIC-CXR (findings-only)",
45
+ "validation_datasets": "MIMIC-CXR (findings-only)",
46
+ "latest_evaluation": {
47
+ "split": "test",
48
+ "subset": "all frontal studies",
49
+ "dataset": "mimic-cxr",
50
+ "view_filter": "frontal-only (PA/AP)",
51
+ "num_examples": 3041,
52
+ "bleu_1": 0.224283665754139,
53
+ "bleu_4": 0.038250603325884945,
54
+ "meteor": 0.20053919101008125,
55
+ "rouge_l": 0.16408262967411588,
56
+ "chexpert_f1_14_micro": 0.1244686823321838,
57
+ "chexpert_f1_5_micro": 0.21896350608231963,
58
+ "chexpert_f1_14_macro": 0.04429260458577706,
59
+ "chexpert_f1_5_macro": 0.09914404243362482,
60
+ "chexpert_f1_micro": 0.1244686823321838,
61
+ "chexpert_f1_macro": 0.04429260458577706,
62
+ "chexpert_per_label_f1": {
63
+ "Enlarged Cardiomediastinum": 0.0,
64
+ "Cardiomegaly": 0.0,
65
+ "Lung Opacity": 0.0,
66
+ "Lung Lesion": 0.0,
67
+ "Edema": 0.023121387283236997,
68
+ "Consolidation": 0.056790123456790124,
69
+ "Pneumonia": 0.02762430939226519,
70
+ "Atelectasis": 0.0,
71
+ "Pneumothorax": 0.059987236758136574,
72
+ "Pleural Effusion": 0.415808701428097,
73
+ "Pleural Other": 0.0,
74
+ "Fracture": 0.0,
75
+ "Support Devices": 0.036764705882352935,
76
+ "No Finding": 0.0
77
+ },
78
+ "radgraph_f1": 0.0941067057393548,
79
+ "radgraph_f1_entity": 0.18191243977753782,
80
+ "radgraph_f1_relation": 0.1652384677607375,
81
+ "radgraph_available": true,
82
+ "radgraph_error": null
83
+ },
84
+ "latest_evaluations": {
85
+ "all_test": {
86
+ "split": "test",
87
+ "subset": "all frontal studies",
88
+ "dataset": "mimic-cxr",
89
+ "view_filter": "frontal-only (PA/AP)",
90
+ "num_examples": 3041,
91
+ "bleu_1": 0.224283665754139,
92
+ "bleu_4": 0.038250603325884945,
93
+ "meteor": 0.20053919101008125,
94
+ "rouge_l": 0.16408262967411588,
95
+ "chexpert_f1_14_micro": 0.1244686823321838,
96
+ "chexpert_f1_5_micro": 0.21896350608231963,
97
+ "chexpert_f1_14_macro": 0.04429260458577706,
98
+ "chexpert_f1_5_macro": 0.09914404243362482,
99
+ "chexpert_f1_micro": 0.1244686823321838,
100
+ "chexpert_f1_macro": 0.04429260458577706,
101
+ "chexpert_per_label_f1": {
102
+ "Enlarged Cardiomediastinum": 0.0,
103
+ "Cardiomegaly": 0.0,
104
+ "Lung Opacity": 0.0,
105
+ "Lung Lesion": 0.0,
106
+ "Edema": 0.023121387283236997,
107
+ "Consolidation": 0.056790123456790124,
108
+ "Pneumonia": 0.02762430939226519,
109
+ "Atelectasis": 0.0,
110
+ "Pneumothorax": 0.059987236758136574,
111
+ "Pleural Effusion": 0.415808701428097,
112
+ "Pleural Other": 0.0,
113
+ "Fracture": 0.0,
114
+ "Support Devices": 0.036764705882352935,
115
+ "No Finding": 0.0
116
+ },
117
+ "radgraph_f1": 0.0941067057393548,
118
+ "radgraph_f1_entity": 0.18191243977753782,
119
+ "radgraph_f1_relation": 0.1652384677607375,
120
+ "radgraph_available": true,
121
+ "radgraph_error": null
122
+ },
123
+ "findings_only_test": {
124
+ "split": "test",
125
+ "subset": "findings-only frontal studies",
126
+ "dataset": "mimic-cxr",
127
+ "view_filter": "frontal-only (PA/AP), structured Findings section only",
128
+ "num_examples": 2210,
129
+ "bleu_1": 0.23099023872215996,
130
+ "bleu_4": 0.0429479479188206,
131
+ "meteor": 0.21248313160360002,
132
+ "rouge_l": 0.17210734193417726,
133
+ "chexpert_f1_14_micro": 0.11655011655011654,
134
+ "chexpert_f1_5_micro": 0.20709914320685435,
135
+ "chexpert_f1_14_macro": 0.04057070914402376,
136
+ "chexpert_f1_5_macro": 0.09202593660588262,
137
+ "chexpert_f1_micro": 0.11655011655011654,
138
+ "chexpert_f1_macro": 0.04057070914402376,
139
+ "chexpert_per_label_f1": {
140
+ "Enlarged Cardiomediastinum": 0.0,
141
+ "Cardiomegaly": 0.0,
142
+ "Lung Opacity": 0.0,
143
+ "Lung Lesion": 0.0,
144
+ "Edema": 0.022471910112359553,
145
+ "Consolidation": 0.05797101449275362,
146
+ "Pneumonia": 0.01673640167364017,
147
+ "Atelectasis": 0.0,
148
+ "Pneumothorax": 0.05716798592788039,
149
+ "Pleural Effusion": 0.3796867584243,
150
+ "Pleural Other": 0.0,
151
+ "Fracture": 0.0,
152
+ "Support Devices": 0.03395585738539898,
153
+ "No Finding": 0.0
154
+ },
155
+ "radgraph_f1": 0.10172866854646034,
156
+ "radgraph_f1_entity": 0.19217701907879298,
157
+ "radgraph_f1_relation": 0.17414731467894073,
158
+ "radgraph_available": true,
159
+ "radgraph_error": null
160
+ }
161
+ }
162
+ }
segmenters/heart_segmenter_dinounet_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7f17093041df317bdd22440789ce3aed407a8bda9d7527751d23e8c106fb59b
3
+ size 204910713
segmenters/lung_segmenter_dinounet_finetuned.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:086027098b3e2243dd56e5ef3b7a248a0532c3ae401da27091d94617d41b7403
3
+ size 204911991
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<|endoftext|>",
5
+ "eos_token": "<|endoftext|>",
6
+ "errors": "replace",
7
+ "is_local": false,
8
+ "model_max_length": 1024,
9
+ "pad_token": "<|endoftext|>",
10
+ "tokenizer_class": "GPT2Tokenizer",
11
+ "unk_token": "<|endoftext|>"
12
+ }