| | """
|
| | CitationModule: Understands scientific citation structure.
|
| | Detects citation spans, tracks provenance, and estimates claim confidence.
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | import re
|
| | from typing import Optional, Tuple, List
|
| |
|
| |
|
| | class CitationModule(nn.Module):
|
| | """
|
| | Understands scientific citation structure.
|
| | - Detects citation spans [Author, Year] or (1) style
|
| | - Learns that cited claims carry different epistemic weight
|
| | - Distinguishes established facts vs recent/contested findings
|
| | - Tracks claim provenance through the context window
|
| | """
|
| |
|
| | def __init__(self, d_model: int):
|
| | """
|
| | Initialize CitationModule.
|
| |
|
| | Args:
|
| | d_model: Model dimension
|
| | """
|
| | super().__init__()
|
| | self.d_model = d_model
|
| |
|
| |
|
| |
|
| |
|
| | self.citation_detector = nn.Linear(d_model, 3)
|
| |
|
| |
|
| | self.provenance_gate = nn.Linear(d_model, d_model)
|
| |
|
| |
|
| | self.confidence_head = nn.Linear(d_model, 1)
|
| |
|
| |
|
| | self.citation_type_embedding = nn.Embedding(3, d_model)
|
| |
|
| |
|
| | self._initialize_weights()
|
| |
|
| | def _initialize_weights(self):
|
| | """Initialize weights."""
|
| | for module in [self.citation_detector, self.provenance_gate, self.confidence_head, self.citation_type_embedding]:
|
| | if hasattr(module, 'weight'):
|
| | nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| | if hasattr(module, 'bias') and module.bias is not None:
|
| | nn.init.zeros_(module.bias)
|
| |
|
| | def detect_citation_spans(
|
| | self,
|
| | text: str,
|
| | ) -> List[Tuple[int, int, str]]:
|
| | """
|
| | Detect citation spans in text.
|
| | Supports: (Author, Year), [1], [Author, Year], et al.
|
| |
|
| | Args:
|
| | text: Input text string
|
| |
|
| | Returns:
|
| | List of (start_char, end_char, citation_type)
|
| | citation_type: "inline" or "reference"
|
| | """
|
| | spans = []
|
| |
|
| |
|
| | for match in re.finditer(r'\([A-Za-z\s]+(?:et al\.)?,?\s*\d{4}\)', text):
|
| | spans.append((match.start(), match.end(), "inline"))
|
| |
|
| |
|
| | for match in re.finditer(r'\[\d+(?:[-,]\d+)*\]', text):
|
| | spans.append((match.start(), match.end(), "inline"))
|
| |
|
| |
|
| | for match in re.finditer(r'\[[A-Za-z\s]+,?\s*\d{4}\]', text):
|
| | spans.append((match.start(), match.end(), "inline"))
|
| |
|
| |
|
| | for match in re.finditer(r'\bet al\.\b', text):
|
| | spans.append((match.start(), match.end(), "inline"))
|
| |
|
| | return spans
|
| |
|
| | def forward(
|
| | self,
|
| | x: torch.Tensor,
|
| | text: Optional[List[str]] = None,
|
| | citation_spans: Optional[List[List[Tuple[int, int, str]]]] = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Forward pass through citation module.
|
| |
|
| | Args:
|
| | x: Input tensor (batch, seq_len, d_model)
|
| | text: Optional original text strings
|
| | citation_spans: Optional pre-computed citation spans per batch
|
| |
|
| | Returns:
|
| | Citation-enhanced representation (batch, seq_len, d_model)
|
| | """
|
| | batch, seq_len, d_model = x.shape
|
| |
|
| |
|
| | if citation_spans is None and text is not None:
|
| | citation_spans = []
|
| | for b in range(batch):
|
| | spans = self.detect_citation_spans(text[b])
|
| |
|
| | token_spans = []
|
| | for start_char, end_char, ctype in spans:
|
| | start_tok = max(0, start_char // 4)
|
| | end_tok = min(seq_len, end_char // 4 + 1)
|
| | token_spans.append((start_tok, end_tok, ctype))
|
| | citation_spans.append(token_spans)
|
| |
|
| |
|
| | citation_logits = self.citation_detector(x)
|
| | citation_probs = F.softmax(citation_logits, dim=-1)
|
| |
|
| |
|
| | output = x.clone()
|
| |
|
| | if citation_spans:
|
| | for b in range(batch):
|
| | spans_b = citation_spans[b] if b < len(citation_spans) else []
|
| |
|
| | for start_tok, end_tok, ctype in spans_b:
|
| | if end_tok <= start_tok:
|
| | continue
|
| |
|
| |
|
| | if ctype == "inline":
|
| | type_id = 1
|
| | elif ctype == "reference":
|
| | type_id = 2
|
| | else:
|
| | type_id = 0
|
| |
|
| | type_emb = self.citation_type_embedding(
|
| | torch.tensor(type_id, device=x.device)
|
| | )
|
| |
|
| |
|
| | span_slice = x[b, start_tok:end_tok, :]
|
| | gated = span_slice * torch.sigmoid(self.provenance_gate(span_slice))
|
| |
|
| |
|
| | gated = gated + type_emb.unsqueeze(0).unsqueeze(0)
|
| |
|
| | output[b, start_tok:end_tok, :] = gated
|
| |
|
| |
|
| | confidence = torch.sigmoid(self.confidence_head(x))
|
| |
|
| | return output, confidence
|
| |
|
| | def compute_citation_loss(
|
| | self,
|
| | x: torch.Tensor,
|
| | citation_mask: torch.Tensor,
|
| | confidence: torch.Tensor,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Compute auxiliary loss for citation detection and confidence.
|
| |
|
| | Args:
|
| | x: Input tensor (batch, seq_len, d_model)
|
| | citation_mask: Ground truth citation mask (batch, seq_len), 1 if token is in citation
|
| | confidence: Predicted confidence scores (batch, seq_len, 1)
|
| |
|
| | Returns:
|
| | Combined citation loss
|
| | """
|
| |
|
| | logits = self.citation_detector(x)
|
| | detection_loss = F.cross_entropy(
|
| | logits.view(-1, 3),
|
| | citation_mask.long().view(-1),
|
| | )
|
| |
|
| |
|
| | confidence_loss = F.mse_loss(
|
| | confidence.squeeze(-1),
|
| | citation_mask.float(),
|
| | )
|
| |
|
| | return detection_loss + 0.1 * confidence_loss
|
| |
|
| |
|
| | def test_citation_module():
|
| | """Test CitationModule."""
|
| | d_model = 512
|
| | batch_size = 2
|
| | seq_len = 128
|
| |
|
| | module = CitationModule(d_model)
|
| |
|
| | x = torch.randn(batch_size, seq_len, d_model)
|
| | text = [
|
| | "The theory of relativity (Einstein, 1905) revolutionized physics. See also [1, 2].",
|
| | "According to Smith et al., the results are significant. Further reading: [Doe, 2020]."
|
| | ]
|
| |
|
| | output, confidence = module(x, text=text)
|
| | print(f"Input shape: {x.shape}")
|
| | print(f"Output shape: {output.shape}")
|
| | print(f"Confidence shape: {confidence.shape}")
|
| | assert output.shape == x.shape
|
| | assert confidence.shape == (batch_size, seq_len, 1)
|
| |
|
| |
|
| | citation_mask = torch.zeros(batch_size, seq_len)
|
| | citation_mask[0, 20:25] = 1.0
|
| | citation_mask[1, 10:18] = 1.0
|
| | loss = module.compute_citation_loss(x, citation_mask, confidence)
|
| | print(f"Citation loss: {loss.item():.4f}")
|
| |
|
| | print("CitationModule test passed!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_citation_module()
|
| |
|