File size: 8,099 Bytes
bf64b03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""

CitationModule: Understands scientific citation structure.

Detects citation spans, tracks provenance, and estimates claim confidence.

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import re
from typing import Optional, Tuple, List


class CitationModule(nn.Module):
    """

    Understands scientific citation structure.

    - Detects citation spans [Author, Year] or (1) style

    - Learns that cited claims carry different epistemic weight

    - Distinguishes established facts vs recent/contested findings

    - Tracks claim provenance through the context window

    """

    def __init__(self, d_model: int):
        """

        Initialize CitationModule.



        Args:

            d_model: Model dimension

        """
        super().__init__()
        self.d_model = d_model

        # Citation span detector (3 classes: none, inline, reference)
        # Inline: (Author, Year) or [1]
        # Reference: full citation at end of paper
        self.citation_detector = nn.Linear(d_model, 3)

        # Provenance gate: modulates information flow based on citation context
        self.provenance_gate = nn.Linear(d_model, d_model)

        # Claim confidence head: estimates how well-supported a claim is
        self.confidence_head = nn.Linear(d_model, 1)

        # Citation type embeddings
        self.citation_type_embedding = nn.Embedding(3, d_model)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize weights."""
        for module in [self.citation_detector, self.provenance_gate, self.confidence_head, self.citation_type_embedding]:
            if hasattr(module, 'weight'):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if hasattr(module, 'bias') and module.bias is not None:
                nn.init.zeros_(module.bias)

    def detect_citation_spans(

        self,

        text: str,

    ) -> List[Tuple[int, int, str]]:
        """

        Detect citation spans in text.

        Supports: (Author, Year), [1], [Author, Year], et al.



        Args:

            text: Input text string



        Returns:

            List of (start_char, end_char, citation_type)

            citation_type: "inline" or "reference"

        """
        spans = []

        # Pattern 1: (Author, Year) or (Author Year)
        for match in re.finditer(r'\([A-Za-z\s]+(?:et al\.)?,?\s*\d{4}\)', text):
            spans.append((match.start(), match.end(), "inline"))

        # Pattern 2: [1] or [1-3] or [1,2,3]
        for match in re.finditer(r'\[\d+(?:[-,]\d+)*\]', text):
            spans.append((match.start(), match.end(), "inline"))

        # Pattern 3: [Author, Year]
        for match in re.finditer(r'\[[A-Za-z\s]+,?\s*\d{4}\]', text):
            spans.append((match.start(), match.end(), "inline"))

        # Pattern 4: et al. (often indicates citation)
        for match in re.finditer(r'\bet al\.\b', text):
            spans.append((match.start(), match.end(), "inline"))

        return spans

    def forward(

        self,

        x: torch.Tensor,

        text: Optional[List[str]] = None,

        citation_spans: Optional[List[List[Tuple[int, int, str]]]] = None,

    ) -> torch.Tensor:
        """

        Forward pass through citation module.



        Args:

            x: Input tensor (batch, seq_len, d_model)

            text: Optional original text strings

            citation_spans: Optional pre-computed citation spans per batch



        Returns:

            Citation-enhanced representation (batch, seq_len, d_model)

        """
        batch, seq_len, d_model = x.shape

        # Detect citation spans
        if citation_spans is None and text is not None:
            citation_spans = []
            for b in range(batch):
                spans = self.detect_citation_spans(text[b])
                # Convert char spans to token spans (approximate)
                token_spans = []
                for start_char, end_char, ctype in spans:
                    start_tok = max(0, start_char // 4)
                    end_tok = min(seq_len, end_char // 4 + 1)
                    token_spans.append((start_tok, end_tok, ctype))
                citation_spans.append(token_spans)

        # Compute citation type logits
        citation_logits = self.citation_detector(x)  # (batch, seq_len, 3)
        citation_probs = F.softmax(citation_logits, dim=-1)

        # Apply citation-specific transformations
        output = x.clone()

        if citation_spans:
            for b in range(batch):
                spans_b = citation_spans[b] if b < len(citation_spans) else []

                for start_tok, end_tok, ctype in spans_b:
                    if end_tok <= start_tok:
                        continue

                    # Get citation type embedding
                    if ctype == "inline":
                        type_id = 1
                    elif ctype == "reference":
                        type_id = 2
                    else:
                        type_id = 0

                    type_emb = self.citation_type_embedding(
                        torch.tensor(type_id, device=x.device)
                    )

                    # Apply provenance gate to citation span
                    span_slice = x[b, start_tok:end_tok, :]
                    gated = span_slice * torch.sigmoid(self.provenance_gate(span_slice))

                    # Add citation type embedding
                    gated = gated + type_emb.unsqueeze(0).unsqueeze(0)

                    output[b, start_tok:end_tok, :] = gated

        # Compute confidence scores (for auxiliary loss)
        confidence = torch.sigmoid(self.confidence_head(x))  # (batch, seq_len, 1)

        return output, confidence

    def compute_citation_loss(

        self,

        x: torch.Tensor,

        citation_mask: torch.Tensor,

        confidence: torch.Tensor,

    ) -> torch.Tensor:
        """

        Compute auxiliary loss for citation detection and confidence.



        Args:

            x: Input tensor (batch, seq_len, d_model)

            citation_mask: Ground truth citation mask (batch, seq_len), 1 if token is in citation

            confidence: Predicted confidence scores (batch, seq_len, 1)



        Returns:

            Combined citation loss

        """
        # Citation detection loss
        logits = self.citation_detector(x)  # (batch, seq_len, 3)
        detection_loss = F.cross_entropy(
            logits.view(-1, 3),
            citation_mask.long().view(-1),
        )

        # Confidence calibration loss (encourage high confidence for true citations)
        confidence_loss = F.mse_loss(
            confidence.squeeze(-1),
            citation_mask.float(),
        )

        return detection_loss + 0.1 * confidence_loss


def test_citation_module():
    """Test CitationModule."""
    d_model = 512
    batch_size = 2
    seq_len = 128

    module = CitationModule(d_model)

    x = torch.randn(batch_size, seq_len, d_model)
    text = [
        "The theory of relativity (Einstein, 1905) revolutionized physics. See also [1, 2].",
        "According to Smith et al., the results are significant. Further reading: [Doe, 2020]."
    ]

    output, confidence = module(x, text=text)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Confidence shape: {confidence.shape}")
    assert output.shape == x.shape
    assert confidence.shape == (batch_size, seq_len, 1)

    # Test loss
    citation_mask = torch.zeros(batch_size, seq_len)
    citation_mask[0, 20:25] = 1.0  # Simulate citation span
    citation_mask[1, 10:18] = 1.0
    loss = module.compute_citation_loss(x, citation_mask, confidence)
    print(f"Citation loss: {loss.item():.4f}")

    print("CitationModule test passed!")


if __name__ == "__main__":
    test_citation_module()