Vortex-7b-V1 / models /science_modules /citation_module.py
Zandy-Wandy's picture
Upload Vortex model
bf64b03 verified
"""
CitationModule: Understands scientific citation structure.
Detects citation spans, tracks provenance, and estimates claim confidence.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
from typing import Optional, Tuple, List
class CitationModule(nn.Module):
"""
Understands scientific citation structure.
- Detects citation spans [Author, Year] or (1) style
- Learns that cited claims carry different epistemic weight
- Distinguishes established facts vs recent/contested findings
- Tracks claim provenance through the context window
"""
def __init__(self, d_model: int):
"""
Initialize CitationModule.
Args:
d_model: Model dimension
"""
super().__init__()
self.d_model = d_model
# Citation span detector (3 classes: none, inline, reference)
# Inline: (Author, Year) or [1]
# Reference: full citation at end of paper
self.citation_detector = nn.Linear(d_model, 3)
# Provenance gate: modulates information flow based on citation context
self.provenance_gate = nn.Linear(d_model, d_model)
# Claim confidence head: estimates how well-supported a claim is
self.confidence_head = nn.Linear(d_model, 1)
# Citation type embeddings
self.citation_type_embedding = nn.Embedding(3, d_model)
# Initialize weights
self._initialize_weights()
def _initialize_weights(self):
"""Initialize weights."""
for module in [self.citation_detector, self.provenance_gate, self.confidence_head, self.citation_type_embedding]:
if hasattr(module, 'weight'):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.zeros_(module.bias)
def detect_citation_spans(
self,
text: str,
) -> List[Tuple[int, int, str]]:
"""
Detect citation spans in text.
Supports: (Author, Year), [1], [Author, Year], et al.
Args:
text: Input text string
Returns:
List of (start_char, end_char, citation_type)
citation_type: "inline" or "reference"
"""
spans = []
# Pattern 1: (Author, Year) or (Author Year)
for match in re.finditer(r'\([A-Za-z\s]+(?:et al\.)?,?\s*\d{4}\)', text):
spans.append((match.start(), match.end(), "inline"))
# Pattern 2: [1] or [1-3] or [1,2,3]
for match in re.finditer(r'\[\d+(?:[-,]\d+)*\]', text):
spans.append((match.start(), match.end(), "inline"))
# Pattern 3: [Author, Year]
for match in re.finditer(r'\[[A-Za-z\s]+,?\s*\d{4}\]', text):
spans.append((match.start(), match.end(), "inline"))
# Pattern 4: et al. (often indicates citation)
for match in re.finditer(r'\bet al\.\b', text):
spans.append((match.start(), match.end(), "inline"))
return spans
def forward(
self,
x: torch.Tensor,
text: Optional[List[str]] = None,
citation_spans: Optional[List[List[Tuple[int, int, str]]]] = None,
) -> torch.Tensor:
"""
Forward pass through citation module.
Args:
x: Input tensor (batch, seq_len, d_model)
text: Optional original text strings
citation_spans: Optional pre-computed citation spans per batch
Returns:
Citation-enhanced representation (batch, seq_len, d_model)
"""
batch, seq_len, d_model = x.shape
# Detect citation spans
if citation_spans is None and text is not None:
citation_spans = []
for b in range(batch):
spans = self.detect_citation_spans(text[b])
# Convert char spans to token spans (approximate)
token_spans = []
for start_char, end_char, ctype in spans:
start_tok = max(0, start_char // 4)
end_tok = min(seq_len, end_char // 4 + 1)
token_spans.append((start_tok, end_tok, ctype))
citation_spans.append(token_spans)
# Compute citation type logits
citation_logits = self.citation_detector(x) # (batch, seq_len, 3)
citation_probs = F.softmax(citation_logits, dim=-1)
# Apply citation-specific transformations
output = x.clone()
if citation_spans:
for b in range(batch):
spans_b = citation_spans[b] if b < len(citation_spans) else []
for start_tok, end_tok, ctype in spans_b:
if end_tok <= start_tok:
continue
# Get citation type embedding
if ctype == "inline":
type_id = 1
elif ctype == "reference":
type_id = 2
else:
type_id = 0
type_emb = self.citation_type_embedding(
torch.tensor(type_id, device=x.device)
)
# Apply provenance gate to citation span
span_slice = x[b, start_tok:end_tok, :]
gated = span_slice * torch.sigmoid(self.provenance_gate(span_slice))
# Add citation type embedding
gated = gated + type_emb.unsqueeze(0).unsqueeze(0)
output[b, start_tok:end_tok, :] = gated
# Compute confidence scores (for auxiliary loss)
confidence = torch.sigmoid(self.confidence_head(x)) # (batch, seq_len, 1)
return output, confidence
def compute_citation_loss(
self,
x: torch.Tensor,
citation_mask: torch.Tensor,
confidence: torch.Tensor,
) -> torch.Tensor:
"""
Compute auxiliary loss for citation detection and confidence.
Args:
x: Input tensor (batch, seq_len, d_model)
citation_mask: Ground truth citation mask (batch, seq_len), 1 if token is in citation
confidence: Predicted confidence scores (batch, seq_len, 1)
Returns:
Combined citation loss
"""
# Citation detection loss
logits = self.citation_detector(x) # (batch, seq_len, 3)
detection_loss = F.cross_entropy(
logits.view(-1, 3),
citation_mask.long().view(-1),
)
# Confidence calibration loss (encourage high confidence for true citations)
confidence_loss = F.mse_loss(
confidence.squeeze(-1),
citation_mask.float(),
)
return detection_loss + 0.1 * confidence_loss
def test_citation_module():
"""Test CitationModule."""
d_model = 512
batch_size = 2
seq_len = 128
module = CitationModule(d_model)
x = torch.randn(batch_size, seq_len, d_model)
text = [
"The theory of relativity (Einstein, 1905) revolutionized physics. See also [1, 2].",
"According to Smith et al., the results are significant. Further reading: [Doe, 2020]."
]
output, confidence = module(x, text=text)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Confidence shape: {confidence.shape}")
assert output.shape == x.shape
assert confidence.shape == (batch_size, seq_len, 1)
# Test loss
citation_mask = torch.zeros(batch_size, seq_len)
citation_mask[0, 20:25] = 1.0 # Simulate citation span
citation_mask[1, 10:18] = 1.0
loss = module.compute_citation_loss(x, citation_mask, confidence)
print(f"Citation loss: {loss.item():.4f}")
print("CitationModule test passed!")
if __name__ == "__main__":
test_citation_module()