"""Causal Counterfactual Attention (CCA) for MANIFOLD.""" from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Dict, Any from manifold.models.layers.attention import MultiHeadLinearAttention class CounterfactualProbe(nn.Module): """ Learnable query vectors for counterfactual reasoning. These probes ask "what if" questions about the input sequence. """ def __init__(self, embed_dim: int = 256, num_probes: int = 16): super().__init__() self.embed_dim = embed_dim self.num_probes = num_probes # Learnable probe vectors - "what if" questions self.probes = nn.Parameter(torch.randn(num_probes, embed_dim) * 0.02) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Compute attention between probes and sequence. Args: x: Input [batch, seq, embed_dim] Returns: Probe outputs [batch, num_probes, embed_dim] """ batch, seq, dim = x.shape # Probes as queries: [num_probes, embed_dim] -> [batch, num_probes, embed_dim] q = self.probes.unsqueeze(0).expand(batch, -1, -1) # x as keys and values: [batch, seq, embed_dim] k = x v = x # Scaled dot-product attention (sparse: only num_probes queries) # Attention weights: [batch, num_probes, seq] scale = dim ** -0.5 attn = torch.bmm(q, k.transpose(1, 2)) * scale attn = F.softmax(attn, dim=-1) # Weighted sum of values: [batch, num_probes, embed_dim] output = torch.bmm(attn, v) return output class CausalCounterfactualAttention(nn.Module): """ Dual-path attention: factual (standard) + counterfactual (sparse probes). Factual path: Linear attention O(T) on actual sequence Counterfactual path: 16 sparse probes asking "what if" questions """ def __init__( self, embed_dim: int = 256, num_cf_probes: int = 16, num_heads: int = 8, dropout: float = 0.1, ): super().__init__() self.embed_dim = embed_dim self.num_cf_probes = num_cf_probes # Factual path: causal linear attention O(T) self.factual_attention = MultiHeadLinearAttention( embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, causal=True, use_rotary=True, ) # Counterfactual path: sparse probes self.cf_probes = CounterfactualProbe( embed_dim=embed_dim, num_probes=num_cf_probes, ) # Project counterfactual probe outputs to sequence contribution self.cf_proj = nn.Linear(embed_dim, embed_dim) # Learnable weights to broadcast cf probes to sequence positions # Maps [batch, num_probes, embed_dim] -> contribution at each position self.cf_to_seq = nn.Linear(num_cf_probes, 1) # Combine factual + counterfactual self.combine = nn.Linear(embed_dim * 2, embed_dim) # Layer normalization self.norm = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: """ Args: x: Input [batch, seq, embed_dim] Returns: Dict with: - "output": combined output [batch, seq, embed_dim] - "factual": factual attention output - "counterfactual": counterfactual probe outputs [batch, num_probes, embed_dim] """ batch, seq, _ = x.shape # Factual path: linear attention on sequence factual_out = self.factual_attention(x, mask=mask)["output"] # Counterfactual path: probe attention cf_out = self.cf_probes(x) # [batch, num_probes, embed_dim] cf_projected = self.cf_proj(cf_out) # [batch, num_probes, embed_dim] # Broadcast counterfactual to sequence length # [batch, num_probes, embed_dim] -> [batch, seq, embed_dim] # Transpose for linear: [batch, embed_dim, num_probes] cf_transposed = cf_projected.transpose(1, 2) # Apply linear to last dim: [batch, embed_dim, 1] cf_seq = self.cf_to_seq(cf_transposed) # Squeeze and expand: [batch, embed_dim] -> [batch, seq, embed_dim] cf_contribution = cf_seq.squeeze(-1).unsqueeze(1).expand(-1, seq, -1) # Combine: concatenate factual and counterfactual contributions combined = torch.cat([factual_out, cf_contribution], dim=-1) output = self.combine(combined) output = self.dropout(output) # Normalize output = self.norm(output) return { "output": output, "factual": factual_out, "counterfactual": cf_out, } @classmethod def from_config(cls, config) -> "CausalCounterfactualAttention": """Create from ModelConfig.""" return cls( embed_dim=config.embed_dim, num_cf_probes=config.num_cf_probes, num_heads=config.cca_heads, dropout=config.dropout, )