Initial release: chest2err sentence-grounded error decoder (τ_b=+0.763, pairwise acc=0.958)
8a9746d verified | """CADA-D — sentence-grounded autoregressive error-tuple decoder. | |
| Architecture | |
| ------------ | |
| 1. Encoder (reused from CADA): backbone produces [B, T, D] hidden states. | |
| 2. Sentence pooling: mean-pool hidden states over per-segment token masks | |
| on each side; prepend a learnable NULL_REF / NULL_CAND vector per side. | |
| 3. Cross-attended decoder: TransformerDecoder over the concatenated | |
| ref+cand segment pool. At each step it predicts a tuple | |
| (cat, anat, concept, severity, ref_seg_idx, cand_seg_idx) | |
| with cat=0 reserved for EOS. | |
| Counts emerge as `len(seq) - 1`, cell counts as a histogram over (cat, anat). | |
| The explanation IS the prediction — each emitted tuple points to a specific | |
| ref sentence (or NULL) and a specific cand sentence (or NULL). | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from typing import Dict, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def _segment_pool(hidden: torch.Tensor, seg_token_mask: torch.Tensor): | |
| """Mean-pool tokens over per-segment masks. | |
| hidden: [B, T, D] | |
| seg_token_mask: [B, S, T] bool 1 where token t belongs to segment s. | |
| Returns | |
| pool: [B, S, D] | |
| valid: [B, S] True where segment had at least 1 token. | |
| """ | |
| m = seg_token_mask.to(hidden.dtype) | |
| denom = m.sum(dim=-1, keepdim=True).clamp_min(1.0) | |
| pool = (m @ hidden) / denom | |
| valid = seg_token_mask.any(dim=-1) | |
| return pool, valid | |
| class _TupleEmbedder(nn.Module): | |
| """Sum of category/anatomy/concept/severity embeddings + segment embeddings, | |
| then a small projection. Used to embed teacher-forced tuples back to D.""" | |
| def __init__(self, n_cat: int, n_anat: int, n_concept: int, n_sev: int, | |
| hidden_size: int): | |
| super().__init__() | |
| self.cat_emb = nn.Embedding(n_cat + 1, hidden_size) | |
| self.anat_emb = nn.Embedding(n_anat, hidden_size) | |
| self.concept_emb = nn.Embedding(n_concept, hidden_size) | |
| self.sev_emb = nn.Embedding(n_sev, hidden_size) | |
| self.proj = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, cat, anat, concept, sev, ref_emb, cand_emb): | |
| e = (self.cat_emb(cat) + self.anat_emb(anat) | |
| + self.concept_emb(concept) + self.sev_emb(sev) | |
| + ref_emb + cand_emb) | |
| return self.proj(e) | |
| class CADAD(nn.Module): | |
| """Sentence-grounded autoregressive error-tuple decoder.""" | |
| EOS_CAT_IDX = 0 # special class in `cat` for end-of-sequence | |
| def __init__( | |
| self, | |
| backbone, | |
| hidden_size: int, | |
| n_cat: int = 5, | |
| n_anat: int = 9, | |
| n_concept: int = 386, | |
| n_severity: int = 2, | |
| decoder_layers: int = 2, | |
| decoder_heads: int = 8, | |
| decoder_ff: int = 1024, | |
| dropout: float = 0.1, | |
| max_decode_steps: int = 24, | |
| ): | |
| super().__init__() | |
| self.backbone = backbone | |
| self.hidden_size = hidden_size | |
| self.n_cat = n_cat | |
| self.n_anat = n_anat | |
| self.n_concept = n_concept | |
| self.n_severity = n_severity | |
| self.max_decode_steps = max_decode_steps | |
| # Memory-side conditioning | |
| self.mem_type_emb = nn.Embedding(2, hidden_size) # 0=ref-side, 1=cand-side | |
| self.null_ref = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) | |
| self.null_cand = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) | |
| self.bos_emb = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) | |
| self.tuple_emb = _TupleEmbedder(n_cat, n_anat, n_concept, n_severity, hidden_size) | |
| layer = nn.TransformerDecoderLayer( | |
| d_model=hidden_size, nhead=decoder_heads, | |
| dim_feedforward=decoder_ff, dropout=dropout, | |
| batch_first=True, activation="gelu", norm_first=True, | |
| ) | |
| self.decoder = nn.TransformerDecoder(layer, num_layers=decoder_layers) | |
| # Output heads | |
| self.head_cat = nn.Linear(hidden_size, n_cat + 1) # +1 for EOS at idx 0 | |
| self.head_anat = nn.Linear(hidden_size, n_anat) | |
| self.head_concept = nn.Linear(hidden_size, n_concept) | |
| self.head_severity = nn.Linear(hidden_size, n_severity) | |
| self.proj_ref = nn.Linear(hidden_size, hidden_size) | |
| self.proj_cand = nn.Linear(hidden_size, hidden_size) | |
| def encode_memory(self, input_ids, attention_mask, | |
| ref_seg_token_mask, cand_seg_token_mask): | |
| """Returns dict with ref_pool, cand_pool, memory, valid masks. | |
| ref_pool/cand_pool include a leading NULL slot at index 0. | |
| """ | |
| out = self.backbone(input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| return_dict=True) | |
| hidden = out.last_hidden_state # [B, T, D] | |
| ref_pool, ref_valid = _segment_pool(hidden, ref_seg_token_mask) | |
| cand_pool, cand_valid = _segment_pool(hidden, cand_seg_token_mask) | |
| B = hidden.size(0) | |
| device = hidden.device | |
| zero_t = torch.zeros(B, 1, dtype=torch.long, device=device) | |
| one_t = torch.ones(B, 1, dtype=torch.long, device=device) | |
| # Prepend NULL slot at index 0 on each side. | |
| null_r = self.null_ref.expand(B, 1, -1).to(hidden.dtype) | |
| null_c = self.null_cand.expand(B, 1, -1).to(hidden.dtype) | |
| ref_pool_full = torch.cat([null_r, ref_pool], dim=1) | |
| cand_pool_full = torch.cat([null_c, cand_pool], dim=1) | |
| # Side-type embeddings | |
| side_ref = self.mem_type_emb(zero_t).to(hidden.dtype) | |
| side_cand = self.mem_type_emb(one_t).to(hidden.dtype) | |
| ref_pool_full = ref_pool_full + side_ref | |
| cand_pool_full = cand_pool_full + side_cand | |
| bool_one = torch.ones(B, 1, dtype=torch.bool, device=device) | |
| ref_valid_full = torch.cat([bool_one, ref_valid], dim=1) | |
| cand_valid_full = torch.cat([bool_one, cand_valid], dim=1) | |
| memory = torch.cat([ref_pool_full, cand_pool_full], dim=1) # [B, M, D] | |
| memory_valid = torch.cat([ref_valid_full, cand_valid_full], dim=1) | |
| return { | |
| "ref_pool": ref_pool_full, "ref_valid": ref_valid_full, | |
| "cand_pool": cand_pool_full, "cand_valid": cand_valid_full, | |
| "memory": memory, "memory_valid": memory_valid, | |
| } | |
| def _gather_seg_emb(self, pool: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: | |
| """pool: [B, S, D], idx: [B, K] (≥0). Returns [B, K, D] via batched gather.""" | |
| B, K = idx.shape | |
| D = pool.size(-1) | |
| b_idx = torch.arange(B, device=pool.device).unsqueeze(1).expand(-1, K) | |
| return pool[b_idx, idx] | |
| def forward_train( | |
| self, | |
| input_ids, attention_mask, | |
| ref_seg_token_mask, cand_seg_token_mask, | |
| target_cat, target_anat, target_concept, target_sev, | |
| target_ref, target_cand, | |
| ): | |
| """All targets are [B, K]. Padding & ignored positions are -100. | |
| target_cat[b, k]==0 marks EOS at position k. | |
| target_ref/target_cand are indices into ref_pool/cand_pool (incl. NULL=0). | |
| """ | |
| enc = self.encode_memory(input_ids, attention_mask, | |
| ref_seg_token_mask, cand_seg_token_mask) | |
| memory = enc["memory"] | |
| ref_pool, cand_pool = enc["ref_pool"], enc["cand_pool"] | |
| B, K = target_cat.shape | |
| # For teacher-forcing we need the segment embedding for each target step, | |
| # using clamp_min(0) so PAD/IGNORE sites get NULL. Loss ignores them later. | |
| ref_idx_safe = target_ref.clamp_min(0) | |
| cand_idx_safe = target_cand.clamp_min(0) | |
| ref_emb_per_t = self._gather_seg_emb(ref_pool, ref_idx_safe) | |
| cand_emb_per_t = self._gather_seg_emb(cand_pool, cand_idx_safe) | |
| tuple_emb_all = self.tuple_emb( | |
| cat=target_cat.clamp_min(0), | |
| anat=target_anat.clamp_min(0), | |
| concept=target_concept.clamp_min(0), | |
| sev=target_sev.clamp_min(0), | |
| ref_emb=ref_emb_per_t, | |
| cand_emb=cand_emb_per_t, | |
| ) | |
| # Shift right with BOS | |
| bos = self.bos_emb.expand(B, 1, -1).to(tuple_emb_all.dtype) | |
| decoder_input = torch.cat([bos, tuple_emb_all[:, :-1, :]], dim=1) | |
| causal_mask = nn.Transformer.generate_square_subsequent_mask(K).to(decoder_input.device) | |
| mem_kp_mask = ~enc["memory_valid"] | |
| out = self.decoder( | |
| tgt=decoder_input, | |
| memory=memory, | |
| tgt_mask=causal_mask, | |
| memory_key_padding_mask=mem_kp_mask, | |
| ) # [B, K, D] | |
| logits_cat = self.head_cat(out) | |
| logits_anat = self.head_anat(out) | |
| logits_concept = self.head_concept(out) | |
| logits_sev = self.head_severity(out) | |
| scale = 1.0 / math.sqrt(self.hidden_size) | |
| ref_q = self.proj_ref(out) | |
| cand_q = self.proj_cand(out) | |
| logits_ref = torch.einsum("bkd,bsd->bks", ref_q, ref_pool) * scale | |
| logits_cand = torch.einsum("bkd,bsd->bks", cand_q, cand_pool) * scale | |
| # Mask invalid pointer slots (padded segments) to -inf | |
| logits_ref = logits_ref.masked_fill(~enc["ref_valid"][:, None, :], -1e4) | |
| logits_cand = logits_cand.masked_fill(~enc["cand_valid"][:, None, :], -1e4) | |
| return { | |
| "logits_cat": logits_cat, | |
| "logits_anat": logits_anat, | |
| "logits_concept": logits_concept, | |
| "logits_sev": logits_sev, | |
| "logits_ref": logits_ref, | |
| "logits_cand": logits_cand, | |
| "memory": memory, | |
| } | |
| def decode_greedy( | |
| self, | |
| input_ids, attention_mask, | |
| ref_seg_token_mask, cand_seg_token_mask, | |
| ): | |
| """Greedy autoregressive decoding. Returns list-of-list of dicts (per pair).""" | |
| enc = self.encode_memory(input_ids, attention_mask, | |
| ref_seg_token_mask, cand_seg_token_mask) | |
| memory = enc["memory"] | |
| ref_pool, cand_pool = enc["ref_pool"], enc["cand_pool"] | |
| ref_valid, cand_valid = enc["ref_valid"], enc["cand_valid"] | |
| mem_kp_mask = ~enc["memory_valid"] | |
| B = input_ids.size(0) | |
| device = input_ids.device | |
| D = memory.size(-1) | |
| bos = self.bos_emb.expand(B, 1, -1).to(memory.dtype) | |
| prev_emb = bos | |
| running = torch.ones(B, dtype=torch.bool, device=device) | |
| out_seqs = [[] for _ in range(B)] | |
| for step in range(self.max_decode_steps): | |
| causal = nn.Transformer.generate_square_subsequent_mask(prev_emb.size(1)).to(device) | |
| dec = self.decoder(prev_emb, memory, tgt_mask=causal, memory_key_padding_mask=mem_kp_mask) | |
| last = dec[:, -1, :] # [B, D] | |
| # Sample / argmax each head | |
| cat_pred = self.head_cat(last).argmax(-1) # [B] | |
| anat_pred = self.head_anat(last).argmax(-1) | |
| concept_pred = self.head_concept(last).argmax(-1) | |
| sev_pred = self.head_severity(last).argmax(-1) | |
| scale = 1.0 / math.sqrt(self.hidden_size) | |
| ref_q = self.proj_ref(last) | |
| cand_q = self.proj_cand(last) | |
| ref_logit = (torch.einsum("bd,bsd->bs", ref_q, ref_pool) * scale).masked_fill(~ref_valid, -1e4) | |
| cand_logit = (torch.einsum("bd,bsd->bs", cand_q, cand_pool) * scale).masked_fill(~cand_valid, -1e4) | |
| ref_pred = ref_logit.argmax(-1) | |
| cand_pred = cand_logit.argmax(-1) | |
| for b in range(B): | |
| if not running[b]: | |
| continue | |
| if cat_pred[b].item() == self.EOS_CAT_IDX: | |
| running[b] = False | |
| continue | |
| out_seqs[b].append({ | |
| "cat": int(cat_pred[b]), | |
| "anat": int(anat_pred[b]), | |
| "concept_id": int(concept_pred[b]), | |
| "severity": int(sev_pred[b]), | |
| "ref_seg_idx": int(ref_pred[b]), | |
| "cand_seg_idx": int(cand_pred[b]), | |
| }) | |
| if not running.any(): | |
| break | |
| # Build next-step embedding from this step's predictions | |
| ref_emb_step = ref_pool[torch.arange(B, device=device), ref_pred] | |
| cand_emb_step = cand_pool[torch.arange(B, device=device), cand_pred] | |
| next_emb = self.tuple_emb( | |
| cat=cat_pred, anat=anat_pred, | |
| concept=concept_pred, sev=sev_pred, | |
| ref_emb=ref_emb_step, cand_emb=cand_emb_step, | |
| ).unsqueeze(1) # [B, 1, D] | |
| prev_emb = torch.cat([prev_emb, next_emb], dim=1) | |
| return out_seqs | |
| def cadad_loss(out: Dict[str, torch.Tensor], | |
| target_cat, target_anat, target_concept, target_sev, | |
| target_ref, target_cand, | |
| weights: Optional[Dict[str, float]] = None) -> Dict[str, torch.Tensor]: | |
| """Cross-entropy on every head. Pad/ignore positions = -100 in targets. | |
| EOS positions only supervise `cat`; other heads should be -100 there. | |
| """ | |
| w = {"cat": 1.0, "anat": 0.5, "concept": 0.3, "sev": 0.5, | |
| "ref": 0.5, "cand": 0.5, **(weights or {})} | |
| L_cat = F.cross_entropy(out["logits_cat"].reshape(-1, out["logits_cat"].size(-1)), | |
| target_cat.reshape(-1), ignore_index=-100) | |
| L_anat = F.cross_entropy(out["logits_anat"].reshape(-1, out["logits_anat"].size(-1)), | |
| target_anat.reshape(-1), ignore_index=-100) | |
| L_concept = F.cross_entropy(out["logits_concept"].reshape(-1, out["logits_concept"].size(-1)), | |
| target_concept.reshape(-1), ignore_index=-100) | |
| L_sev = F.cross_entropy(out["logits_sev"].reshape(-1, out["logits_sev"].size(-1)), | |
| target_sev.reshape(-1), ignore_index=-100) | |
| L_ref = F.cross_entropy(out["logits_ref"].reshape(-1, out["logits_ref"].size(-1)), | |
| target_ref.reshape(-1), ignore_index=-100) | |
| L_cand = F.cross_entropy(out["logits_cand"].reshape(-1, out["logits_cand"].size(-1)), | |
| target_cand.reshape(-1), ignore_index=-100) | |
| total = (w["cat"] * L_cat + w["anat"] * L_anat + w["concept"] * L_concept | |
| + w["sev"] * L_sev + w["ref"] * L_ref + w["cand"] * L_cand) | |
| return {"total": total, "cat": L_cat, "anat": L_anat, "concept": L_concept, | |
| "sev": L_sev, "ref": L_ref, "cand": L_cand} | |