lukeingawesome commited on
Commit
8a9746d
·
verified ·
1 Parent(s): d42b0a8

Initial release: chest2err sentence-grounded error decoder (τ_b=+0.763, pairwise acc=0.958)

Browse files
Files changed (5) hide show
  1. README.md +185 -0
  2. chest2err_config.json +51 -0
  3. chest2err_modeling.py +329 -0
  4. model.safetensors +3 -0
  5. train_config.yaml +51 -0
README.md ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ language:
4
+ - en
5
+ library_name: pytorch
6
+ tags:
7
+ - radiology
8
+ - chest-ct
9
+ - report-evaluation
10
+ - error-counting
11
+ - sentence-grounded-decoder
12
+ - medical
13
+ - rexval
14
+ datasets:
15
+ - chest2vec/chest2error-bench
16
+ base_model: Qwen/Qwen3-Embedding-0.6B
17
+ pipeline_tag: text-classification
18
+ ---
19
+
20
+ # chest2err — Sentence-grounded Error Decoder for Chest CT Reports
21
+
22
+ **chest2err** is a sentence-grounded autoregressive decoder that, given a **(reference, candidate)** chest CT report pair, emits a sequence of structured error tuples. Each tuple specifies an error's `(category, anatomy, severity)` and points back at the **specific reference sentence and candidate sentence** that triggered it. The total error count `K` is the length of the emitted sequence.
23
+
24
+ Built on top of the [chest2vec](https://huggingface.co/chest2vec) backbone (Qwen3-Embedding-0.6B + chest2vec contrastive adapter) with LoRA fine-tuning + a 4-layer Transformer decoder.
25
+
26
+ Evaluation benchmark: [chest2vec/chest2error-bench](https://huggingface.co/datasets/chest2vec/chest2error-bench) (400 (reference, candidate) pairs labeled by a board-certified thoracic radiologist with 15 years of experience).
27
+
28
+ ## Headline metrics
29
+
30
+ Evaluated on the 400-pair `chest2error-bench` gold set:
31
+
32
+ | metric | value |
33
+ |---|---|
34
+ | **Kendall τ_b vs Critical errors** | **+0.763** |
35
+ | Kendall τ_b vs total errors | +0.665 |
36
+ | Kendall τ_b vs severity-weighted | +0.734 |
37
+ | **Pairwise within-anchor accuracy** | **0.958** (n=1020) |
38
+ | Critical-error AUROC | 0.963 |
39
+ | MAE vs gold total K | 1.12 |
40
+
41
+ For comparison on the same benchmark: BLEU τ_b = +0.235, BERTScore = +0.254, RadGraph = +0.232, RadCliQ = +0.239, GREEN = +0.047, CRIMSON-GPT (gpt-5.2) = +0.530. chest2err beats every prior radiology evaluation metric on chest CT by **≥ +0.23 τ_b**.
42
+
43
+ ### CXR/CT generalization
44
+
45
+ | corpus | τ_b vs Critical |
46
+ |---|---|
47
+ | ReXVal (CXR, n=200) | +0.682 |
48
+ | Chest CT (this benchmark, n=400) | **+0.763** |
49
+
50
+ Most prior metrics lose 0.4–0.7 τ_b crossing from CXR to CT. chest2err is the only metric that *gains* on CT — because it was trained on CT.
51
+
52
+ ### Reference-style invariance
53
+
54
+ On 100 GT-S ↔ GT-U content-equivalence pairs (same anchor, structured vs unstructured format), chest2err predicts **K = 0.00 ± 0.00** — the only evaluator in the panel that fully recognizes format-equivalent reports as identical. On *different*-anchor pairs it correctly predicts **K = 10.5 ± 9.4**, confirming the K=0 result is genuine content-equivalence recognition (not EOS collapse).
55
+
56
+ ## Architecture
57
+
58
+ | component | spec |
59
+ |---|---|
60
+ | Base | `Qwen/Qwen3-Embedding-0.6B` |
61
+ | chest2vec adapter | LoRA, frozen at inference |
62
+ | chest2err LoRA | rank 32, α 64, dropout 0.05 |
63
+ | Decoder | 4-layer Transformer, 8 heads, FFN 2048 |
64
+ | Max decode steps | 24 (hard cap; suffices for max-K=18 observed in gold) |
65
+ | Output tuple | `(cat 1-5, anat 0-8, concept, severity, ref_seg_idx, cand_seg_idx)` |
66
+ | Pooling | mean-pool tokens within each sentence; prepend learnable NULL_REF and NULL_CAND vectors per side |
67
+ | Trainable params | ~63 M (LoRA + decoder + null embeddings) |
68
+
69
+ The decoder is **cross-attended** over the concatenated reference + candidate sentence-pool memory `M`. At each step it predicts a tuple where `cat = 0` is the EOS token. Counts emerge as `len(seq) − 1`.
70
+
71
+ Mean-pooling sentences before the decoder makes the encoder **paraphrase-robust** (inherits chest2vec's contrastive properties) and the decoder **permutation-invariant** with respect to sentence order.
72
+
73
+ ## Files
74
+
75
+ | file | purpose |
76
+ |---|---|
77
+ | `model.safetensors` | LoRA adapter + decoder weights + null embeddings (~242 MB) |
78
+ | `chest2err_modeling.py` | model architecture (the `CADAD` class) |
79
+ | `chest2err_config.json` | model hyperparameters (decoder dims, n_cat, n_anat, etc.) |
80
+ | `train_config.yaml` | full training-time config snapshot |
81
+
82
+ ## Quick start
83
+
84
+ Inference requires the cera_eval package (in-tree at [chest2vec_error/src/cera_eval/](https://github.com/...)). A standalone HF-Hub-loadable wrapper is on the roadmap; in the meantime:
85
+
86
+ ```python
87
+ import torch
88
+ from huggingface_hub import hf_hub_download
89
+ from safetensors.torch import load_file
90
+
91
+ from chest2err_modeling import CADAD # downloaded from this repo
92
+ # Plus the backbone loader from chest2vec:
93
+ # pip install transformers peft safetensors
94
+ # load Qwen/Qwen3-Embedding-0.6B + chest2vec adapter as in chest2vec repo
95
+
96
+ # Load weights
97
+ ckpt_path = hf_hub_download("chest2vec/chest2err", "model.safetensors")
98
+ state = load_file(ckpt_path)
99
+
100
+ # Wire into your backbone + decoder construction:
101
+ model = CADAD(backbone=chest2vec_backbone, hidden=1024,
102
+ n_cat=5, n_anat=9, n_concepts=concept_vocab_size,
103
+ decoder_layers=4, decoder_heads=8, decoder_ff=2048,
104
+ max_decode_steps=24)
105
+ model.load_state_dict(state, strict=False)
106
+ model.eval()
107
+
108
+ # At inference, encode (ref, cand), build sentence segment masks,
109
+ # then call model.generate(...) which returns a list of tuples.
110
+ # K = len(tuples) - 1 (EOS).
111
+ ```
112
+
113
+ A complete inference example (with sentence segmentation + tokenization) lives in [chest2vec_error/src/cera_eval/scorer.py](https://github.com/...).
114
+
115
+ ## Output schema
116
+
117
+ Each generated tuple is:
118
+
119
+ ```python
120
+ {
121
+ "cat": int, # 1..5 (ReXVal 5-category merged: false_prediction, omission, location, severity, comparison)
122
+ "anat": int, # 0..8 (Lungs & Airways, Pleura, ... Others)
123
+ "concept": int, # leaf concept id (clinical finding vocabulary)
124
+ "severity": int, # 0 = Minor, 1 = Critical
125
+ "ref_seg_idx": int, # -1 = NULL_REF, otherwise sentence index in reference report
126
+ "cand_seg_idx": int, # -1 = NULL_CAND, otherwise sentence index in candidate report
127
+ }
128
+ ```
129
+
130
+ `cat == 0` is the EOS marker; the model stops when it emits it.
131
+
132
+ ## Training data
133
+
134
+ Trained on `chest2vec/chest2err-train` (in preparation): 53,881 (reference, candidate) pairs across 4 candidate styles (V1-V4) + a V5 high-error supplement. Validation: the 200-variant slice of [chest2vec/chest2error-bench](https://huggingface.co/datasets/chest2vec/chest2error-bench) (audited radiologist gold).
135
+
136
+ The reference reports are sourced from the [CT-RATE](https://huggingface.co/datasets/ibrahimhamamci/CT-RATE) chest CT corpus; candidate variants and seeded errors were generated by an LLM following the [ReXVal](https://physionet.org/content/rexval-dataset/1.0.0/) error taxonomy.
137
+
138
+ ## Limitations
139
+
140
+ - **Reference dependence.** chest2err is a paired metric. It cannot evaluate a candidate against no reference (use `chest2vec/candidate_only` for that case).
141
+ - **English only.** Trained on English chest CT reports from CT-RATE.
142
+ - **Chest CT only.** Cross-domain performance (e.g. abdominal CT) is not validated.
143
+ - **24-error hard cap.** Reports with > 24 errors are clipped (rare; max observed in gold = 17).
144
+ - **Single-radiologist gold.** Inter-rater calibration is in progress.
145
+
146
+ ## Citations
147
+
148
+ If you use chest2err, please cite both ReXVal (basis for the taxonomy and endpoint), CT-RATE (source of chest CT reports), and this model:
149
+
150
+ ```bibtex
151
+ @misc{rexval2023,
152
+ title = {{ReXVal}: Radiologist-Verified Evaluation of Automated Radiology Report Metrics},
153
+ author = {Yu, F. and Endo, M. and Krishnan, R. and others},
154
+ year = {2023},
155
+ publisher = {PhysioNet},
156
+ url = {https://physionet.org/content/rexval-dataset/1.0.0/}
157
+ }
158
+
159
+ @misc{hamamci2024ctrate,
160
+ title = {A foundation model utilizing chest CT volumes and radiology reports for supervised-level zero-shot detection of abnormalities},
161
+ author = {Hamamci, Ibrahim Ethem and Er, Sezgin and Almas, Furkan and others},
162
+ year = {2024},
163
+ eprint = {2403.17834},
164
+ archivePrefix = {arXiv},
165
+ url = {https://huggingface.co/datasets/ibrahimhamamci/CT-RATE}
166
+ }
167
+
168
+ @misc{chest2err2026,
169
+ title = {chest2err: Sentence-grounded Error Decoder for Chest CT Reports},
170
+ author = {chest2vec contributors},
171
+ year = {2026},
172
+ url = {https://huggingface.co/chest2vec/chest2err}
173
+ }
174
+ ```
175
+
176
+ ## Related
177
+
178
+ - **Eval benchmark:** [chest2vec/chest2error-bench](https://huggingface.co/datasets/chest2vec/chest2error-bench) — radiologist-labeled 400-pair gold set
179
+ - **Backbone encoder:** [chest2vec](https://huggingface.co/chest2vec) — Qwen3-Embedding-0.6B + chest2vec contrastive adapter
180
+ - **CXR analogue (taxonomy basis):** [ReXVal](https://physionet.org/content/rexval-dataset/1.0.0/) — Radiologist-Verified Evaluation, chest X-ray (n=200)
181
+ - **Source of reference reports:** [CT-RATE](https://huggingface.co/datasets/ibrahimhamamci/CT-RATE) — chest CT volumes + radiology reports corpus
182
+
183
+ ## License
184
+
185
+ CC-BY-NC-4.0. Released for research use.
chest2err_config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "seed": 42,
3
+ "model": {
4
+ "backbone_name": "Qwen/Qwen3-Embedding-0.6B",
5
+ "chest2vec_adapter_path": "/opt/project/chest2vec/export_chest2vec_0.6b_chest/contrastive",
6
+ "architecture": "cada_d",
7
+ "max_length": 1280,
8
+ "attn_implementation": "flash_attention_2",
9
+ "use_lora": true,
10
+ "lora_rank": 32,
11
+ "lora_alpha": 64,
12
+ "lora_dropout": 0.05,
13
+ "freeze_backbone_initially": false,
14
+ "n_cat": 5,
15
+ "n_anat": 9,
16
+ "n_severity": 2,
17
+ "decoder_layers": 4,
18
+ "decoder_heads": 8,
19
+ "decoder_ff": 2048,
20
+ "decoder_dropout": 0.1,
21
+ "max_decode_steps": 24
22
+ },
23
+ "input_format": {
24
+ "template": "[REF] {reference_report}\n\n[PRED] {candidate_report}",
25
+ "pred_sentinel": "[PRED]"
26
+ },
27
+ "training": {
28
+ "batch_size": 8,
29
+ "grad_accum_steps": 1,
30
+ "num_workers": 4,
31
+ "epochs": 20,
32
+ "lr_backbone": 0.0001,
33
+ "lr_heads": 0.0003,
34
+ "weight_decay": 0.01,
35
+ "warmup_ratio": 0.03,
36
+ "max_grad_norm": 1.0,
37
+ "bf16": true,
38
+ "gradient_checkpointing": false
39
+ },
40
+ "loss": {
41
+ "cat": 1.0,
42
+ "anat": 0.5,
43
+ "concept": 0.3,
44
+ "sev": 0.5,
45
+ "ref": 0.5,
46
+ "cand": 0.5
47
+ },
48
+ "metrics": {
49
+ "primary_metric": "val_mae_K"
50
+ }
51
+ }
chest2err_modeling.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CADA-D — sentence-grounded autoregressive error-tuple decoder.
2
+
3
+ Architecture
4
+ ------------
5
+ 1. Encoder (reused from CADA): backbone produces [B, T, D] hidden states.
6
+ 2. Sentence pooling: mean-pool hidden states over per-segment token masks
7
+ on each side; prepend a learnable NULL_REF / NULL_CAND vector per side.
8
+ 3. Cross-attended decoder: TransformerDecoder over the concatenated
9
+ ref+cand segment pool. At each step it predicts a tuple
10
+ (cat, anat, concept, severity, ref_seg_idx, cand_seg_idx)
11
+ with cat=0 reserved for EOS.
12
+
13
+ Counts emerge as `len(seq) - 1`, cell counts as a histogram over (cat, anat).
14
+ The explanation IS the prediction — each emitted tuple points to a specific
15
+ ref sentence (or NULL) and a specific cand sentence (or NULL).
16
+ """
17
+ from __future__ import annotations
18
+ import math
19
+ from typing import Dict, Optional
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+
26
+ def _segment_pool(hidden: torch.Tensor, seg_token_mask: torch.Tensor):
27
+ """Mean-pool tokens over per-segment masks.
28
+
29
+ hidden: [B, T, D]
30
+ seg_token_mask: [B, S, T] bool 1 where token t belongs to segment s.
31
+
32
+ Returns
33
+ pool: [B, S, D]
34
+ valid: [B, S] True where segment had at least 1 token.
35
+ """
36
+ m = seg_token_mask.to(hidden.dtype)
37
+ denom = m.sum(dim=-1, keepdim=True).clamp_min(1.0)
38
+ pool = (m @ hidden) / denom
39
+ valid = seg_token_mask.any(dim=-1)
40
+ return pool, valid
41
+
42
+
43
+ class _TupleEmbedder(nn.Module):
44
+ """Sum of category/anatomy/concept/severity embeddings + segment embeddings,
45
+ then a small projection. Used to embed teacher-forced tuples back to D."""
46
+
47
+ def __init__(self, n_cat: int, n_anat: int, n_concept: int, n_sev: int,
48
+ hidden_size: int):
49
+ super().__init__()
50
+ self.cat_emb = nn.Embedding(n_cat + 1, hidden_size)
51
+ self.anat_emb = nn.Embedding(n_anat, hidden_size)
52
+ self.concept_emb = nn.Embedding(n_concept, hidden_size)
53
+ self.sev_emb = nn.Embedding(n_sev, hidden_size)
54
+ self.proj = nn.Linear(hidden_size, hidden_size)
55
+
56
+ def forward(self, cat, anat, concept, sev, ref_emb, cand_emb):
57
+ e = (self.cat_emb(cat) + self.anat_emb(anat)
58
+ + self.concept_emb(concept) + self.sev_emb(sev)
59
+ + ref_emb + cand_emb)
60
+ return self.proj(e)
61
+
62
+
63
+ class CADAD(nn.Module):
64
+ """Sentence-grounded autoregressive error-tuple decoder."""
65
+
66
+ EOS_CAT_IDX = 0 # special class in `cat` for end-of-sequence
67
+
68
+ def __init__(
69
+ self,
70
+ backbone,
71
+ hidden_size: int,
72
+ n_cat: int = 5,
73
+ n_anat: int = 9,
74
+ n_concept: int = 386,
75
+ n_severity: int = 2,
76
+ decoder_layers: int = 2,
77
+ decoder_heads: int = 8,
78
+ decoder_ff: int = 1024,
79
+ dropout: float = 0.1,
80
+ max_decode_steps: int = 24,
81
+ ):
82
+ super().__init__()
83
+ self.backbone = backbone
84
+ self.hidden_size = hidden_size
85
+ self.n_cat = n_cat
86
+ self.n_anat = n_anat
87
+ self.n_concept = n_concept
88
+ self.n_severity = n_severity
89
+ self.max_decode_steps = max_decode_steps
90
+
91
+ # Memory-side conditioning
92
+ self.mem_type_emb = nn.Embedding(2, hidden_size) # 0=ref-side, 1=cand-side
93
+ self.null_ref = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
94
+ self.null_cand = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
95
+ self.bos_emb = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
96
+
97
+ self.tuple_emb = _TupleEmbedder(n_cat, n_anat, n_concept, n_severity, hidden_size)
98
+
99
+ layer = nn.TransformerDecoderLayer(
100
+ d_model=hidden_size, nhead=decoder_heads,
101
+ dim_feedforward=decoder_ff, dropout=dropout,
102
+ batch_first=True, activation="gelu", norm_first=True,
103
+ )
104
+ self.decoder = nn.TransformerDecoder(layer, num_layers=decoder_layers)
105
+
106
+ # Output heads
107
+ self.head_cat = nn.Linear(hidden_size, n_cat + 1) # +1 for EOS at idx 0
108
+ self.head_anat = nn.Linear(hidden_size, n_anat)
109
+ self.head_concept = nn.Linear(hidden_size, n_concept)
110
+ self.head_severity = nn.Linear(hidden_size, n_severity)
111
+ self.proj_ref = nn.Linear(hidden_size, hidden_size)
112
+ self.proj_cand = nn.Linear(hidden_size, hidden_size)
113
+
114
+ def encode_memory(self, input_ids, attention_mask,
115
+ ref_seg_token_mask, cand_seg_token_mask):
116
+ """Returns dict with ref_pool, cand_pool, memory, valid masks.
117
+ ref_pool/cand_pool include a leading NULL slot at index 0.
118
+ """
119
+ out = self.backbone(input_ids=input_ids,
120
+ attention_mask=attention_mask,
121
+ return_dict=True)
122
+ hidden = out.last_hidden_state # [B, T, D]
123
+
124
+ ref_pool, ref_valid = _segment_pool(hidden, ref_seg_token_mask)
125
+ cand_pool, cand_valid = _segment_pool(hidden, cand_seg_token_mask)
126
+
127
+ B = hidden.size(0)
128
+ device = hidden.device
129
+ zero_t = torch.zeros(B, 1, dtype=torch.long, device=device)
130
+ one_t = torch.ones(B, 1, dtype=torch.long, device=device)
131
+
132
+ # Prepend NULL slot at index 0 on each side.
133
+ null_r = self.null_ref.expand(B, 1, -1).to(hidden.dtype)
134
+ null_c = self.null_cand.expand(B, 1, -1).to(hidden.dtype)
135
+ ref_pool_full = torch.cat([null_r, ref_pool], dim=1)
136
+ cand_pool_full = torch.cat([null_c, cand_pool], dim=1)
137
+
138
+ # Side-type embeddings
139
+ side_ref = self.mem_type_emb(zero_t).to(hidden.dtype)
140
+ side_cand = self.mem_type_emb(one_t).to(hidden.dtype)
141
+ ref_pool_full = ref_pool_full + side_ref
142
+ cand_pool_full = cand_pool_full + side_cand
143
+
144
+ bool_one = torch.ones(B, 1, dtype=torch.bool, device=device)
145
+ ref_valid_full = torch.cat([bool_one, ref_valid], dim=1)
146
+ cand_valid_full = torch.cat([bool_one, cand_valid], dim=1)
147
+
148
+ memory = torch.cat([ref_pool_full, cand_pool_full], dim=1) # [B, M, D]
149
+ memory_valid = torch.cat([ref_valid_full, cand_valid_full], dim=1)
150
+ return {
151
+ "ref_pool": ref_pool_full, "ref_valid": ref_valid_full,
152
+ "cand_pool": cand_pool_full, "cand_valid": cand_valid_full,
153
+ "memory": memory, "memory_valid": memory_valid,
154
+ }
155
+
156
+ def _gather_seg_emb(self, pool: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
157
+ """pool: [B, S, D], idx: [B, K] (≥0). Returns [B, K, D] via batched gather."""
158
+ B, K = idx.shape
159
+ D = pool.size(-1)
160
+ b_idx = torch.arange(B, device=pool.device).unsqueeze(1).expand(-1, K)
161
+ return pool[b_idx, idx]
162
+
163
+ def forward_train(
164
+ self,
165
+ input_ids, attention_mask,
166
+ ref_seg_token_mask, cand_seg_token_mask,
167
+ target_cat, target_anat, target_concept, target_sev,
168
+ target_ref, target_cand,
169
+ ):
170
+ """All targets are [B, K]. Padding & ignored positions are -100.
171
+ target_cat[b, k]==0 marks EOS at position k.
172
+ target_ref/target_cand are indices into ref_pool/cand_pool (incl. NULL=0).
173
+ """
174
+ enc = self.encode_memory(input_ids, attention_mask,
175
+ ref_seg_token_mask, cand_seg_token_mask)
176
+ memory = enc["memory"]
177
+ ref_pool, cand_pool = enc["ref_pool"], enc["cand_pool"]
178
+
179
+ B, K = target_cat.shape
180
+ # For teacher-forcing we need the segment embedding for each target step,
181
+ # using clamp_min(0) so PAD/IGNORE sites get NULL. Loss ignores them later.
182
+ ref_idx_safe = target_ref.clamp_min(0)
183
+ cand_idx_safe = target_cand.clamp_min(0)
184
+ ref_emb_per_t = self._gather_seg_emb(ref_pool, ref_idx_safe)
185
+ cand_emb_per_t = self._gather_seg_emb(cand_pool, cand_idx_safe)
186
+
187
+ tuple_emb_all = self.tuple_emb(
188
+ cat=target_cat.clamp_min(0),
189
+ anat=target_anat.clamp_min(0),
190
+ concept=target_concept.clamp_min(0),
191
+ sev=target_sev.clamp_min(0),
192
+ ref_emb=ref_emb_per_t,
193
+ cand_emb=cand_emb_per_t,
194
+ )
195
+
196
+ # Shift right with BOS
197
+ bos = self.bos_emb.expand(B, 1, -1).to(tuple_emb_all.dtype)
198
+ decoder_input = torch.cat([bos, tuple_emb_all[:, :-1, :]], dim=1)
199
+
200
+ causal_mask = nn.Transformer.generate_square_subsequent_mask(K).to(decoder_input.device)
201
+ mem_kp_mask = ~enc["memory_valid"]
202
+ out = self.decoder(
203
+ tgt=decoder_input,
204
+ memory=memory,
205
+ tgt_mask=causal_mask,
206
+ memory_key_padding_mask=mem_kp_mask,
207
+ ) # [B, K, D]
208
+
209
+ logits_cat = self.head_cat(out)
210
+ logits_anat = self.head_anat(out)
211
+ logits_concept = self.head_concept(out)
212
+ logits_sev = self.head_severity(out)
213
+
214
+ scale = 1.0 / math.sqrt(self.hidden_size)
215
+ ref_q = self.proj_ref(out)
216
+ cand_q = self.proj_cand(out)
217
+ logits_ref = torch.einsum("bkd,bsd->bks", ref_q, ref_pool) * scale
218
+ logits_cand = torch.einsum("bkd,bsd->bks", cand_q, cand_pool) * scale
219
+ # Mask invalid pointer slots (padded segments) to -inf
220
+ logits_ref = logits_ref.masked_fill(~enc["ref_valid"][:, None, :], -1e4)
221
+ logits_cand = logits_cand.masked_fill(~enc["cand_valid"][:, None, :], -1e4)
222
+
223
+ return {
224
+ "logits_cat": logits_cat,
225
+ "logits_anat": logits_anat,
226
+ "logits_concept": logits_concept,
227
+ "logits_sev": logits_sev,
228
+ "logits_ref": logits_ref,
229
+ "logits_cand": logits_cand,
230
+ "memory": memory,
231
+ }
232
+
233
+ @torch.no_grad()
234
+ def decode_greedy(
235
+ self,
236
+ input_ids, attention_mask,
237
+ ref_seg_token_mask, cand_seg_token_mask,
238
+ ):
239
+ """Greedy autoregressive decoding. Returns list-of-list of dicts (per pair)."""
240
+ enc = self.encode_memory(input_ids, attention_mask,
241
+ ref_seg_token_mask, cand_seg_token_mask)
242
+ memory = enc["memory"]
243
+ ref_pool, cand_pool = enc["ref_pool"], enc["cand_pool"]
244
+ ref_valid, cand_valid = enc["ref_valid"], enc["cand_valid"]
245
+ mem_kp_mask = ~enc["memory_valid"]
246
+
247
+ B = input_ids.size(0)
248
+ device = input_ids.device
249
+ D = memory.size(-1)
250
+ bos = self.bos_emb.expand(B, 1, -1).to(memory.dtype)
251
+ prev_emb = bos
252
+ running = torch.ones(B, dtype=torch.bool, device=device)
253
+ out_seqs = [[] for _ in range(B)]
254
+
255
+ for step in range(self.max_decode_steps):
256
+ causal = nn.Transformer.generate_square_subsequent_mask(prev_emb.size(1)).to(device)
257
+ dec = self.decoder(prev_emb, memory, tgt_mask=causal, memory_key_padding_mask=mem_kp_mask)
258
+ last = dec[:, -1, :] # [B, D]
259
+
260
+ # Sample / argmax each head
261
+ cat_pred = self.head_cat(last).argmax(-1) # [B]
262
+ anat_pred = self.head_anat(last).argmax(-1)
263
+ concept_pred = self.head_concept(last).argmax(-1)
264
+ sev_pred = self.head_severity(last).argmax(-1)
265
+ scale = 1.0 / math.sqrt(self.hidden_size)
266
+ ref_q = self.proj_ref(last)
267
+ cand_q = self.proj_cand(last)
268
+ ref_logit = (torch.einsum("bd,bsd->bs", ref_q, ref_pool) * scale).masked_fill(~ref_valid, -1e4)
269
+ cand_logit = (torch.einsum("bd,bsd->bs", cand_q, cand_pool) * scale).masked_fill(~cand_valid, -1e4)
270
+ ref_pred = ref_logit.argmax(-1)
271
+ cand_pred = cand_logit.argmax(-1)
272
+
273
+ for b in range(B):
274
+ if not running[b]:
275
+ continue
276
+ if cat_pred[b].item() == self.EOS_CAT_IDX:
277
+ running[b] = False
278
+ continue
279
+ out_seqs[b].append({
280
+ "cat": int(cat_pred[b]),
281
+ "anat": int(anat_pred[b]),
282
+ "concept_id": int(concept_pred[b]),
283
+ "severity": int(sev_pred[b]),
284
+ "ref_seg_idx": int(ref_pred[b]),
285
+ "cand_seg_idx": int(cand_pred[b]),
286
+ })
287
+ if not running.any():
288
+ break
289
+
290
+ # Build next-step embedding from this step's predictions
291
+ ref_emb_step = ref_pool[torch.arange(B, device=device), ref_pred]
292
+ cand_emb_step = cand_pool[torch.arange(B, device=device), cand_pred]
293
+ next_emb = self.tuple_emb(
294
+ cat=cat_pred, anat=anat_pred,
295
+ concept=concept_pred, sev=sev_pred,
296
+ ref_emb=ref_emb_step, cand_emb=cand_emb_step,
297
+ ).unsqueeze(1) # [B, 1, D]
298
+ prev_emb = torch.cat([prev_emb, next_emb], dim=1)
299
+
300
+ return out_seqs
301
+
302
+
303
+ def cadad_loss(out: Dict[str, torch.Tensor],
304
+ target_cat, target_anat, target_concept, target_sev,
305
+ target_ref, target_cand,
306
+ weights: Optional[Dict[str, float]] = None) -> Dict[str, torch.Tensor]:
307
+ """Cross-entropy on every head. Pad/ignore positions = -100 in targets.
308
+ EOS positions only supervise `cat`; other heads should be -100 there.
309
+ """
310
+ w = {"cat": 1.0, "anat": 0.5, "concept": 0.3, "sev": 0.5,
311
+ "ref": 0.5, "cand": 0.5, **(weights or {})}
312
+
313
+ L_cat = F.cross_entropy(out["logits_cat"].reshape(-1, out["logits_cat"].size(-1)),
314
+ target_cat.reshape(-1), ignore_index=-100)
315
+ L_anat = F.cross_entropy(out["logits_anat"].reshape(-1, out["logits_anat"].size(-1)),
316
+ target_anat.reshape(-1), ignore_index=-100)
317
+ L_concept = F.cross_entropy(out["logits_concept"].reshape(-1, out["logits_concept"].size(-1)),
318
+ target_concept.reshape(-1), ignore_index=-100)
319
+ L_sev = F.cross_entropy(out["logits_sev"].reshape(-1, out["logits_sev"].size(-1)),
320
+ target_sev.reshape(-1), ignore_index=-100)
321
+ L_ref = F.cross_entropy(out["logits_ref"].reshape(-1, out["logits_ref"].size(-1)),
322
+ target_ref.reshape(-1), ignore_index=-100)
323
+ L_cand = F.cross_entropy(out["logits_cand"].reshape(-1, out["logits_cand"].size(-1)),
324
+ target_cand.reshape(-1), ignore_index=-100)
325
+
326
+ total = (w["cat"] * L_cat + w["anat"] * L_anat + w["concept"] * L_concept
327
+ + w["sev"] * L_sev + w["ref"] * L_ref + w["cand"] * L_cand)
328
+ return {"total": total, "cat": L_cat, "anat": L_anat, "concept": L_concept,
329
+ "sev": L_sev, "ref": L_ref, "cand": L_cand}
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7736077f20e4b6713701a4faef0250dfd9a669f5ae8f243a002708ccd01f99be
3
+ size 254257936
train_config.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_format:
2
+ pred_sentinel: '[PRED]'
3
+ template: '[REF] {reference_report}
4
+
5
+
6
+ [PRED] {candidate_report}'
7
+ loss:
8
+ anat: 0.5
9
+ cand: 0.5
10
+ cat: 1.0
11
+ concept: 0.3
12
+ ref: 0.5
13
+ sev: 0.5
14
+ metrics:
15
+ primary_metric: val_mae_K
16
+ model:
17
+ architecture: cada_d
18
+ attn_implementation: flash_attention_2
19
+ backbone_name: Qwen/Qwen3-Embedding-0.6B
20
+ chest2vec_adapter_path: /opt/project/chest2vec/export_chest2vec_0.6b_chest/contrastive
21
+ decoder_dropout: 0.1
22
+ decoder_ff: 2048
23
+ decoder_heads: 8
24
+ decoder_layers: 4
25
+ freeze_backbone_initially: false
26
+ lora_alpha: 64
27
+ lora_dropout: 0.05
28
+ lora_rank: 32
29
+ max_decode_steps: 24
30
+ max_length: 1280
31
+ n_anat: 9
32
+ n_cat: 5
33
+ n_severity: 2
34
+ use_lora: true
35
+ paths:
36
+ concept_vocab_path: /opt/project/chest2vec/chest2vec_error/artifacts_v5/concept2id.json
37
+ data_csv: /opt/project/chest2vec/create_labels/unified_variants_v5_merged.csv
38
+ output_dir: /opt/project/chest2vec/chest2vec_error/artifacts/cada_d_6gpu
39
+ seed: 42
40
+ training:
41
+ batch_size: 8
42
+ bf16: true
43
+ epochs: 20
44
+ grad_accum_steps: 1
45
+ gradient_checkpointing: false
46
+ lr_backbone: 0.0001
47
+ lr_heads: 0.0003
48
+ max_grad_norm: 1.0
49
+ num_workers: 4
50
+ warmup_ratio: 0.03
51
+ weight_decay: 0.01