chest2err / chest2err_modeling.py
lukeingawesome's picture
Initial release: chest2err sentence-grounded error decoder (τ_b=+0.763, pairwise acc=0.958)
8a9746d verified
Raw
History Blame Contribute Delete
14.5 kB
"""CADA-D — sentence-grounded autoregressive error-tuple decoder.
Architecture
------------
1. Encoder (reused from CADA): backbone produces [B, T, D] hidden states.
2. Sentence pooling: mean-pool hidden states over per-segment token masks
on each side; prepend a learnable NULL_REF / NULL_CAND vector per side.
3. Cross-attended decoder: TransformerDecoder over the concatenated
ref+cand segment pool. At each step it predicts a tuple
(cat, anat, concept, severity, ref_seg_idx, cand_seg_idx)
with cat=0 reserved for EOS.
Counts emerge as `len(seq) - 1`, cell counts as a histogram over (cat, anat).
The explanation IS the prediction — each emitted tuple points to a specific
ref sentence (or NULL) and a specific cand sentence (or NULL).
"""
from __future__ import annotations
import math
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
def _segment_pool(hidden: torch.Tensor, seg_token_mask: torch.Tensor):
"""Mean-pool tokens over per-segment masks.
hidden: [B, T, D]
seg_token_mask: [B, S, T] bool 1 where token t belongs to segment s.
Returns
pool: [B, S, D]
valid: [B, S] True where segment had at least 1 token.
"""
m = seg_token_mask.to(hidden.dtype)
denom = m.sum(dim=-1, keepdim=True).clamp_min(1.0)
pool = (m @ hidden) / denom
valid = seg_token_mask.any(dim=-1)
return pool, valid
class _TupleEmbedder(nn.Module):
"""Sum of category/anatomy/concept/severity embeddings + segment embeddings,
then a small projection. Used to embed teacher-forced tuples back to D."""
def __init__(self, n_cat: int, n_anat: int, n_concept: int, n_sev: int,
hidden_size: int):
super().__init__()
self.cat_emb = nn.Embedding(n_cat + 1, hidden_size)
self.anat_emb = nn.Embedding(n_anat, hidden_size)
self.concept_emb = nn.Embedding(n_concept, hidden_size)
self.sev_emb = nn.Embedding(n_sev, hidden_size)
self.proj = nn.Linear(hidden_size, hidden_size)
def forward(self, cat, anat, concept, sev, ref_emb, cand_emb):
e = (self.cat_emb(cat) + self.anat_emb(anat)
+ self.concept_emb(concept) + self.sev_emb(sev)
+ ref_emb + cand_emb)
return self.proj(e)
class CADAD(nn.Module):
"""Sentence-grounded autoregressive error-tuple decoder."""
EOS_CAT_IDX = 0 # special class in `cat` for end-of-sequence
def __init__(
self,
backbone,
hidden_size: int,
n_cat: int = 5,
n_anat: int = 9,
n_concept: int = 386,
n_severity: int = 2,
decoder_layers: int = 2,
decoder_heads: int = 8,
decoder_ff: int = 1024,
dropout: float = 0.1,
max_decode_steps: int = 24,
):
super().__init__()
self.backbone = backbone
self.hidden_size = hidden_size
self.n_cat = n_cat
self.n_anat = n_anat
self.n_concept = n_concept
self.n_severity = n_severity
self.max_decode_steps = max_decode_steps
# Memory-side conditioning
self.mem_type_emb = nn.Embedding(2, hidden_size) # 0=ref-side, 1=cand-side
self.null_ref = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
self.null_cand = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
self.bos_emb = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
self.tuple_emb = _TupleEmbedder(n_cat, n_anat, n_concept, n_severity, hidden_size)
layer = nn.TransformerDecoderLayer(
d_model=hidden_size, nhead=decoder_heads,
dim_feedforward=decoder_ff, dropout=dropout,
batch_first=True, activation="gelu", norm_first=True,
)
self.decoder = nn.TransformerDecoder(layer, num_layers=decoder_layers)
# Output heads
self.head_cat = nn.Linear(hidden_size, n_cat + 1) # +1 for EOS at idx 0
self.head_anat = nn.Linear(hidden_size, n_anat)
self.head_concept = nn.Linear(hidden_size, n_concept)
self.head_severity = nn.Linear(hidden_size, n_severity)
self.proj_ref = nn.Linear(hidden_size, hidden_size)
self.proj_cand = nn.Linear(hidden_size, hidden_size)
def encode_memory(self, input_ids, attention_mask,
ref_seg_token_mask, cand_seg_token_mask):
"""Returns dict with ref_pool, cand_pool, memory, valid masks.
ref_pool/cand_pool include a leading NULL slot at index 0.
"""
out = self.backbone(input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True)
hidden = out.last_hidden_state # [B, T, D]
ref_pool, ref_valid = _segment_pool(hidden, ref_seg_token_mask)
cand_pool, cand_valid = _segment_pool(hidden, cand_seg_token_mask)
B = hidden.size(0)
device = hidden.device
zero_t = torch.zeros(B, 1, dtype=torch.long, device=device)
one_t = torch.ones(B, 1, dtype=torch.long, device=device)
# Prepend NULL slot at index 0 on each side.
null_r = self.null_ref.expand(B, 1, -1).to(hidden.dtype)
null_c = self.null_cand.expand(B, 1, -1).to(hidden.dtype)
ref_pool_full = torch.cat([null_r, ref_pool], dim=1)
cand_pool_full = torch.cat([null_c, cand_pool], dim=1)
# Side-type embeddings
side_ref = self.mem_type_emb(zero_t).to(hidden.dtype)
side_cand = self.mem_type_emb(one_t).to(hidden.dtype)
ref_pool_full = ref_pool_full + side_ref
cand_pool_full = cand_pool_full + side_cand
bool_one = torch.ones(B, 1, dtype=torch.bool, device=device)
ref_valid_full = torch.cat([bool_one, ref_valid], dim=1)
cand_valid_full = torch.cat([bool_one, cand_valid], dim=1)
memory = torch.cat([ref_pool_full, cand_pool_full], dim=1) # [B, M, D]
memory_valid = torch.cat([ref_valid_full, cand_valid_full], dim=1)
return {
"ref_pool": ref_pool_full, "ref_valid": ref_valid_full,
"cand_pool": cand_pool_full, "cand_valid": cand_valid_full,
"memory": memory, "memory_valid": memory_valid,
}
def _gather_seg_emb(self, pool: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""pool: [B, S, D], idx: [B, K] (≥0). Returns [B, K, D] via batched gather."""
B, K = idx.shape
D = pool.size(-1)
b_idx = torch.arange(B, device=pool.device).unsqueeze(1).expand(-1, K)
return pool[b_idx, idx]
def forward_train(
self,
input_ids, attention_mask,
ref_seg_token_mask, cand_seg_token_mask,
target_cat, target_anat, target_concept, target_sev,
target_ref, target_cand,
):
"""All targets are [B, K]. Padding & ignored positions are -100.
target_cat[b, k]==0 marks EOS at position k.
target_ref/target_cand are indices into ref_pool/cand_pool (incl. NULL=0).
"""
enc = self.encode_memory(input_ids, attention_mask,
ref_seg_token_mask, cand_seg_token_mask)
memory = enc["memory"]
ref_pool, cand_pool = enc["ref_pool"], enc["cand_pool"]
B, K = target_cat.shape
# For teacher-forcing we need the segment embedding for each target step,
# using clamp_min(0) so PAD/IGNORE sites get NULL. Loss ignores them later.
ref_idx_safe = target_ref.clamp_min(0)
cand_idx_safe = target_cand.clamp_min(0)
ref_emb_per_t = self._gather_seg_emb(ref_pool, ref_idx_safe)
cand_emb_per_t = self._gather_seg_emb(cand_pool, cand_idx_safe)
tuple_emb_all = self.tuple_emb(
cat=target_cat.clamp_min(0),
anat=target_anat.clamp_min(0),
concept=target_concept.clamp_min(0),
sev=target_sev.clamp_min(0),
ref_emb=ref_emb_per_t,
cand_emb=cand_emb_per_t,
)
# Shift right with BOS
bos = self.bos_emb.expand(B, 1, -1).to(tuple_emb_all.dtype)
decoder_input = torch.cat([bos, tuple_emb_all[:, :-1, :]], dim=1)
causal_mask = nn.Transformer.generate_square_subsequent_mask(K).to(decoder_input.device)
mem_kp_mask = ~enc["memory_valid"]
out = self.decoder(
tgt=decoder_input,
memory=memory,
tgt_mask=causal_mask,
memory_key_padding_mask=mem_kp_mask,
) # [B, K, D]
logits_cat = self.head_cat(out)
logits_anat = self.head_anat(out)
logits_concept = self.head_concept(out)
logits_sev = self.head_severity(out)
scale = 1.0 / math.sqrt(self.hidden_size)
ref_q = self.proj_ref(out)
cand_q = self.proj_cand(out)
logits_ref = torch.einsum("bkd,bsd->bks", ref_q, ref_pool) * scale
logits_cand = torch.einsum("bkd,bsd->bks", cand_q, cand_pool) * scale
# Mask invalid pointer slots (padded segments) to -inf
logits_ref = logits_ref.masked_fill(~enc["ref_valid"][:, None, :], -1e4)
logits_cand = logits_cand.masked_fill(~enc["cand_valid"][:, None, :], -1e4)
return {
"logits_cat": logits_cat,
"logits_anat": logits_anat,
"logits_concept": logits_concept,
"logits_sev": logits_sev,
"logits_ref": logits_ref,
"logits_cand": logits_cand,
"memory": memory,
}
@torch.no_grad()
def decode_greedy(
self,
input_ids, attention_mask,
ref_seg_token_mask, cand_seg_token_mask,
):
"""Greedy autoregressive decoding. Returns list-of-list of dicts (per pair)."""
enc = self.encode_memory(input_ids, attention_mask,
ref_seg_token_mask, cand_seg_token_mask)
memory = enc["memory"]
ref_pool, cand_pool = enc["ref_pool"], enc["cand_pool"]
ref_valid, cand_valid = enc["ref_valid"], enc["cand_valid"]
mem_kp_mask = ~enc["memory_valid"]
B = input_ids.size(0)
device = input_ids.device
D = memory.size(-1)
bos = self.bos_emb.expand(B, 1, -1).to(memory.dtype)
prev_emb = bos
running = torch.ones(B, dtype=torch.bool, device=device)
out_seqs = [[] for _ in range(B)]
for step in range(self.max_decode_steps):
causal = nn.Transformer.generate_square_subsequent_mask(prev_emb.size(1)).to(device)
dec = self.decoder(prev_emb, memory, tgt_mask=causal, memory_key_padding_mask=mem_kp_mask)
last = dec[:, -1, :] # [B, D]
# Sample / argmax each head
cat_pred = self.head_cat(last).argmax(-1) # [B]
anat_pred = self.head_anat(last).argmax(-1)
concept_pred = self.head_concept(last).argmax(-1)
sev_pred = self.head_severity(last).argmax(-1)
scale = 1.0 / math.sqrt(self.hidden_size)
ref_q = self.proj_ref(last)
cand_q = self.proj_cand(last)
ref_logit = (torch.einsum("bd,bsd->bs", ref_q, ref_pool) * scale).masked_fill(~ref_valid, -1e4)
cand_logit = (torch.einsum("bd,bsd->bs", cand_q, cand_pool) * scale).masked_fill(~cand_valid, -1e4)
ref_pred = ref_logit.argmax(-1)
cand_pred = cand_logit.argmax(-1)
for b in range(B):
if not running[b]:
continue
if cat_pred[b].item() == self.EOS_CAT_IDX:
running[b] = False
continue
out_seqs[b].append({
"cat": int(cat_pred[b]),
"anat": int(anat_pred[b]),
"concept_id": int(concept_pred[b]),
"severity": int(sev_pred[b]),
"ref_seg_idx": int(ref_pred[b]),
"cand_seg_idx": int(cand_pred[b]),
})
if not running.any():
break
# Build next-step embedding from this step's predictions
ref_emb_step = ref_pool[torch.arange(B, device=device), ref_pred]
cand_emb_step = cand_pool[torch.arange(B, device=device), cand_pred]
next_emb = self.tuple_emb(
cat=cat_pred, anat=anat_pred,
concept=concept_pred, sev=sev_pred,
ref_emb=ref_emb_step, cand_emb=cand_emb_step,
).unsqueeze(1) # [B, 1, D]
prev_emb = torch.cat([prev_emb, next_emb], dim=1)
return out_seqs
def cadad_loss(out: Dict[str, torch.Tensor],
target_cat, target_anat, target_concept, target_sev,
target_ref, target_cand,
weights: Optional[Dict[str, float]] = None) -> Dict[str, torch.Tensor]:
"""Cross-entropy on every head. Pad/ignore positions = -100 in targets.
EOS positions only supervise `cat`; other heads should be -100 there.
"""
w = {"cat": 1.0, "anat": 0.5, "concept": 0.3, "sev": 0.5,
"ref": 0.5, "cand": 0.5, **(weights or {})}
L_cat = F.cross_entropy(out["logits_cat"].reshape(-1, out["logits_cat"].size(-1)),
target_cat.reshape(-1), ignore_index=-100)
L_anat = F.cross_entropy(out["logits_anat"].reshape(-1, out["logits_anat"].size(-1)),
target_anat.reshape(-1), ignore_index=-100)
L_concept = F.cross_entropy(out["logits_concept"].reshape(-1, out["logits_concept"].size(-1)),
target_concept.reshape(-1), ignore_index=-100)
L_sev = F.cross_entropy(out["logits_sev"].reshape(-1, out["logits_sev"].size(-1)),
target_sev.reshape(-1), ignore_index=-100)
L_ref = F.cross_entropy(out["logits_ref"].reshape(-1, out["logits_ref"].size(-1)),
target_ref.reshape(-1), ignore_index=-100)
L_cand = F.cross_entropy(out["logits_cand"].reshape(-1, out["logits_cand"].size(-1)),
target_cand.reshape(-1), ignore_index=-100)
total = (w["cat"] * L_cat + w["anat"] * L_anat + w["concept"] * L_concept
+ w["sev"] * L_sev + w["ref"] * L_ref + w["cand"] * L_cand)
return {"total": total, "cat": L_cat, "anat": L_anat, "concept": L_concept,
"sev": L_sev, "ref": L_ref, "cand": L_cand}