File size: 5,469 Bytes
454ecdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""Causal Counterfactual Attention (CCA) for MANIFOLD."""

from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any
from manifold.models.layers.attention import MultiHeadLinearAttention


class CounterfactualProbe(nn.Module):
    """
    Learnable query vectors for counterfactual reasoning.
    
    These probes ask "what if" questions about the input sequence.
    """
    
    def __init__(self, embed_dim: int = 256, num_probes: int = 16):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_probes = num_probes
        
        # Learnable probe vectors - "what if" questions
        self.probes = nn.Parameter(torch.randn(num_probes, embed_dim) * 0.02)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute attention between probes and sequence.
        
        Args:
            x: Input [batch, seq, embed_dim]
            
        Returns:
            Probe outputs [batch, num_probes, embed_dim]
        """
        batch, seq, dim = x.shape
        
        # Probes as queries: [num_probes, embed_dim] -> [batch, num_probes, embed_dim]
        q = self.probes.unsqueeze(0).expand(batch, -1, -1)
        
        # x as keys and values: [batch, seq, embed_dim]
        k = x
        v = x
        
        # Scaled dot-product attention (sparse: only num_probes queries)
        # Attention weights: [batch, num_probes, seq]
        scale = dim ** -0.5
        attn = torch.bmm(q, k.transpose(1, 2)) * scale
        attn = F.softmax(attn, dim=-1)
        
        # Weighted sum of values: [batch, num_probes, embed_dim]
        output = torch.bmm(attn, v)
        
        return output


class CausalCounterfactualAttention(nn.Module):
    """
    Dual-path attention: factual (standard) + counterfactual (sparse probes).
    
    Factual path: Linear attention O(T) on actual sequence
    Counterfactual path: 16 sparse probes asking "what if" questions
    """
    
    def __init__(
        self,
        embed_dim: int = 256,
        num_cf_probes: int = 16,
        num_heads: int = 8,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_cf_probes = num_cf_probes
        
        # Factual path: causal linear attention O(T)
        self.factual_attention = MultiHeadLinearAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=dropout,
            causal=True,
            use_rotary=True,
        )
        
        # Counterfactual path: sparse probes
        self.cf_probes = CounterfactualProbe(
            embed_dim=embed_dim,
            num_probes=num_cf_probes,
        )
        
        # Project counterfactual probe outputs to sequence contribution
        self.cf_proj = nn.Linear(embed_dim, embed_dim)
        
        # Learnable weights to broadcast cf probes to sequence positions
        # Maps [batch, num_probes, embed_dim] -> contribution at each position
        self.cf_to_seq = nn.Linear(num_cf_probes, 1)
        
        # Combine factual + counterfactual
        self.combine = nn.Linear(embed_dim * 2, embed_dim)
        
        # Layer normalization
        self.norm = nn.LayerNorm(embed_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Args:
            x: Input [batch, seq, embed_dim]
            
        Returns:
            Dict with:
            - "output": combined output [batch, seq, embed_dim]
            - "factual": factual attention output
            - "counterfactual": counterfactual probe outputs [batch, num_probes, embed_dim]
        """
        batch, seq, _ = x.shape
        
        # Factual path: linear attention on sequence
        factual_out = self.factual_attention(x, mask=mask)["output"]
        
        # Counterfactual path: probe attention
        cf_out = self.cf_probes(x)  # [batch, num_probes, embed_dim]
        cf_projected = self.cf_proj(cf_out)  # [batch, num_probes, embed_dim]
        
        # Broadcast counterfactual to sequence length
        # [batch, num_probes, embed_dim] -> [batch, seq, embed_dim]
        # Transpose for linear: [batch, embed_dim, num_probes]
        cf_transposed = cf_projected.transpose(1, 2)
        # Apply linear to last dim: [batch, embed_dim, 1]
        cf_seq = self.cf_to_seq(cf_transposed)
        # Squeeze and expand: [batch, embed_dim] -> [batch, seq, embed_dim]
        cf_contribution = cf_seq.squeeze(-1).unsqueeze(1).expand(-1, seq, -1)
        
        # Combine: concatenate factual and counterfactual contributions
        combined = torch.cat([factual_out, cf_contribution], dim=-1)
        output = self.combine(combined)
        output = self.dropout(output)
        
        # Normalize
        output = self.norm(output)
        
        return {
            "output": output,
            "factual": factual_out,
            "counterfactual": cf_out,
        }
    
    @classmethod
    def from_config(cls, config) -> "CausalCounterfactualAttention":
        """Create from ModelConfig."""
        return cls(
            embed_dim=config.embed_dim,
            num_cf_probes=config.num_cf_probes,
            num_heads=config.cca_heads,
            dropout=config.dropout,
        )