"""CADA-D — sentence-grounded autoregressive error-tuple decoder. Architecture ------------ 1. Encoder (reused from CADA): backbone produces [B, T, D] hidden states. 2. Sentence pooling: mean-pool hidden states over per-segment token masks on each side; prepend a learnable NULL_REF / NULL_CAND vector per side. 3. Cross-attended decoder: TransformerDecoder over the concatenated ref+cand segment pool. At each step it predicts a tuple (cat, anat, concept, severity, ref_seg_idx, cand_seg_idx) with cat=0 reserved for EOS. Counts emerge as `len(seq) - 1`, cell counts as a histogram over (cat, anat). The explanation IS the prediction — each emitted tuple points to a specific ref sentence (or NULL) and a specific cand sentence (or NULL). """ from __future__ import annotations import math from typing import Dict, Optional import torch import torch.nn as nn import torch.nn.functional as F def _segment_pool(hidden: torch.Tensor, seg_token_mask: torch.Tensor): """Mean-pool tokens over per-segment masks. hidden: [B, T, D] seg_token_mask: [B, S, T] bool 1 where token t belongs to segment s. Returns pool: [B, S, D] valid: [B, S] True where segment had at least 1 token. """ m = seg_token_mask.to(hidden.dtype) denom = m.sum(dim=-1, keepdim=True).clamp_min(1.0) pool = (m @ hidden) / denom valid = seg_token_mask.any(dim=-1) return pool, valid class _TupleEmbedder(nn.Module): """Sum of category/anatomy/concept/severity embeddings + segment embeddings, then a small projection. Used to embed teacher-forced tuples back to D.""" def __init__(self, n_cat: int, n_anat: int, n_concept: int, n_sev: int, hidden_size: int): super().__init__() self.cat_emb = nn.Embedding(n_cat + 1, hidden_size) self.anat_emb = nn.Embedding(n_anat, hidden_size) self.concept_emb = nn.Embedding(n_concept, hidden_size) self.sev_emb = nn.Embedding(n_sev, hidden_size) self.proj = nn.Linear(hidden_size, hidden_size) def forward(self, cat, anat, concept, sev, ref_emb, cand_emb): e = (self.cat_emb(cat) + self.anat_emb(anat) + self.concept_emb(concept) + self.sev_emb(sev) + ref_emb + cand_emb) return self.proj(e) class CADAD(nn.Module): """Sentence-grounded autoregressive error-tuple decoder.""" EOS_CAT_IDX = 0 # special class in `cat` for end-of-sequence def __init__( self, backbone, hidden_size: int, n_cat: int = 5, n_anat: int = 9, n_concept: int = 386, n_severity: int = 2, decoder_layers: int = 2, decoder_heads: int = 8, decoder_ff: int = 1024, dropout: float = 0.1, max_decode_steps: int = 24, ): super().__init__() self.backbone = backbone self.hidden_size = hidden_size self.n_cat = n_cat self.n_anat = n_anat self.n_concept = n_concept self.n_severity = n_severity self.max_decode_steps = max_decode_steps # Memory-side conditioning self.mem_type_emb = nn.Embedding(2, hidden_size) # 0=ref-side, 1=cand-side self.null_ref = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) self.null_cand = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) self.bos_emb = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) self.tuple_emb = _TupleEmbedder(n_cat, n_anat, n_concept, n_severity, hidden_size) layer = nn.TransformerDecoderLayer( d_model=hidden_size, nhead=decoder_heads, dim_feedforward=decoder_ff, dropout=dropout, batch_first=True, activation="gelu", norm_first=True, ) self.decoder = nn.TransformerDecoder(layer, num_layers=decoder_layers) # Output heads self.head_cat = nn.Linear(hidden_size, n_cat + 1) # +1 for EOS at idx 0 self.head_anat = nn.Linear(hidden_size, n_anat) self.head_concept = nn.Linear(hidden_size, n_concept) self.head_severity = nn.Linear(hidden_size, n_severity) self.proj_ref = nn.Linear(hidden_size, hidden_size) self.proj_cand = nn.Linear(hidden_size, hidden_size) def encode_memory(self, input_ids, attention_mask, ref_seg_token_mask, cand_seg_token_mask): """Returns dict with ref_pool, cand_pool, memory, valid masks. ref_pool/cand_pool include a leading NULL slot at index 0. """ out = self.backbone(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) hidden = out.last_hidden_state # [B, T, D] ref_pool, ref_valid = _segment_pool(hidden, ref_seg_token_mask) cand_pool, cand_valid = _segment_pool(hidden, cand_seg_token_mask) B = hidden.size(0) device = hidden.device zero_t = torch.zeros(B, 1, dtype=torch.long, device=device) one_t = torch.ones(B, 1, dtype=torch.long, device=device) # Prepend NULL slot at index 0 on each side. null_r = self.null_ref.expand(B, 1, -1).to(hidden.dtype) null_c = self.null_cand.expand(B, 1, -1).to(hidden.dtype) ref_pool_full = torch.cat([null_r, ref_pool], dim=1) cand_pool_full = torch.cat([null_c, cand_pool], dim=1) # Side-type embeddings side_ref = self.mem_type_emb(zero_t).to(hidden.dtype) side_cand = self.mem_type_emb(one_t).to(hidden.dtype) ref_pool_full = ref_pool_full + side_ref cand_pool_full = cand_pool_full + side_cand bool_one = torch.ones(B, 1, dtype=torch.bool, device=device) ref_valid_full = torch.cat([bool_one, ref_valid], dim=1) cand_valid_full = torch.cat([bool_one, cand_valid], dim=1) memory = torch.cat([ref_pool_full, cand_pool_full], dim=1) # [B, M, D] memory_valid = torch.cat([ref_valid_full, cand_valid_full], dim=1) return { "ref_pool": ref_pool_full, "ref_valid": ref_valid_full, "cand_pool": cand_pool_full, "cand_valid": cand_valid_full, "memory": memory, "memory_valid": memory_valid, } def _gather_seg_emb(self, pool: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: """pool: [B, S, D], idx: [B, K] (≥0). Returns [B, K, D] via batched gather.""" B, K = idx.shape D = pool.size(-1) b_idx = torch.arange(B, device=pool.device).unsqueeze(1).expand(-1, K) return pool[b_idx, idx] def forward_train( self, input_ids, attention_mask, ref_seg_token_mask, cand_seg_token_mask, target_cat, target_anat, target_concept, target_sev, target_ref, target_cand, ): """All targets are [B, K]. Padding & ignored positions are -100. target_cat[b, k]==0 marks EOS at position k. target_ref/target_cand are indices into ref_pool/cand_pool (incl. NULL=0). """ enc = self.encode_memory(input_ids, attention_mask, ref_seg_token_mask, cand_seg_token_mask) memory = enc["memory"] ref_pool, cand_pool = enc["ref_pool"], enc["cand_pool"] B, K = target_cat.shape # For teacher-forcing we need the segment embedding for each target step, # using clamp_min(0) so PAD/IGNORE sites get NULL. Loss ignores them later. ref_idx_safe = target_ref.clamp_min(0) cand_idx_safe = target_cand.clamp_min(0) ref_emb_per_t = self._gather_seg_emb(ref_pool, ref_idx_safe) cand_emb_per_t = self._gather_seg_emb(cand_pool, cand_idx_safe) tuple_emb_all = self.tuple_emb( cat=target_cat.clamp_min(0), anat=target_anat.clamp_min(0), concept=target_concept.clamp_min(0), sev=target_sev.clamp_min(0), ref_emb=ref_emb_per_t, cand_emb=cand_emb_per_t, ) # Shift right with BOS bos = self.bos_emb.expand(B, 1, -1).to(tuple_emb_all.dtype) decoder_input = torch.cat([bos, tuple_emb_all[:, :-1, :]], dim=1) causal_mask = nn.Transformer.generate_square_subsequent_mask(K).to(decoder_input.device) mem_kp_mask = ~enc["memory_valid"] out = self.decoder( tgt=decoder_input, memory=memory, tgt_mask=causal_mask, memory_key_padding_mask=mem_kp_mask, ) # [B, K, D] logits_cat = self.head_cat(out) logits_anat = self.head_anat(out) logits_concept = self.head_concept(out) logits_sev = self.head_severity(out) scale = 1.0 / math.sqrt(self.hidden_size) ref_q = self.proj_ref(out) cand_q = self.proj_cand(out) logits_ref = torch.einsum("bkd,bsd->bks", ref_q, ref_pool) * scale logits_cand = torch.einsum("bkd,bsd->bks", cand_q, cand_pool) * scale # Mask invalid pointer slots (padded segments) to -inf logits_ref = logits_ref.masked_fill(~enc["ref_valid"][:, None, :], -1e4) logits_cand = logits_cand.masked_fill(~enc["cand_valid"][:, None, :], -1e4) return { "logits_cat": logits_cat, "logits_anat": logits_anat, "logits_concept": logits_concept, "logits_sev": logits_sev, "logits_ref": logits_ref, "logits_cand": logits_cand, "memory": memory, } @torch.no_grad() def decode_greedy( self, input_ids, attention_mask, ref_seg_token_mask, cand_seg_token_mask, ): """Greedy autoregressive decoding. Returns list-of-list of dicts (per pair).""" enc = self.encode_memory(input_ids, attention_mask, ref_seg_token_mask, cand_seg_token_mask) memory = enc["memory"] ref_pool, cand_pool = enc["ref_pool"], enc["cand_pool"] ref_valid, cand_valid = enc["ref_valid"], enc["cand_valid"] mem_kp_mask = ~enc["memory_valid"] B = input_ids.size(0) device = input_ids.device D = memory.size(-1) bos = self.bos_emb.expand(B, 1, -1).to(memory.dtype) prev_emb = bos running = torch.ones(B, dtype=torch.bool, device=device) out_seqs = [[] for _ in range(B)] for step in range(self.max_decode_steps): causal = nn.Transformer.generate_square_subsequent_mask(prev_emb.size(1)).to(device) dec = self.decoder(prev_emb, memory, tgt_mask=causal, memory_key_padding_mask=mem_kp_mask) last = dec[:, -1, :] # [B, D] # Sample / argmax each head cat_pred = self.head_cat(last).argmax(-1) # [B] anat_pred = self.head_anat(last).argmax(-1) concept_pred = self.head_concept(last).argmax(-1) sev_pred = self.head_severity(last).argmax(-1) scale = 1.0 / math.sqrt(self.hidden_size) ref_q = self.proj_ref(last) cand_q = self.proj_cand(last) ref_logit = (torch.einsum("bd,bsd->bs", ref_q, ref_pool) * scale).masked_fill(~ref_valid, -1e4) cand_logit = (torch.einsum("bd,bsd->bs", cand_q, cand_pool) * scale).masked_fill(~cand_valid, -1e4) ref_pred = ref_logit.argmax(-1) cand_pred = cand_logit.argmax(-1) for b in range(B): if not running[b]: continue if cat_pred[b].item() == self.EOS_CAT_IDX: running[b] = False continue out_seqs[b].append({ "cat": int(cat_pred[b]), "anat": int(anat_pred[b]), "concept_id": int(concept_pred[b]), "severity": int(sev_pred[b]), "ref_seg_idx": int(ref_pred[b]), "cand_seg_idx": int(cand_pred[b]), }) if not running.any(): break # Build next-step embedding from this step's predictions ref_emb_step = ref_pool[torch.arange(B, device=device), ref_pred] cand_emb_step = cand_pool[torch.arange(B, device=device), cand_pred] next_emb = self.tuple_emb( cat=cat_pred, anat=anat_pred, concept=concept_pred, sev=sev_pred, ref_emb=ref_emb_step, cand_emb=cand_emb_step, ).unsqueeze(1) # [B, 1, D] prev_emb = torch.cat([prev_emb, next_emb], dim=1) return out_seqs def cadad_loss(out: Dict[str, torch.Tensor], target_cat, target_anat, target_concept, target_sev, target_ref, target_cand, weights: Optional[Dict[str, float]] = None) -> Dict[str, torch.Tensor]: """Cross-entropy on every head. Pad/ignore positions = -100 in targets. EOS positions only supervise `cat`; other heads should be -100 there. """ w = {"cat": 1.0, "anat": 0.5, "concept": 0.3, "sev": 0.5, "ref": 0.5, "cand": 0.5, **(weights or {})} L_cat = F.cross_entropy(out["logits_cat"].reshape(-1, out["logits_cat"].size(-1)), target_cat.reshape(-1), ignore_index=-100) L_anat = F.cross_entropy(out["logits_anat"].reshape(-1, out["logits_anat"].size(-1)), target_anat.reshape(-1), ignore_index=-100) L_concept = F.cross_entropy(out["logits_concept"].reshape(-1, out["logits_concept"].size(-1)), target_concept.reshape(-1), ignore_index=-100) L_sev = F.cross_entropy(out["logits_sev"].reshape(-1, out["logits_sev"].size(-1)), target_sev.reshape(-1), ignore_index=-100) L_ref = F.cross_entropy(out["logits_ref"].reshape(-1, out["logits_ref"].size(-1)), target_ref.reshape(-1), ignore_index=-100) L_cand = F.cross_entropy(out["logits_cand"].reshape(-1, out["logits_cand"].size(-1)), target_cand.reshape(-1), ignore_index=-100) total = (w["cat"] * L_cat + w["anat"] * L_anat + w["concept"] * L_concept + w["sev"] * L_sev + w["ref"] * L_ref + w["cand"] * L_cand) return {"total": total, "cat": L_cat, "anat": L_anat, "concept": L_concept, "sev": L_sev, "ref": L_ref, "cand": L_cand}