Spaces:
Running on Zero
Running on Zero
| """Information Horizon Encoder - Causal transformer with linear attention.""" | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Dict, Any, List | |
| from manifold.models.layers.attention import MultiHeadLinearAttention, RotaryPositionEncoding | |
| class IHEBlock(nn.Module): | |
| """ | |
| Single IHE transformer block with linear attention + FFN. | |
| Uses pre-norm architecture for training stability. | |
| """ | |
| def __init__( | |
| self, | |
| embed_dim: int = 256, | |
| num_heads: int = 8, | |
| ff_dim: int = 1024, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(embed_dim) | |
| self.norm2 = nn.LayerNorm(embed_dim) | |
| self.attention = MultiHeadLinearAttention( | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| causal=True, | |
| use_rotary=True, | |
| ) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(embed_dim, ff_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(ff_dim, embed_dim), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Forward pass through transformer block. | |
| Args: | |
| x: Input tensor [batch, seq, embed_dim] | |
| mask: Optional attention mask [batch, seq] | |
| Returns: | |
| Dict with 'output' and 'attention_weights' | |
| """ | |
| normed = self.norm1(x) | |
| attn_out = self.attention(normed, mask=mask) | |
| x = x + attn_out["output"] | |
| normed = self.norm2(x) | |
| x = x + self.ffn(normed) | |
| return { | |
| "output": x, | |
| "attention_weights": None, | |
| } | |
| class InformationHorizonEncoder(nn.Module): | |
| """ | |
| Multi-layer causal transformer for encoding player action sequences. | |
| Uses linear attention O(T) and rotary position encoding. | |
| Causal masking ensures actions can't see future information. | |
| """ | |
| def __init__( | |
| self, | |
| embed_dim: int = 256, | |
| num_layers: int = 4, | |
| num_heads: int = 8, | |
| ff_dim: int = 1024, | |
| dropout: float = 0.1, | |
| max_seq_len: int = 128, | |
| ): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.num_layers = num_layers | |
| self.num_heads = num_heads | |
| self.max_seq_len = max_seq_len | |
| head_dim = embed_dim // num_heads | |
| self.pos_encoding = RotaryPositionEncoding( | |
| dim=head_dim, | |
| max_seq_len=max_seq_len, | |
| ) | |
| self.layers = nn.ModuleList([ | |
| IHEBlock( | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| ff_dim=ff_dim, | |
| dropout=dropout, | |
| ) | |
| for _ in range(num_layers) | |
| ]) | |
| self.final_norm = nn.LayerNorm(embed_dim) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Encode action sequence through causal transformer layers. | |
| Args: | |
| x: Input tensor [batch, seq, embed_dim] | |
| mask: Optional attention mask [batch, seq] | |
| Returns: | |
| Dict with 'encoding' and 'all_layer_outputs' | |
| """ | |
| all_layer_outputs: List[torch.Tensor] = [] | |
| for layer in self.layers: | |
| layer_out = layer(x, mask=mask) | |
| x = layer_out["output"] | |
| all_layer_outputs.append(x) | |
| encoding = self.final_norm(x) | |
| return { | |
| "encoding": encoding, | |
| "all_layer_outputs": all_layer_outputs, | |
| } | |
| def from_config(cls, config: Any) -> "InformationHorizonEncoder": | |
| """ | |
| Create InformationHorizonEncoder from ModelConfig. | |
| Args: | |
| config: ModelConfig instance with IHE parameters | |
| Returns: | |
| Configured InformationHorizonEncoder instance | |
| """ | |
| return cls( | |
| embed_dim=config.embed_dim, | |
| num_layers=config.ihe_layers, | |
| num_heads=config.ihe_heads, | |
| ff_dim=config.ihe_ff_dim, | |
| dropout=config.ihe_dropout, | |
| max_seq_len=config.sequence_length, | |
| ) | |