Tarang_v2 / model.py
unknownfriend00007's picture
Create model.py
85a1a1d verified
import torch
import torch.nn as nn
class TinyTransformerForecaster(nn.Module):
"""
Small Transformer encoder for time-series (CPU-friendly).
Input: (B, T, F=4)
Output: (B, 1) predicted log-return
"""
def __init__(self, feature_dim=4, d_model=64, nhead=4, num_layers=2, dropout=0.1):
super().__init__()
self.in_proj = nn.Linear(feature_dim, d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=4 * d_model,
dropout=dropout,
batch_first=True,
activation="gelu",
norm_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.head = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, 1),
)
def forward(self, x):
h = self.in_proj(x)
h = self.encoder(h)
last = h[:, -1, :]
return self.head(last)