File size: 3,512 Bytes
d5b7ee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""AI Fusion Model — BitNet-Transformer for hybrid trading signal generation."""

import torch
import torch.nn as nn
from .bitlinear import BitLinear, BitRMSNorm


class BitNetAttention(nn.Module):
    """Multi-head Attention with ternary-quantized projections."""
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        
        self.q_proj = BitLinear(d_model, d_model)
        self.k_proj = BitLinear(d_model, d_model)
        self.v_proj = BitLinear(d_model, d_model)
        self.out_proj = BitLinear(d_model, d_model)
        
    def forward(self, x):
        B, T, C = x.shape
        # Ternary projections
        q = self.q_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        
        # Scaled dot-product attention
        attn = (q @ k.transpose(-2, -1)) * (self.d_head ** -0.5)
        attn = torch.softmax(attn, dim=-1)
        
        out = (attn @ v).transpose(1, 2).reshape(B, T, C)
        return self.out_proj(out)


class BitNetTransformerLayer(nn.Module):
    """Single Transformer Encoder layer with BitNet components."""
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.norm1 = BitRMSNorm(d_model)
        self.attn = BitNetAttention(d_model, n_heads)
        self.norm2 = BitRMSNorm(d_model)
        self.ffn = nn.Sequential(
            BitLinear(d_model, d_ff),
            nn.SiLU(),
            BitLinear(d_ff, d_model)
        )
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x


class BitNetTransformer(nn.Module):
    """
    High-capacity sequence model for market pattern recognition.
    Utilizes 1.58-bit (ternary) weights for all projections.
    """
    def __init__(self, input_dim=9, d_model=512, n_heads=8, n_layers=6, seq_len=30):
        super().__init__()
        self.input_proj = BitLinear(input_dim, d_model)
        self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, d_model))
        
        self.layers = nn.ModuleList([
            BitNetTransformerLayer(d_model, n_heads, d_model * 4)
            for _ in range(n_layers)
        ])
        
        self.norm = BitRMSNorm(d_model)
        self.head = BitLinear(d_model, 3) # 0=HOLD, 1=BUY, 2=SELL
        
    def forward(self, x):
        """
        Input x: [batch, seq_len, input_dim]
        Output: Logits [batch, 3]
        """
        # Embed and add positional information
        x = self.input_proj(x) + self.pos_embed
        
        for layer in self.layers:
            x = layer(x)
        
        # Decision based on the most recent state
        x = self.norm(x[:, -1, :])
        return self.head(x)

    @torch.no_grad()
    def predict_action(self, x):
        """Perform inference on a sequence and return discrete action."""
        logits = self.forward(x)
        probs = torch.softmax(logits, dim=-1)
        return torch.argmax(probs, dim=-1).item()


def create_model(input_dim=9, hidden_dim=512, output_dim=3, layers=6, seq_len=30):
    """Helper to instantiate the SOTA BitNet Transformer."""
    return BitNetTransformer(
        input_dim=input_dim, 
        d_model=hidden_dim, 
        n_layers=layers, 
        seq_len=seq_len
    )