"""Information Horizon Encoder - Causal transformer with linear attention.""" from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Dict, Any, List from manifold.models.layers.attention import MultiHeadLinearAttention, RotaryPositionEncoding class IHEBlock(nn.Module): """ Single IHE transformer block with linear attention + FFN. Uses pre-norm architecture for training stability. """ def __init__( self, embed_dim: int = 256, num_heads: int = 8, ff_dim: int = 1024, dropout: float = 0.1, ): super().__init__() self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) self.attention = MultiHeadLinearAttention( embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, causal=True, use_rotary=True, ) self.ffn = nn.Sequential( nn.Linear(embed_dim, ff_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(ff_dim, embed_dim), nn.Dropout(dropout), ) def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: """ Forward pass through transformer block. Args: x: Input tensor [batch, seq, embed_dim] mask: Optional attention mask [batch, seq] Returns: Dict with 'output' and 'attention_weights' """ normed = self.norm1(x) attn_out = self.attention(normed, mask=mask) x = x + attn_out["output"] normed = self.norm2(x) x = x + self.ffn(normed) return { "output": x, "attention_weights": None, } class InformationHorizonEncoder(nn.Module): """ Multi-layer causal transformer for encoding player action sequences. Uses linear attention O(T) and rotary position encoding. Causal masking ensures actions can't see future information. """ def __init__( self, embed_dim: int = 256, num_layers: int = 4, num_heads: int = 8, ff_dim: int = 1024, dropout: float = 0.1, max_seq_len: int = 128, ): super().__init__() self.embed_dim = embed_dim self.num_layers = num_layers self.num_heads = num_heads self.max_seq_len = max_seq_len head_dim = embed_dim // num_heads self.pos_encoding = RotaryPositionEncoding( dim=head_dim, max_seq_len=max_seq_len, ) self.layers = nn.ModuleList([ IHEBlock( embed_dim=embed_dim, num_heads=num_heads, ff_dim=ff_dim, dropout=dropout, ) for _ in range(num_layers) ]) self.final_norm = nn.LayerNorm(embed_dim) def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: """ Encode action sequence through causal transformer layers. Args: x: Input tensor [batch, seq, embed_dim] mask: Optional attention mask [batch, seq] Returns: Dict with 'encoding' and 'all_layer_outputs' """ all_layer_outputs: List[torch.Tensor] = [] for layer in self.layers: layer_out = layer(x, mask=mask) x = layer_out["output"] all_layer_outputs.append(x) encoding = self.final_norm(x) return { "encoding": encoding, "all_layer_outputs": all_layer_outputs, } @classmethod def from_config(cls, config: Any) -> "InformationHorizonEncoder": """ Create InformationHorizonEncoder from ModelConfig. Args: config: ModelConfig instance with IHE parameters Returns: Configured InformationHorizonEncoder instance """ return cls( embed_dim=config.embed_dim, num_layers=config.ihe_layers, num_heads=config.ihe_heads, ff_dim=config.ihe_ff_dim, dropout=config.ihe_dropout, max_seq_len=config.sequence_length, )