Initial release: chest2err sentence-grounded error decoder (τ_b=+0.763, pairwise acc=0.958)
Browse files- README.md +185 -0
- chest2err_config.json +51 -0
- chest2err_modeling.py +329 -0
- model.safetensors +3 -0
- train_config.yaml +51 -0
README.md
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-nc-4.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: pytorch
|
| 6 |
+
tags:
|
| 7 |
+
- radiology
|
| 8 |
+
- chest-ct
|
| 9 |
+
- report-evaluation
|
| 10 |
+
- error-counting
|
| 11 |
+
- sentence-grounded-decoder
|
| 12 |
+
- medical
|
| 13 |
+
- rexval
|
| 14 |
+
datasets:
|
| 15 |
+
- chest2vec/chest2error-bench
|
| 16 |
+
base_model: Qwen/Qwen3-Embedding-0.6B
|
| 17 |
+
pipeline_tag: text-classification
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
# chest2err — Sentence-grounded Error Decoder for Chest CT Reports
|
| 21 |
+
|
| 22 |
+
**chest2err** is a sentence-grounded autoregressive decoder that, given a **(reference, candidate)** chest CT report pair, emits a sequence of structured error tuples. Each tuple specifies an error's `(category, anatomy, severity)` and points back at the **specific reference sentence and candidate sentence** that triggered it. The total error count `K` is the length of the emitted sequence.
|
| 23 |
+
|
| 24 |
+
Built on top of the [chest2vec](https://huggingface.co/chest2vec) backbone (Qwen3-Embedding-0.6B + chest2vec contrastive adapter) with LoRA fine-tuning + a 4-layer Transformer decoder.
|
| 25 |
+
|
| 26 |
+
Evaluation benchmark: [chest2vec/chest2error-bench](https://huggingface.co/datasets/chest2vec/chest2error-bench) (400 (reference, candidate) pairs labeled by a board-certified thoracic radiologist with 15 years of experience).
|
| 27 |
+
|
| 28 |
+
## Headline metrics
|
| 29 |
+
|
| 30 |
+
Evaluated on the 400-pair `chest2error-bench` gold set:
|
| 31 |
+
|
| 32 |
+
| metric | value |
|
| 33 |
+
|---|---|
|
| 34 |
+
| **Kendall τ_b vs Critical errors** | **+0.763** |
|
| 35 |
+
| Kendall τ_b vs total errors | +0.665 |
|
| 36 |
+
| Kendall τ_b vs severity-weighted | +0.734 |
|
| 37 |
+
| **Pairwise within-anchor accuracy** | **0.958** (n=1020) |
|
| 38 |
+
| Critical-error AUROC | 0.963 |
|
| 39 |
+
| MAE vs gold total K | 1.12 |
|
| 40 |
+
|
| 41 |
+
For comparison on the same benchmark: BLEU τ_b = +0.235, BERTScore = +0.254, RadGraph = +0.232, RadCliQ = +0.239, GREEN = +0.047, CRIMSON-GPT (gpt-5.2) = +0.530. chest2err beats every prior radiology evaluation metric on chest CT by **≥ +0.23 τ_b**.
|
| 42 |
+
|
| 43 |
+
### CXR/CT generalization
|
| 44 |
+
|
| 45 |
+
| corpus | τ_b vs Critical |
|
| 46 |
+
|---|---|
|
| 47 |
+
| ReXVal (CXR, n=200) | +0.682 |
|
| 48 |
+
| Chest CT (this benchmark, n=400) | **+0.763** |
|
| 49 |
+
|
| 50 |
+
Most prior metrics lose 0.4–0.7 τ_b crossing from CXR to CT. chest2err is the only metric that *gains* on CT — because it was trained on CT.
|
| 51 |
+
|
| 52 |
+
### Reference-style invariance
|
| 53 |
+
|
| 54 |
+
On 100 GT-S ↔ GT-U content-equivalence pairs (same anchor, structured vs unstructured format), chest2err predicts **K = 0.00 ± 0.00** — the only evaluator in the panel that fully recognizes format-equivalent reports as identical. On *different*-anchor pairs it correctly predicts **K = 10.5 ± 9.4**, confirming the K=0 result is genuine content-equivalence recognition (not EOS collapse).
|
| 55 |
+
|
| 56 |
+
## Architecture
|
| 57 |
+
|
| 58 |
+
| component | spec |
|
| 59 |
+
|---|---|
|
| 60 |
+
| Base | `Qwen/Qwen3-Embedding-0.6B` |
|
| 61 |
+
| chest2vec adapter | LoRA, frozen at inference |
|
| 62 |
+
| chest2err LoRA | rank 32, α 64, dropout 0.05 |
|
| 63 |
+
| Decoder | 4-layer Transformer, 8 heads, FFN 2048 |
|
| 64 |
+
| Max decode steps | 24 (hard cap; suffices for max-K=18 observed in gold) |
|
| 65 |
+
| Output tuple | `(cat 1-5, anat 0-8, concept, severity, ref_seg_idx, cand_seg_idx)` |
|
| 66 |
+
| Pooling | mean-pool tokens within each sentence; prepend learnable NULL_REF and NULL_CAND vectors per side |
|
| 67 |
+
| Trainable params | ~63 M (LoRA + decoder + null embeddings) |
|
| 68 |
+
|
| 69 |
+
The decoder is **cross-attended** over the concatenated reference + candidate sentence-pool memory `M`. At each step it predicts a tuple where `cat = 0` is the EOS token. Counts emerge as `len(seq) − 1`.
|
| 70 |
+
|
| 71 |
+
Mean-pooling sentences before the decoder makes the encoder **paraphrase-robust** (inherits chest2vec's contrastive properties) and the decoder **permutation-invariant** with respect to sentence order.
|
| 72 |
+
|
| 73 |
+
## Files
|
| 74 |
+
|
| 75 |
+
| file | purpose |
|
| 76 |
+
|---|---|
|
| 77 |
+
| `model.safetensors` | LoRA adapter + decoder weights + null embeddings (~242 MB) |
|
| 78 |
+
| `chest2err_modeling.py` | model architecture (the `CADAD` class) |
|
| 79 |
+
| `chest2err_config.json` | model hyperparameters (decoder dims, n_cat, n_anat, etc.) |
|
| 80 |
+
| `train_config.yaml` | full training-time config snapshot |
|
| 81 |
+
|
| 82 |
+
## Quick start
|
| 83 |
+
|
| 84 |
+
Inference requires the cera_eval package (in-tree at [chest2vec_error/src/cera_eval/](https://github.com/...)). A standalone HF-Hub-loadable wrapper is on the roadmap; in the meantime:
|
| 85 |
+
|
| 86 |
+
```python
|
| 87 |
+
import torch
|
| 88 |
+
from huggingface_hub import hf_hub_download
|
| 89 |
+
from safetensors.torch import load_file
|
| 90 |
+
|
| 91 |
+
from chest2err_modeling import CADAD # downloaded from this repo
|
| 92 |
+
# Plus the backbone loader from chest2vec:
|
| 93 |
+
# pip install transformers peft safetensors
|
| 94 |
+
# load Qwen/Qwen3-Embedding-0.6B + chest2vec adapter as in chest2vec repo
|
| 95 |
+
|
| 96 |
+
# Load weights
|
| 97 |
+
ckpt_path = hf_hub_download("chest2vec/chest2err", "model.safetensors")
|
| 98 |
+
state = load_file(ckpt_path)
|
| 99 |
+
|
| 100 |
+
# Wire into your backbone + decoder construction:
|
| 101 |
+
model = CADAD(backbone=chest2vec_backbone, hidden=1024,
|
| 102 |
+
n_cat=5, n_anat=9, n_concepts=concept_vocab_size,
|
| 103 |
+
decoder_layers=4, decoder_heads=8, decoder_ff=2048,
|
| 104 |
+
max_decode_steps=24)
|
| 105 |
+
model.load_state_dict(state, strict=False)
|
| 106 |
+
model.eval()
|
| 107 |
+
|
| 108 |
+
# At inference, encode (ref, cand), build sentence segment masks,
|
| 109 |
+
# then call model.generate(...) which returns a list of tuples.
|
| 110 |
+
# K = len(tuples) - 1 (EOS).
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
A complete inference example (with sentence segmentation + tokenization) lives in [chest2vec_error/src/cera_eval/scorer.py](https://github.com/...).
|
| 114 |
+
|
| 115 |
+
## Output schema
|
| 116 |
+
|
| 117 |
+
Each generated tuple is:
|
| 118 |
+
|
| 119 |
+
```python
|
| 120 |
+
{
|
| 121 |
+
"cat": int, # 1..5 (ReXVal 5-category merged: false_prediction, omission, location, severity, comparison)
|
| 122 |
+
"anat": int, # 0..8 (Lungs & Airways, Pleura, ... Others)
|
| 123 |
+
"concept": int, # leaf concept id (clinical finding vocabulary)
|
| 124 |
+
"severity": int, # 0 = Minor, 1 = Critical
|
| 125 |
+
"ref_seg_idx": int, # -1 = NULL_REF, otherwise sentence index in reference report
|
| 126 |
+
"cand_seg_idx": int, # -1 = NULL_CAND, otherwise sentence index in candidate report
|
| 127 |
+
}
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
`cat == 0` is the EOS marker; the model stops when it emits it.
|
| 131 |
+
|
| 132 |
+
## Training data
|
| 133 |
+
|
| 134 |
+
Trained on `chest2vec/chest2err-train` (in preparation): 53,881 (reference, candidate) pairs across 4 candidate styles (V1-V4) + a V5 high-error supplement. Validation: the 200-variant slice of [chest2vec/chest2error-bench](https://huggingface.co/datasets/chest2vec/chest2error-bench) (audited radiologist gold).
|
| 135 |
+
|
| 136 |
+
The reference reports are sourced from the [CT-RATE](https://huggingface.co/datasets/ibrahimhamamci/CT-RATE) chest CT corpus; candidate variants and seeded errors were generated by an LLM following the [ReXVal](https://physionet.org/content/rexval-dataset/1.0.0/) error taxonomy.
|
| 137 |
+
|
| 138 |
+
## Limitations
|
| 139 |
+
|
| 140 |
+
- **Reference dependence.** chest2err is a paired metric. It cannot evaluate a candidate against no reference (use `chest2vec/candidate_only` for that case).
|
| 141 |
+
- **English only.** Trained on English chest CT reports from CT-RATE.
|
| 142 |
+
- **Chest CT only.** Cross-domain performance (e.g. abdominal CT) is not validated.
|
| 143 |
+
- **24-error hard cap.** Reports with > 24 errors are clipped (rare; max observed in gold = 17).
|
| 144 |
+
- **Single-radiologist gold.** Inter-rater calibration is in progress.
|
| 145 |
+
|
| 146 |
+
## Citations
|
| 147 |
+
|
| 148 |
+
If you use chest2err, please cite both ReXVal (basis for the taxonomy and endpoint), CT-RATE (source of chest CT reports), and this model:
|
| 149 |
+
|
| 150 |
+
```bibtex
|
| 151 |
+
@misc{rexval2023,
|
| 152 |
+
title = {{ReXVal}: Radiologist-Verified Evaluation of Automated Radiology Report Metrics},
|
| 153 |
+
author = {Yu, F. and Endo, M. and Krishnan, R. and others},
|
| 154 |
+
year = {2023},
|
| 155 |
+
publisher = {PhysioNet},
|
| 156 |
+
url = {https://physionet.org/content/rexval-dataset/1.0.0/}
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
@misc{hamamci2024ctrate,
|
| 160 |
+
title = {A foundation model utilizing chest CT volumes and radiology reports for supervised-level zero-shot detection of abnormalities},
|
| 161 |
+
author = {Hamamci, Ibrahim Ethem and Er, Sezgin and Almas, Furkan and others},
|
| 162 |
+
year = {2024},
|
| 163 |
+
eprint = {2403.17834},
|
| 164 |
+
archivePrefix = {arXiv},
|
| 165 |
+
url = {https://huggingface.co/datasets/ibrahimhamamci/CT-RATE}
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
@misc{chest2err2026,
|
| 169 |
+
title = {chest2err: Sentence-grounded Error Decoder for Chest CT Reports},
|
| 170 |
+
author = {chest2vec contributors},
|
| 171 |
+
year = {2026},
|
| 172 |
+
url = {https://huggingface.co/chest2vec/chest2err}
|
| 173 |
+
}
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
## Related
|
| 177 |
+
|
| 178 |
+
- **Eval benchmark:** [chest2vec/chest2error-bench](https://huggingface.co/datasets/chest2vec/chest2error-bench) — radiologist-labeled 400-pair gold set
|
| 179 |
+
- **Backbone encoder:** [chest2vec](https://huggingface.co/chest2vec) — Qwen3-Embedding-0.6B + chest2vec contrastive adapter
|
| 180 |
+
- **CXR analogue (taxonomy basis):** [ReXVal](https://physionet.org/content/rexval-dataset/1.0.0/) — Radiologist-Verified Evaluation, chest X-ray (n=200)
|
| 181 |
+
- **Source of reference reports:** [CT-RATE](https://huggingface.co/datasets/ibrahimhamamci/CT-RATE) — chest CT volumes + radiology reports corpus
|
| 182 |
+
|
| 183 |
+
## License
|
| 184 |
+
|
| 185 |
+
CC-BY-NC-4.0. Released for research use.
|
chest2err_config.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"seed": 42,
|
| 3 |
+
"model": {
|
| 4 |
+
"backbone_name": "Qwen/Qwen3-Embedding-0.6B",
|
| 5 |
+
"chest2vec_adapter_path": "/opt/project/chest2vec/export_chest2vec_0.6b_chest/contrastive",
|
| 6 |
+
"architecture": "cada_d",
|
| 7 |
+
"max_length": 1280,
|
| 8 |
+
"attn_implementation": "flash_attention_2",
|
| 9 |
+
"use_lora": true,
|
| 10 |
+
"lora_rank": 32,
|
| 11 |
+
"lora_alpha": 64,
|
| 12 |
+
"lora_dropout": 0.05,
|
| 13 |
+
"freeze_backbone_initially": false,
|
| 14 |
+
"n_cat": 5,
|
| 15 |
+
"n_anat": 9,
|
| 16 |
+
"n_severity": 2,
|
| 17 |
+
"decoder_layers": 4,
|
| 18 |
+
"decoder_heads": 8,
|
| 19 |
+
"decoder_ff": 2048,
|
| 20 |
+
"decoder_dropout": 0.1,
|
| 21 |
+
"max_decode_steps": 24
|
| 22 |
+
},
|
| 23 |
+
"input_format": {
|
| 24 |
+
"template": "[REF] {reference_report}\n\n[PRED] {candidate_report}",
|
| 25 |
+
"pred_sentinel": "[PRED]"
|
| 26 |
+
},
|
| 27 |
+
"training": {
|
| 28 |
+
"batch_size": 8,
|
| 29 |
+
"grad_accum_steps": 1,
|
| 30 |
+
"num_workers": 4,
|
| 31 |
+
"epochs": 20,
|
| 32 |
+
"lr_backbone": 0.0001,
|
| 33 |
+
"lr_heads": 0.0003,
|
| 34 |
+
"weight_decay": 0.01,
|
| 35 |
+
"warmup_ratio": 0.03,
|
| 36 |
+
"max_grad_norm": 1.0,
|
| 37 |
+
"bf16": true,
|
| 38 |
+
"gradient_checkpointing": false
|
| 39 |
+
},
|
| 40 |
+
"loss": {
|
| 41 |
+
"cat": 1.0,
|
| 42 |
+
"anat": 0.5,
|
| 43 |
+
"concept": 0.3,
|
| 44 |
+
"sev": 0.5,
|
| 45 |
+
"ref": 0.5,
|
| 46 |
+
"cand": 0.5
|
| 47 |
+
},
|
| 48 |
+
"metrics": {
|
| 49 |
+
"primary_metric": "val_mae_K"
|
| 50 |
+
}
|
| 51 |
+
}
|
chest2err_modeling.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CADA-D — sentence-grounded autoregressive error-tuple decoder.
|
| 2 |
+
|
| 3 |
+
Architecture
|
| 4 |
+
------------
|
| 5 |
+
1. Encoder (reused from CADA): backbone produces [B, T, D] hidden states.
|
| 6 |
+
2. Sentence pooling: mean-pool hidden states over per-segment token masks
|
| 7 |
+
on each side; prepend a learnable NULL_REF / NULL_CAND vector per side.
|
| 8 |
+
3. Cross-attended decoder: TransformerDecoder over the concatenated
|
| 9 |
+
ref+cand segment pool. At each step it predicts a tuple
|
| 10 |
+
(cat, anat, concept, severity, ref_seg_idx, cand_seg_idx)
|
| 11 |
+
with cat=0 reserved for EOS.
|
| 12 |
+
|
| 13 |
+
Counts emerge as `len(seq) - 1`, cell counts as a histogram over (cat, anat).
|
| 14 |
+
The explanation IS the prediction — each emitted tuple points to a specific
|
| 15 |
+
ref sentence (or NULL) and a specific cand sentence (or NULL).
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
import math
|
| 19 |
+
from typing import Dict, Optional
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _segment_pool(hidden: torch.Tensor, seg_token_mask: torch.Tensor):
|
| 27 |
+
"""Mean-pool tokens over per-segment masks.
|
| 28 |
+
|
| 29 |
+
hidden: [B, T, D]
|
| 30 |
+
seg_token_mask: [B, S, T] bool 1 where token t belongs to segment s.
|
| 31 |
+
|
| 32 |
+
Returns
|
| 33 |
+
pool: [B, S, D]
|
| 34 |
+
valid: [B, S] True where segment had at least 1 token.
|
| 35 |
+
"""
|
| 36 |
+
m = seg_token_mask.to(hidden.dtype)
|
| 37 |
+
denom = m.sum(dim=-1, keepdim=True).clamp_min(1.0)
|
| 38 |
+
pool = (m @ hidden) / denom
|
| 39 |
+
valid = seg_token_mask.any(dim=-1)
|
| 40 |
+
return pool, valid
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class _TupleEmbedder(nn.Module):
|
| 44 |
+
"""Sum of category/anatomy/concept/severity embeddings + segment embeddings,
|
| 45 |
+
then a small projection. Used to embed teacher-forced tuples back to D."""
|
| 46 |
+
|
| 47 |
+
def __init__(self, n_cat: int, n_anat: int, n_concept: int, n_sev: int,
|
| 48 |
+
hidden_size: int):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.cat_emb = nn.Embedding(n_cat + 1, hidden_size)
|
| 51 |
+
self.anat_emb = nn.Embedding(n_anat, hidden_size)
|
| 52 |
+
self.concept_emb = nn.Embedding(n_concept, hidden_size)
|
| 53 |
+
self.sev_emb = nn.Embedding(n_sev, hidden_size)
|
| 54 |
+
self.proj = nn.Linear(hidden_size, hidden_size)
|
| 55 |
+
|
| 56 |
+
def forward(self, cat, anat, concept, sev, ref_emb, cand_emb):
|
| 57 |
+
e = (self.cat_emb(cat) + self.anat_emb(anat)
|
| 58 |
+
+ self.concept_emb(concept) + self.sev_emb(sev)
|
| 59 |
+
+ ref_emb + cand_emb)
|
| 60 |
+
return self.proj(e)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class CADAD(nn.Module):
|
| 64 |
+
"""Sentence-grounded autoregressive error-tuple decoder."""
|
| 65 |
+
|
| 66 |
+
EOS_CAT_IDX = 0 # special class in `cat` for end-of-sequence
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
backbone,
|
| 71 |
+
hidden_size: int,
|
| 72 |
+
n_cat: int = 5,
|
| 73 |
+
n_anat: int = 9,
|
| 74 |
+
n_concept: int = 386,
|
| 75 |
+
n_severity: int = 2,
|
| 76 |
+
decoder_layers: int = 2,
|
| 77 |
+
decoder_heads: int = 8,
|
| 78 |
+
decoder_ff: int = 1024,
|
| 79 |
+
dropout: float = 0.1,
|
| 80 |
+
max_decode_steps: int = 24,
|
| 81 |
+
):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.backbone = backbone
|
| 84 |
+
self.hidden_size = hidden_size
|
| 85 |
+
self.n_cat = n_cat
|
| 86 |
+
self.n_anat = n_anat
|
| 87 |
+
self.n_concept = n_concept
|
| 88 |
+
self.n_severity = n_severity
|
| 89 |
+
self.max_decode_steps = max_decode_steps
|
| 90 |
+
|
| 91 |
+
# Memory-side conditioning
|
| 92 |
+
self.mem_type_emb = nn.Embedding(2, hidden_size) # 0=ref-side, 1=cand-side
|
| 93 |
+
self.null_ref = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
|
| 94 |
+
self.null_cand = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
|
| 95 |
+
self.bos_emb = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
|
| 96 |
+
|
| 97 |
+
self.tuple_emb = _TupleEmbedder(n_cat, n_anat, n_concept, n_severity, hidden_size)
|
| 98 |
+
|
| 99 |
+
layer = nn.TransformerDecoderLayer(
|
| 100 |
+
d_model=hidden_size, nhead=decoder_heads,
|
| 101 |
+
dim_feedforward=decoder_ff, dropout=dropout,
|
| 102 |
+
batch_first=True, activation="gelu", norm_first=True,
|
| 103 |
+
)
|
| 104 |
+
self.decoder = nn.TransformerDecoder(layer, num_layers=decoder_layers)
|
| 105 |
+
|
| 106 |
+
# Output heads
|
| 107 |
+
self.head_cat = nn.Linear(hidden_size, n_cat + 1) # +1 for EOS at idx 0
|
| 108 |
+
self.head_anat = nn.Linear(hidden_size, n_anat)
|
| 109 |
+
self.head_concept = nn.Linear(hidden_size, n_concept)
|
| 110 |
+
self.head_severity = nn.Linear(hidden_size, n_severity)
|
| 111 |
+
self.proj_ref = nn.Linear(hidden_size, hidden_size)
|
| 112 |
+
self.proj_cand = nn.Linear(hidden_size, hidden_size)
|
| 113 |
+
|
| 114 |
+
def encode_memory(self, input_ids, attention_mask,
|
| 115 |
+
ref_seg_token_mask, cand_seg_token_mask):
|
| 116 |
+
"""Returns dict with ref_pool, cand_pool, memory, valid masks.
|
| 117 |
+
ref_pool/cand_pool include a leading NULL slot at index 0.
|
| 118 |
+
"""
|
| 119 |
+
out = self.backbone(input_ids=input_ids,
|
| 120 |
+
attention_mask=attention_mask,
|
| 121 |
+
return_dict=True)
|
| 122 |
+
hidden = out.last_hidden_state # [B, T, D]
|
| 123 |
+
|
| 124 |
+
ref_pool, ref_valid = _segment_pool(hidden, ref_seg_token_mask)
|
| 125 |
+
cand_pool, cand_valid = _segment_pool(hidden, cand_seg_token_mask)
|
| 126 |
+
|
| 127 |
+
B = hidden.size(0)
|
| 128 |
+
device = hidden.device
|
| 129 |
+
zero_t = torch.zeros(B, 1, dtype=torch.long, device=device)
|
| 130 |
+
one_t = torch.ones(B, 1, dtype=torch.long, device=device)
|
| 131 |
+
|
| 132 |
+
# Prepend NULL slot at index 0 on each side.
|
| 133 |
+
null_r = self.null_ref.expand(B, 1, -1).to(hidden.dtype)
|
| 134 |
+
null_c = self.null_cand.expand(B, 1, -1).to(hidden.dtype)
|
| 135 |
+
ref_pool_full = torch.cat([null_r, ref_pool], dim=1)
|
| 136 |
+
cand_pool_full = torch.cat([null_c, cand_pool], dim=1)
|
| 137 |
+
|
| 138 |
+
# Side-type embeddings
|
| 139 |
+
side_ref = self.mem_type_emb(zero_t).to(hidden.dtype)
|
| 140 |
+
side_cand = self.mem_type_emb(one_t).to(hidden.dtype)
|
| 141 |
+
ref_pool_full = ref_pool_full + side_ref
|
| 142 |
+
cand_pool_full = cand_pool_full + side_cand
|
| 143 |
+
|
| 144 |
+
bool_one = torch.ones(B, 1, dtype=torch.bool, device=device)
|
| 145 |
+
ref_valid_full = torch.cat([bool_one, ref_valid], dim=1)
|
| 146 |
+
cand_valid_full = torch.cat([bool_one, cand_valid], dim=1)
|
| 147 |
+
|
| 148 |
+
memory = torch.cat([ref_pool_full, cand_pool_full], dim=1) # [B, M, D]
|
| 149 |
+
memory_valid = torch.cat([ref_valid_full, cand_valid_full], dim=1)
|
| 150 |
+
return {
|
| 151 |
+
"ref_pool": ref_pool_full, "ref_valid": ref_valid_full,
|
| 152 |
+
"cand_pool": cand_pool_full, "cand_valid": cand_valid_full,
|
| 153 |
+
"memory": memory, "memory_valid": memory_valid,
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
def _gather_seg_emb(self, pool: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
| 157 |
+
"""pool: [B, S, D], idx: [B, K] (≥0). Returns [B, K, D] via batched gather."""
|
| 158 |
+
B, K = idx.shape
|
| 159 |
+
D = pool.size(-1)
|
| 160 |
+
b_idx = torch.arange(B, device=pool.device).unsqueeze(1).expand(-1, K)
|
| 161 |
+
return pool[b_idx, idx]
|
| 162 |
+
|
| 163 |
+
def forward_train(
|
| 164 |
+
self,
|
| 165 |
+
input_ids, attention_mask,
|
| 166 |
+
ref_seg_token_mask, cand_seg_token_mask,
|
| 167 |
+
target_cat, target_anat, target_concept, target_sev,
|
| 168 |
+
target_ref, target_cand,
|
| 169 |
+
):
|
| 170 |
+
"""All targets are [B, K]. Padding & ignored positions are -100.
|
| 171 |
+
target_cat[b, k]==0 marks EOS at position k.
|
| 172 |
+
target_ref/target_cand are indices into ref_pool/cand_pool (incl. NULL=0).
|
| 173 |
+
"""
|
| 174 |
+
enc = self.encode_memory(input_ids, attention_mask,
|
| 175 |
+
ref_seg_token_mask, cand_seg_token_mask)
|
| 176 |
+
memory = enc["memory"]
|
| 177 |
+
ref_pool, cand_pool = enc["ref_pool"], enc["cand_pool"]
|
| 178 |
+
|
| 179 |
+
B, K = target_cat.shape
|
| 180 |
+
# For teacher-forcing we need the segment embedding for each target step,
|
| 181 |
+
# using clamp_min(0) so PAD/IGNORE sites get NULL. Loss ignores them later.
|
| 182 |
+
ref_idx_safe = target_ref.clamp_min(0)
|
| 183 |
+
cand_idx_safe = target_cand.clamp_min(0)
|
| 184 |
+
ref_emb_per_t = self._gather_seg_emb(ref_pool, ref_idx_safe)
|
| 185 |
+
cand_emb_per_t = self._gather_seg_emb(cand_pool, cand_idx_safe)
|
| 186 |
+
|
| 187 |
+
tuple_emb_all = self.tuple_emb(
|
| 188 |
+
cat=target_cat.clamp_min(0),
|
| 189 |
+
anat=target_anat.clamp_min(0),
|
| 190 |
+
concept=target_concept.clamp_min(0),
|
| 191 |
+
sev=target_sev.clamp_min(0),
|
| 192 |
+
ref_emb=ref_emb_per_t,
|
| 193 |
+
cand_emb=cand_emb_per_t,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Shift right with BOS
|
| 197 |
+
bos = self.bos_emb.expand(B, 1, -1).to(tuple_emb_all.dtype)
|
| 198 |
+
decoder_input = torch.cat([bos, tuple_emb_all[:, :-1, :]], dim=1)
|
| 199 |
+
|
| 200 |
+
causal_mask = nn.Transformer.generate_square_subsequent_mask(K).to(decoder_input.device)
|
| 201 |
+
mem_kp_mask = ~enc["memory_valid"]
|
| 202 |
+
out = self.decoder(
|
| 203 |
+
tgt=decoder_input,
|
| 204 |
+
memory=memory,
|
| 205 |
+
tgt_mask=causal_mask,
|
| 206 |
+
memory_key_padding_mask=mem_kp_mask,
|
| 207 |
+
) # [B, K, D]
|
| 208 |
+
|
| 209 |
+
logits_cat = self.head_cat(out)
|
| 210 |
+
logits_anat = self.head_anat(out)
|
| 211 |
+
logits_concept = self.head_concept(out)
|
| 212 |
+
logits_sev = self.head_severity(out)
|
| 213 |
+
|
| 214 |
+
scale = 1.0 / math.sqrt(self.hidden_size)
|
| 215 |
+
ref_q = self.proj_ref(out)
|
| 216 |
+
cand_q = self.proj_cand(out)
|
| 217 |
+
logits_ref = torch.einsum("bkd,bsd->bks", ref_q, ref_pool) * scale
|
| 218 |
+
logits_cand = torch.einsum("bkd,bsd->bks", cand_q, cand_pool) * scale
|
| 219 |
+
# Mask invalid pointer slots (padded segments) to -inf
|
| 220 |
+
logits_ref = logits_ref.masked_fill(~enc["ref_valid"][:, None, :], -1e4)
|
| 221 |
+
logits_cand = logits_cand.masked_fill(~enc["cand_valid"][:, None, :], -1e4)
|
| 222 |
+
|
| 223 |
+
return {
|
| 224 |
+
"logits_cat": logits_cat,
|
| 225 |
+
"logits_anat": logits_anat,
|
| 226 |
+
"logits_concept": logits_concept,
|
| 227 |
+
"logits_sev": logits_sev,
|
| 228 |
+
"logits_ref": logits_ref,
|
| 229 |
+
"logits_cand": logits_cand,
|
| 230 |
+
"memory": memory,
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
@torch.no_grad()
|
| 234 |
+
def decode_greedy(
|
| 235 |
+
self,
|
| 236 |
+
input_ids, attention_mask,
|
| 237 |
+
ref_seg_token_mask, cand_seg_token_mask,
|
| 238 |
+
):
|
| 239 |
+
"""Greedy autoregressive decoding. Returns list-of-list of dicts (per pair)."""
|
| 240 |
+
enc = self.encode_memory(input_ids, attention_mask,
|
| 241 |
+
ref_seg_token_mask, cand_seg_token_mask)
|
| 242 |
+
memory = enc["memory"]
|
| 243 |
+
ref_pool, cand_pool = enc["ref_pool"], enc["cand_pool"]
|
| 244 |
+
ref_valid, cand_valid = enc["ref_valid"], enc["cand_valid"]
|
| 245 |
+
mem_kp_mask = ~enc["memory_valid"]
|
| 246 |
+
|
| 247 |
+
B = input_ids.size(0)
|
| 248 |
+
device = input_ids.device
|
| 249 |
+
D = memory.size(-1)
|
| 250 |
+
bos = self.bos_emb.expand(B, 1, -1).to(memory.dtype)
|
| 251 |
+
prev_emb = bos
|
| 252 |
+
running = torch.ones(B, dtype=torch.bool, device=device)
|
| 253 |
+
out_seqs = [[] for _ in range(B)]
|
| 254 |
+
|
| 255 |
+
for step in range(self.max_decode_steps):
|
| 256 |
+
causal = nn.Transformer.generate_square_subsequent_mask(prev_emb.size(1)).to(device)
|
| 257 |
+
dec = self.decoder(prev_emb, memory, tgt_mask=causal, memory_key_padding_mask=mem_kp_mask)
|
| 258 |
+
last = dec[:, -1, :] # [B, D]
|
| 259 |
+
|
| 260 |
+
# Sample / argmax each head
|
| 261 |
+
cat_pred = self.head_cat(last).argmax(-1) # [B]
|
| 262 |
+
anat_pred = self.head_anat(last).argmax(-1)
|
| 263 |
+
concept_pred = self.head_concept(last).argmax(-1)
|
| 264 |
+
sev_pred = self.head_severity(last).argmax(-1)
|
| 265 |
+
scale = 1.0 / math.sqrt(self.hidden_size)
|
| 266 |
+
ref_q = self.proj_ref(last)
|
| 267 |
+
cand_q = self.proj_cand(last)
|
| 268 |
+
ref_logit = (torch.einsum("bd,bsd->bs", ref_q, ref_pool) * scale).masked_fill(~ref_valid, -1e4)
|
| 269 |
+
cand_logit = (torch.einsum("bd,bsd->bs", cand_q, cand_pool) * scale).masked_fill(~cand_valid, -1e4)
|
| 270 |
+
ref_pred = ref_logit.argmax(-1)
|
| 271 |
+
cand_pred = cand_logit.argmax(-1)
|
| 272 |
+
|
| 273 |
+
for b in range(B):
|
| 274 |
+
if not running[b]:
|
| 275 |
+
continue
|
| 276 |
+
if cat_pred[b].item() == self.EOS_CAT_IDX:
|
| 277 |
+
running[b] = False
|
| 278 |
+
continue
|
| 279 |
+
out_seqs[b].append({
|
| 280 |
+
"cat": int(cat_pred[b]),
|
| 281 |
+
"anat": int(anat_pred[b]),
|
| 282 |
+
"concept_id": int(concept_pred[b]),
|
| 283 |
+
"severity": int(sev_pred[b]),
|
| 284 |
+
"ref_seg_idx": int(ref_pred[b]),
|
| 285 |
+
"cand_seg_idx": int(cand_pred[b]),
|
| 286 |
+
})
|
| 287 |
+
if not running.any():
|
| 288 |
+
break
|
| 289 |
+
|
| 290 |
+
# Build next-step embedding from this step's predictions
|
| 291 |
+
ref_emb_step = ref_pool[torch.arange(B, device=device), ref_pred]
|
| 292 |
+
cand_emb_step = cand_pool[torch.arange(B, device=device), cand_pred]
|
| 293 |
+
next_emb = self.tuple_emb(
|
| 294 |
+
cat=cat_pred, anat=anat_pred,
|
| 295 |
+
concept=concept_pred, sev=sev_pred,
|
| 296 |
+
ref_emb=ref_emb_step, cand_emb=cand_emb_step,
|
| 297 |
+
).unsqueeze(1) # [B, 1, D]
|
| 298 |
+
prev_emb = torch.cat([prev_emb, next_emb], dim=1)
|
| 299 |
+
|
| 300 |
+
return out_seqs
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def cadad_loss(out: Dict[str, torch.Tensor],
|
| 304 |
+
target_cat, target_anat, target_concept, target_sev,
|
| 305 |
+
target_ref, target_cand,
|
| 306 |
+
weights: Optional[Dict[str, float]] = None) -> Dict[str, torch.Tensor]:
|
| 307 |
+
"""Cross-entropy on every head. Pad/ignore positions = -100 in targets.
|
| 308 |
+
EOS positions only supervise `cat`; other heads should be -100 there.
|
| 309 |
+
"""
|
| 310 |
+
w = {"cat": 1.0, "anat": 0.5, "concept": 0.3, "sev": 0.5,
|
| 311 |
+
"ref": 0.5, "cand": 0.5, **(weights or {})}
|
| 312 |
+
|
| 313 |
+
L_cat = F.cross_entropy(out["logits_cat"].reshape(-1, out["logits_cat"].size(-1)),
|
| 314 |
+
target_cat.reshape(-1), ignore_index=-100)
|
| 315 |
+
L_anat = F.cross_entropy(out["logits_anat"].reshape(-1, out["logits_anat"].size(-1)),
|
| 316 |
+
target_anat.reshape(-1), ignore_index=-100)
|
| 317 |
+
L_concept = F.cross_entropy(out["logits_concept"].reshape(-1, out["logits_concept"].size(-1)),
|
| 318 |
+
target_concept.reshape(-1), ignore_index=-100)
|
| 319 |
+
L_sev = F.cross_entropy(out["logits_sev"].reshape(-1, out["logits_sev"].size(-1)),
|
| 320 |
+
target_sev.reshape(-1), ignore_index=-100)
|
| 321 |
+
L_ref = F.cross_entropy(out["logits_ref"].reshape(-1, out["logits_ref"].size(-1)),
|
| 322 |
+
target_ref.reshape(-1), ignore_index=-100)
|
| 323 |
+
L_cand = F.cross_entropy(out["logits_cand"].reshape(-1, out["logits_cand"].size(-1)),
|
| 324 |
+
target_cand.reshape(-1), ignore_index=-100)
|
| 325 |
+
|
| 326 |
+
total = (w["cat"] * L_cat + w["anat"] * L_anat + w["concept"] * L_concept
|
| 327 |
+
+ w["sev"] * L_sev + w["ref"] * L_ref + w["cand"] * L_cand)
|
| 328 |
+
return {"total": total, "cat": L_cat, "anat": L_anat, "concept": L_concept,
|
| 329 |
+
"sev": L_sev, "ref": L_ref, "cand": L_cand}
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7736077f20e4b6713701a4faef0250dfd9a669f5ae8f243a002708ccd01f99be
|
| 3 |
+
size 254257936
|
train_config.yaml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
input_format:
|
| 2 |
+
pred_sentinel: '[PRED]'
|
| 3 |
+
template: '[REF] {reference_report}
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
[PRED] {candidate_report}'
|
| 7 |
+
loss:
|
| 8 |
+
anat: 0.5
|
| 9 |
+
cand: 0.5
|
| 10 |
+
cat: 1.0
|
| 11 |
+
concept: 0.3
|
| 12 |
+
ref: 0.5
|
| 13 |
+
sev: 0.5
|
| 14 |
+
metrics:
|
| 15 |
+
primary_metric: val_mae_K
|
| 16 |
+
model:
|
| 17 |
+
architecture: cada_d
|
| 18 |
+
attn_implementation: flash_attention_2
|
| 19 |
+
backbone_name: Qwen/Qwen3-Embedding-0.6B
|
| 20 |
+
chest2vec_adapter_path: /opt/project/chest2vec/export_chest2vec_0.6b_chest/contrastive
|
| 21 |
+
decoder_dropout: 0.1
|
| 22 |
+
decoder_ff: 2048
|
| 23 |
+
decoder_heads: 8
|
| 24 |
+
decoder_layers: 4
|
| 25 |
+
freeze_backbone_initially: false
|
| 26 |
+
lora_alpha: 64
|
| 27 |
+
lora_dropout: 0.05
|
| 28 |
+
lora_rank: 32
|
| 29 |
+
max_decode_steps: 24
|
| 30 |
+
max_length: 1280
|
| 31 |
+
n_anat: 9
|
| 32 |
+
n_cat: 5
|
| 33 |
+
n_severity: 2
|
| 34 |
+
use_lora: true
|
| 35 |
+
paths:
|
| 36 |
+
concept_vocab_path: /opt/project/chest2vec/chest2vec_error/artifacts_v5/concept2id.json
|
| 37 |
+
data_csv: /opt/project/chest2vec/create_labels/unified_variants_v5_merged.csv
|
| 38 |
+
output_dir: /opt/project/chest2vec/chest2vec_error/artifacts/cada_d_6gpu
|
| 39 |
+
seed: 42
|
| 40 |
+
training:
|
| 41 |
+
batch_size: 8
|
| 42 |
+
bf16: true
|
| 43 |
+
epochs: 20
|
| 44 |
+
grad_accum_steps: 1
|
| 45 |
+
gradient_checkpointing: false
|
| 46 |
+
lr_backbone: 0.0001
|
| 47 |
+
lr_heads: 0.0003
|
| 48 |
+
max_grad_norm: 1.0
|
| 49 |
+
num_workers: 4
|
| 50 |
+
warmup_ratio: 0.03
|
| 51 |
+
weight_decay: 0.01
|