LimmeDev's picture
Initial MANIFOLD upload - CS2 cheat detection training
454ecdd verified
"""Information Horizon Encoder - Causal transformer with linear attention."""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any, List
from manifold.models.layers.attention import MultiHeadLinearAttention, RotaryPositionEncoding
class IHEBlock(nn.Module):
"""
Single IHE transformer block with linear attention + FFN.
Uses pre-norm architecture for training stability.
"""
def __init__(
self,
embed_dim: int = 256,
num_heads: int = 8,
ff_dim: int = 1024,
dropout: float = 0.1,
):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.attention = MultiHeadLinearAttention(
embed_dim=embed_dim,
num_heads=num_heads,
dropout=dropout,
causal=True,
use_rotary=True,
)
self.ffn = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(ff_dim, embed_dim),
nn.Dropout(dropout),
)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Forward pass through transformer block.
Args:
x: Input tensor [batch, seq, embed_dim]
mask: Optional attention mask [batch, seq]
Returns:
Dict with 'output' and 'attention_weights'
"""
normed = self.norm1(x)
attn_out = self.attention(normed, mask=mask)
x = x + attn_out["output"]
normed = self.norm2(x)
x = x + self.ffn(normed)
return {
"output": x,
"attention_weights": None,
}
class InformationHorizonEncoder(nn.Module):
"""
Multi-layer causal transformer for encoding player action sequences.
Uses linear attention O(T) and rotary position encoding.
Causal masking ensures actions can't see future information.
"""
def __init__(
self,
embed_dim: int = 256,
num_layers: int = 4,
num_heads: int = 8,
ff_dim: int = 1024,
dropout: float = 0.1,
max_seq_len: int = 128,
):
super().__init__()
self.embed_dim = embed_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.max_seq_len = max_seq_len
head_dim = embed_dim // num_heads
self.pos_encoding = RotaryPositionEncoding(
dim=head_dim,
max_seq_len=max_seq_len,
)
self.layers = nn.ModuleList([
IHEBlock(
embed_dim=embed_dim,
num_heads=num_heads,
ff_dim=ff_dim,
dropout=dropout,
)
for _ in range(num_layers)
])
self.final_norm = nn.LayerNorm(embed_dim)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Encode action sequence through causal transformer layers.
Args:
x: Input tensor [batch, seq, embed_dim]
mask: Optional attention mask [batch, seq]
Returns:
Dict with 'encoding' and 'all_layer_outputs'
"""
all_layer_outputs: List[torch.Tensor] = []
for layer in self.layers:
layer_out = layer(x, mask=mask)
x = layer_out["output"]
all_layer_outputs.append(x)
encoding = self.final_norm(x)
return {
"encoding": encoding,
"all_layer_outputs": all_layer_outputs,
}
@classmethod
def from_config(cls, config: Any) -> "InformationHorizonEncoder":
"""
Create InformationHorizonEncoder from ModelConfig.
Args:
config: ModelConfig instance with IHE parameters
Returns:
Configured InformationHorizonEncoder instance
"""
return cls(
embed_dim=config.embed_dim,
num_layers=config.ihe_layers,
num_heads=config.ihe_heads,
ff_dim=config.ihe_ff_dim,
dropout=config.ihe_dropout,
max_seq_len=config.sequence_length,
)