""" CitationModule: Understands scientific citation structure. Detects citation spans, tracks provenance, and estimates claim confidence. """ import torch import torch.nn as nn import torch.nn.functional as F import re from typing import Optional, Tuple, List class CitationModule(nn.Module): """ Understands scientific citation structure. - Detects citation spans [Author, Year] or (1) style - Learns that cited claims carry different epistemic weight - Distinguishes established facts vs recent/contested findings - Tracks claim provenance through the context window """ def __init__(self, d_model: int): """ Initialize CitationModule. Args: d_model: Model dimension """ super().__init__() self.d_model = d_model # Citation span detector (3 classes: none, inline, reference) # Inline: (Author, Year) or [1] # Reference: full citation at end of paper self.citation_detector = nn.Linear(d_model, 3) # Provenance gate: modulates information flow based on citation context self.provenance_gate = nn.Linear(d_model, d_model) # Claim confidence head: estimates how well-supported a claim is self.confidence_head = nn.Linear(d_model, 1) # Citation type embeddings self.citation_type_embedding = nn.Embedding(3, d_model) # Initialize weights self._initialize_weights() def _initialize_weights(self): """Initialize weights.""" for module in [self.citation_detector, self.provenance_gate, self.confidence_head, self.citation_type_embedding]: if hasattr(module, 'weight'): nn.init.normal_(module.weight, mean=0.0, std=0.02) if hasattr(module, 'bias') and module.bias is not None: nn.init.zeros_(module.bias) def detect_citation_spans( self, text: str, ) -> List[Tuple[int, int, str]]: """ Detect citation spans in text. Supports: (Author, Year), [1], [Author, Year], et al. Args: text: Input text string Returns: List of (start_char, end_char, citation_type) citation_type: "inline" or "reference" """ spans = [] # Pattern 1: (Author, Year) or (Author Year) for match in re.finditer(r'\([A-Za-z\s]+(?:et al\.)?,?\s*\d{4}\)', text): spans.append((match.start(), match.end(), "inline")) # Pattern 2: [1] or [1-3] or [1,2,3] for match in re.finditer(r'\[\d+(?:[-,]\d+)*\]', text): spans.append((match.start(), match.end(), "inline")) # Pattern 3: [Author, Year] for match in re.finditer(r'\[[A-Za-z\s]+,?\s*\d{4}\]', text): spans.append((match.start(), match.end(), "inline")) # Pattern 4: et al. (often indicates citation) for match in re.finditer(r'\bet al\.\b', text): spans.append((match.start(), match.end(), "inline")) return spans def forward( self, x: torch.Tensor, text: Optional[List[str]] = None, citation_spans: Optional[List[List[Tuple[int, int, str]]]] = None, ) -> torch.Tensor: """ Forward pass through citation module. Args: x: Input tensor (batch, seq_len, d_model) text: Optional original text strings citation_spans: Optional pre-computed citation spans per batch Returns: Citation-enhanced representation (batch, seq_len, d_model) """ batch, seq_len, d_model = x.shape # Detect citation spans if citation_spans is None and text is not None: citation_spans = [] for b in range(batch): spans = self.detect_citation_spans(text[b]) # Convert char spans to token spans (approximate) token_spans = [] for start_char, end_char, ctype in spans: start_tok = max(0, start_char // 4) end_tok = min(seq_len, end_char // 4 + 1) token_spans.append((start_tok, end_tok, ctype)) citation_spans.append(token_spans) # Compute citation type logits citation_logits = self.citation_detector(x) # (batch, seq_len, 3) citation_probs = F.softmax(citation_logits, dim=-1) # Apply citation-specific transformations output = x.clone() if citation_spans: for b in range(batch): spans_b = citation_spans[b] if b < len(citation_spans) else [] for start_tok, end_tok, ctype in spans_b: if end_tok <= start_tok: continue # Get citation type embedding if ctype == "inline": type_id = 1 elif ctype == "reference": type_id = 2 else: type_id = 0 type_emb = self.citation_type_embedding( torch.tensor(type_id, device=x.device) ) # Apply provenance gate to citation span span_slice = x[b, start_tok:end_tok, :] gated = span_slice * torch.sigmoid(self.provenance_gate(span_slice)) # Add citation type embedding gated = gated + type_emb.unsqueeze(0).unsqueeze(0) output[b, start_tok:end_tok, :] = gated # Compute confidence scores (for auxiliary loss) confidence = torch.sigmoid(self.confidence_head(x)) # (batch, seq_len, 1) return output, confidence def compute_citation_loss( self, x: torch.Tensor, citation_mask: torch.Tensor, confidence: torch.Tensor, ) -> torch.Tensor: """ Compute auxiliary loss for citation detection and confidence. Args: x: Input tensor (batch, seq_len, d_model) citation_mask: Ground truth citation mask (batch, seq_len), 1 if token is in citation confidence: Predicted confidence scores (batch, seq_len, 1) Returns: Combined citation loss """ # Citation detection loss logits = self.citation_detector(x) # (batch, seq_len, 3) detection_loss = F.cross_entropy( logits.view(-1, 3), citation_mask.long().view(-1), ) # Confidence calibration loss (encourage high confidence for true citations) confidence_loss = F.mse_loss( confidence.squeeze(-1), citation_mask.float(), ) return detection_loss + 0.1 * confidence_loss def test_citation_module(): """Test CitationModule.""" d_model = 512 batch_size = 2 seq_len = 128 module = CitationModule(d_model) x = torch.randn(batch_size, seq_len, d_model) text = [ "The theory of relativity (Einstein, 1905) revolutionized physics. See also [1, 2].", "According to Smith et al., the results are significant. Further reading: [Doe, 2020]." ] output, confidence = module(x, text=text) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}") print(f"Confidence shape: {confidence.shape}") assert output.shape == x.shape assert confidence.shape == (batch_size, seq_len, 1) # Test loss citation_mask = torch.zeros(batch_size, seq_len) citation_mask[0, 20:25] = 1.0 # Simulate citation span citation_mask[1, 10:18] = 1.0 loss = module.compute_citation_loss(x, citation_mask, confidence) print(f"Citation loss: {loss.item():.4f}") print("CitationModule test passed!") if __name__ == "__main__": test_citation_module()