"""AI Fusion Model — BitNet-Transformer for hybrid trading signal generation.""" import torch import torch.nn as nn from .bitlinear import BitLinear, BitRMSNorm class BitNetAttention(nn.Module): """Multi-head Attention with ternary-quantized projections.""" def __init__(self, d_model, n_heads): super().__init__() assert d_model % n_heads == 0 self.n_heads = n_heads self.d_head = d_model // n_heads self.q_proj = BitLinear(d_model, d_model) self.k_proj = BitLinear(d_model, d_model) self.v_proj = BitLinear(d_model, d_model) self.out_proj = BitLinear(d_model, d_model) def forward(self, x): B, T, C = x.shape # Ternary projections q = self.q_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2) # Scaled dot-product attention attn = (q @ k.transpose(-2, -1)) * (self.d_head ** -0.5) attn = torch.softmax(attn, dim=-1) out = (attn @ v).transpose(1, 2).reshape(B, T, C) return self.out_proj(out) class BitNetTransformerLayer(nn.Module): """Single Transformer Encoder layer with BitNet components.""" def __init__(self, d_model, n_heads, d_ff): super().__init__() self.norm1 = BitRMSNorm(d_model) self.attn = BitNetAttention(d_model, n_heads) self.norm2 = BitRMSNorm(d_model) self.ffn = nn.Sequential( BitLinear(d_model, d_ff), nn.SiLU(), BitLinear(d_ff, d_model) ) def forward(self, x): x = x + self.attn(self.norm1(x)) x = x + self.ffn(self.norm2(x)) return x class BitNetTransformer(nn.Module): """ High-capacity sequence model for market pattern recognition. Utilizes 1.58-bit (ternary) weights for all projections. """ def __init__(self, input_dim=9, d_model=512, n_heads=8, n_layers=6, seq_len=30): super().__init__() self.input_proj = BitLinear(input_dim, d_model) self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, d_model)) self.layers = nn.ModuleList([ BitNetTransformerLayer(d_model, n_heads, d_model * 4) for _ in range(n_layers) ]) self.norm = BitRMSNorm(d_model) self.head = BitLinear(d_model, 3) # 0=HOLD, 1=BUY, 2=SELL def forward(self, x): """ Input x: [batch, seq_len, input_dim] Output: Logits [batch, 3] """ # Embed and add positional information x = self.input_proj(x) + self.pos_embed for layer in self.layers: x = layer(x) # Decision based on the most recent state x = self.norm(x[:, -1, :]) return self.head(x) @torch.no_grad() def predict_action(self, x): """Perform inference on a sequence and return discrete action.""" logits = self.forward(x) probs = torch.softmax(logits, dim=-1) return torch.argmax(probs, dim=-1).item() def create_model(input_dim=9, hidden_dim=512, output_dim=3, layers=6, seq_len=30): """Helper to instantiate the SOTA BitNet Transformer.""" return BitNetTransformer( input_dim=input_dim, d_model=hidden_dim, n_layers=layers, seq_len=seq_len )