LimmeDev's picture
Initial MANIFOLD upload - CS2 cheat detection training
454ecdd verified
"""Causal Counterfactual Attention (CCA) for MANIFOLD."""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any
from manifold.models.layers.attention import MultiHeadLinearAttention
class CounterfactualProbe(nn.Module):
"""
Learnable query vectors for counterfactual reasoning.
These probes ask "what if" questions about the input sequence.
"""
def __init__(self, embed_dim: int = 256, num_probes: int = 16):
super().__init__()
self.embed_dim = embed_dim
self.num_probes = num_probes
# Learnable probe vectors - "what if" questions
self.probes = nn.Parameter(torch.randn(num_probes, embed_dim) * 0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Compute attention between probes and sequence.
Args:
x: Input [batch, seq, embed_dim]
Returns:
Probe outputs [batch, num_probes, embed_dim]
"""
batch, seq, dim = x.shape
# Probes as queries: [num_probes, embed_dim] -> [batch, num_probes, embed_dim]
q = self.probes.unsqueeze(0).expand(batch, -1, -1)
# x as keys and values: [batch, seq, embed_dim]
k = x
v = x
# Scaled dot-product attention (sparse: only num_probes queries)
# Attention weights: [batch, num_probes, seq]
scale = dim ** -0.5
attn = torch.bmm(q, k.transpose(1, 2)) * scale
attn = F.softmax(attn, dim=-1)
# Weighted sum of values: [batch, num_probes, embed_dim]
output = torch.bmm(attn, v)
return output
class CausalCounterfactualAttention(nn.Module):
"""
Dual-path attention: factual (standard) + counterfactual (sparse probes).
Factual path: Linear attention O(T) on actual sequence
Counterfactual path: 16 sparse probes asking "what if" questions
"""
def __init__(
self,
embed_dim: int = 256,
num_cf_probes: int = 16,
num_heads: int = 8,
dropout: float = 0.1,
):
super().__init__()
self.embed_dim = embed_dim
self.num_cf_probes = num_cf_probes
# Factual path: causal linear attention O(T)
self.factual_attention = MultiHeadLinearAttention(
embed_dim=embed_dim,
num_heads=num_heads,
dropout=dropout,
causal=True,
use_rotary=True,
)
# Counterfactual path: sparse probes
self.cf_probes = CounterfactualProbe(
embed_dim=embed_dim,
num_probes=num_cf_probes,
)
# Project counterfactual probe outputs to sequence contribution
self.cf_proj = nn.Linear(embed_dim, embed_dim)
# Learnable weights to broadcast cf probes to sequence positions
# Maps [batch, num_probes, embed_dim] -> contribution at each position
self.cf_to_seq = nn.Linear(num_cf_probes, 1)
# Combine factual + counterfactual
self.combine = nn.Linear(embed_dim * 2, embed_dim)
# Layer normalization
self.norm = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Args:
x: Input [batch, seq, embed_dim]
Returns:
Dict with:
- "output": combined output [batch, seq, embed_dim]
- "factual": factual attention output
- "counterfactual": counterfactual probe outputs [batch, num_probes, embed_dim]
"""
batch, seq, _ = x.shape
# Factual path: linear attention on sequence
factual_out = self.factual_attention(x, mask=mask)["output"]
# Counterfactual path: probe attention
cf_out = self.cf_probes(x) # [batch, num_probes, embed_dim]
cf_projected = self.cf_proj(cf_out) # [batch, num_probes, embed_dim]
# Broadcast counterfactual to sequence length
# [batch, num_probes, embed_dim] -> [batch, seq, embed_dim]
# Transpose for linear: [batch, embed_dim, num_probes]
cf_transposed = cf_projected.transpose(1, 2)
# Apply linear to last dim: [batch, embed_dim, 1]
cf_seq = self.cf_to_seq(cf_transposed)
# Squeeze and expand: [batch, embed_dim] -> [batch, seq, embed_dim]
cf_contribution = cf_seq.squeeze(-1).unsqueeze(1).expand(-1, seq, -1)
# Combine: concatenate factual and counterfactual contributions
combined = torch.cat([factual_out, cf_contribution], dim=-1)
output = self.combine(combined)
output = self.dropout(output)
# Normalize
output = self.norm(output)
return {
"output": output,
"factual": factual_out,
"counterfactual": cf_out,
}
@classmethod
def from_config(cls, config) -> "CausalCounterfactualAttention":
"""Create from ModelConfig."""
return cls(
embed_dim=config.embed_dim,
num_cf_probes=config.num_cf_probes,
num_heads=config.cca_heads,
dropout=config.dropout,
)