Spaces:
Running on Zero
Running on Zero
| """Causal Counterfactual Attention (CCA) for MANIFOLD.""" | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Dict, Any | |
| from manifold.models.layers.attention import MultiHeadLinearAttention | |
| class CounterfactualProbe(nn.Module): | |
| """ | |
| Learnable query vectors for counterfactual reasoning. | |
| These probes ask "what if" questions about the input sequence. | |
| """ | |
| def __init__(self, embed_dim: int = 256, num_probes: int = 16): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.num_probes = num_probes | |
| # Learnable probe vectors - "what if" questions | |
| self.probes = nn.Parameter(torch.randn(num_probes, embed_dim) * 0.02) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Compute attention between probes and sequence. | |
| Args: | |
| x: Input [batch, seq, embed_dim] | |
| Returns: | |
| Probe outputs [batch, num_probes, embed_dim] | |
| """ | |
| batch, seq, dim = x.shape | |
| # Probes as queries: [num_probes, embed_dim] -> [batch, num_probes, embed_dim] | |
| q = self.probes.unsqueeze(0).expand(batch, -1, -1) | |
| # x as keys and values: [batch, seq, embed_dim] | |
| k = x | |
| v = x | |
| # Scaled dot-product attention (sparse: only num_probes queries) | |
| # Attention weights: [batch, num_probes, seq] | |
| scale = dim ** -0.5 | |
| attn = torch.bmm(q, k.transpose(1, 2)) * scale | |
| attn = F.softmax(attn, dim=-1) | |
| # Weighted sum of values: [batch, num_probes, embed_dim] | |
| output = torch.bmm(attn, v) | |
| return output | |
| class CausalCounterfactualAttention(nn.Module): | |
| """ | |
| Dual-path attention: factual (standard) + counterfactual (sparse probes). | |
| Factual path: Linear attention O(T) on actual sequence | |
| Counterfactual path: 16 sparse probes asking "what if" questions | |
| """ | |
| def __init__( | |
| self, | |
| embed_dim: int = 256, | |
| num_cf_probes: int = 16, | |
| num_heads: int = 8, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.num_cf_probes = num_cf_probes | |
| # Factual path: causal linear attention O(T) | |
| self.factual_attention = MultiHeadLinearAttention( | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| causal=True, | |
| use_rotary=True, | |
| ) | |
| # Counterfactual path: sparse probes | |
| self.cf_probes = CounterfactualProbe( | |
| embed_dim=embed_dim, | |
| num_probes=num_cf_probes, | |
| ) | |
| # Project counterfactual probe outputs to sequence contribution | |
| self.cf_proj = nn.Linear(embed_dim, embed_dim) | |
| # Learnable weights to broadcast cf probes to sequence positions | |
| # Maps [batch, num_probes, embed_dim] -> contribution at each position | |
| self.cf_to_seq = nn.Linear(num_cf_probes, 1) | |
| # Combine factual + counterfactual | |
| self.combine = nn.Linear(embed_dim * 2, embed_dim) | |
| # Layer normalization | |
| self.norm = nn.LayerNorm(embed_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Args: | |
| x: Input [batch, seq, embed_dim] | |
| Returns: | |
| Dict with: | |
| - "output": combined output [batch, seq, embed_dim] | |
| - "factual": factual attention output | |
| - "counterfactual": counterfactual probe outputs [batch, num_probes, embed_dim] | |
| """ | |
| batch, seq, _ = x.shape | |
| # Factual path: linear attention on sequence | |
| factual_out = self.factual_attention(x, mask=mask)["output"] | |
| # Counterfactual path: probe attention | |
| cf_out = self.cf_probes(x) # [batch, num_probes, embed_dim] | |
| cf_projected = self.cf_proj(cf_out) # [batch, num_probes, embed_dim] | |
| # Broadcast counterfactual to sequence length | |
| # [batch, num_probes, embed_dim] -> [batch, seq, embed_dim] | |
| # Transpose for linear: [batch, embed_dim, num_probes] | |
| cf_transposed = cf_projected.transpose(1, 2) | |
| # Apply linear to last dim: [batch, embed_dim, 1] | |
| cf_seq = self.cf_to_seq(cf_transposed) | |
| # Squeeze and expand: [batch, embed_dim] -> [batch, seq, embed_dim] | |
| cf_contribution = cf_seq.squeeze(-1).unsqueeze(1).expand(-1, seq, -1) | |
| # Combine: concatenate factual and counterfactual contributions | |
| combined = torch.cat([factual_out, cf_contribution], dim=-1) | |
| output = self.combine(combined) | |
| output = self.dropout(output) | |
| # Normalize | |
| output = self.norm(output) | |
| return { | |
| "output": output, | |
| "factual": factual_out, | |
| "counterfactual": cf_out, | |
| } | |
| def from_config(cls, config) -> "CausalCounterfactualAttention": | |
| """Create from ModelConfig.""" | |
| return cls( | |
| embed_dim=config.embed_dim, | |
| num_cf_probes=config.num_cf_probes, | |
| num_heads=config.cca_heads, | |
| dropout=config.dropout, | |
| ) | |