| """ |
| Prediction Model Module |
| ======================== |
| Multi-horizon Transformer-based prediction model. |
| |
| Architecture: PatchTST-inspired with Kronos-style multi-resolution encoding. |
| - Patch embedding for temporal features |
| - Multi-head self-attention across patches |
| - Multi-task heads for direction, return, and uncertainty |
| |
| Key design decisions (from literature): |
| 1. PatchTST (2211.14730): Channel-independent patching reduces O(L²) to O((L/S)²) |
| 2. Chronos (2403.07815): Probabilistic outputs via distributional heads |
| 3. Kronos (2508.02739): Coarse-to-fine hierarchical predictions for financial data |
| 4. iTransformer: Inverted attention on variate dimension |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import math |
| from typing import Dict, List, Optional, Tuple |
|
|
|
|
| class PatchEmbedding(nn.Module): |
| """ |
| PatchTST-style patch embedding for time series. |
| |
| Splits each channel's sequence into overlapping patches, |
| then projects to embedding dimension. |
| """ |
| |
| def __init__(self, patch_len: int = 8, stride: int = 4, d_model: int = 128): |
| super().__init__() |
| self.patch_len = patch_len |
| self.stride = stride |
| self.projection = nn.Linear(patch_len, d_model) |
| self.layer_norm = nn.LayerNorm(d_model) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: (batch, channels, seq_len) |
| Returns: |
| patches: (batch, channels, num_patches, d_model) |
| """ |
| B, C, L = x.shape |
| |
| |
| pad_len = (self.stride - (L - self.patch_len) % self.stride) % self.stride |
| if pad_len > 0: |
| x = F.pad(x, (0, pad_len), mode='replicate') |
| L = L + pad_len |
| |
| |
| num_patches = (L - self.patch_len) // self.stride + 1 |
| patches = x.unfold(dimension=2, size=self.patch_len, step=self.stride) |
| |
| |
| patches = self.projection(patches) |
| patches = self.layer_norm(patches) |
| |
| return patches |
|
|
|
|
| class PositionalEncoding(nn.Module): |
| """Learnable positional encoding for patches.""" |
| |
| def __init__(self, d_model: int, max_patches: int = 200): |
| super().__init__() |
| self.pos_embed = nn.Parameter(torch.randn(1, 1, max_patches, d_model) * 0.02) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """x: (B, C, num_patches, d_model)""" |
| return x + self.pos_embed[:, :, :x.size(2), :] |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| """Standard multi-head self-attention.""" |
| |
| def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1): |
| super().__init__() |
| self.n_heads = n_heads |
| self.d_k = d_model // n_heads |
| |
| self.W_q = nn.Linear(d_model, d_model) |
| self.W_k = nn.Linear(d_model, d_model) |
| self.W_v = nn.Linear(d_model, d_model) |
| self.W_o = nn.Linear(d_model, d_model) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| B, N, D = x.shape |
| |
| Q = self.W_q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) |
| K = self.W_k(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) |
| V = self.W_v(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) |
| |
| scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) |
| if mask is not None: |
| scores = scores.masked_fill(mask == 0, -1e9) |
| |
| attn = F.softmax(scores, dim=-1) |
| attn = self.dropout(attn) |
| |
| out = torch.matmul(attn, V) |
| out = out.transpose(1, 2).contiguous().view(B, N, D) |
| return self.W_o(out) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """Transformer encoder block with pre-norm (better for time series per PatchTST).""" |
| |
| def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1): |
| super().__init__() |
| self.attn = MultiHeadAttention(d_model, n_heads, dropout) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.ff = nn.Sequential( |
| nn.Linear(d_model, d_ff), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(d_ff, d_model), |
| nn.Dropout(dropout) |
| ) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x = x + self.attn(self.norm1(x)) |
| |
| x = x + self.ff(self.norm2(x)) |
| return x |
|
|
|
|
| class ChannelMixer(nn.Module): |
| """ |
| Cross-channel attention for capturing inter-feature dependencies. |
| Inspired by iTransformer - applies attention across variate dimension. |
| """ |
| |
| def __init__(self, num_channels: int, d_model: int, n_heads: int = 4, dropout: float = 0.1): |
| super().__init__() |
| self.channel_attn = MultiHeadAttention(d_model, n_heads, dropout) |
| self.norm = nn.LayerNorm(d_model) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: (B, C, num_patches, d_model) |
| Returns: |
| x: (B, C, num_patches, d_model) with cross-channel info |
| """ |
| B, C, N, D = x.shape |
| |
| |
| channel_repr = x.mean(dim=2) |
| |
| |
| channel_out = self.channel_attn(self.norm(channel_repr)) |
| |
| |
| x = x + channel_out.unsqueeze(2) |
| |
| return x |
|
|
|
|
| class PredictionHead(nn.Module): |
| """ |
| Multi-task prediction head. |
| |
| Outputs: |
| 1. Direction probability (binary classification per horizon) |
| 2. Expected return (regression per horizon) |
| 3. Uncertainty/confidence (learned aleatoric uncertainty) |
| """ |
| |
| def __init__(self, d_model: int, num_horizons: int = 3, dropout: float = 0.1): |
| super().__init__() |
| self.num_horizons = num_horizons |
| |
| |
| self.shared = nn.Sequential( |
| nn.Linear(d_model, d_model), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| ) |
| |
| |
| self.direction_head = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), |
| nn.GELU(), |
| nn.Linear(d_model // 2, num_horizons), |
| ) |
| |
| |
| self.return_head = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), |
| nn.GELU(), |
| nn.Linear(d_model // 2, num_horizons), |
| ) |
| |
| |
| self.uncertainty_head = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), |
| nn.GELU(), |
| nn.Linear(d_model // 2, num_horizons), |
| ) |
| |
| def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: |
| """ |
| Args: |
| x: (B, d_model) - pooled representation |
| Returns: |
| dict with 'direction_logits', 'expected_return', 'log_variance' |
| """ |
| shared = self.shared(x) |
| |
| return { |
| 'direction_logits': self.direction_head(shared), |
| 'expected_return': self.return_head(shared), |
| 'log_variance': self.uncertainty_head(shared), |
| } |
|
|
|
|
| class TradingTransformer(nn.Module): |
| """ |
| Main prediction model: Patch-based Transformer for multi-horizon trading predictions. |
| |
| Architecture: |
| 1. PatchEmbedding → patches per channel (PatchTST) |
| 2. Intra-channel Transformer blocks (temporal patterns) |
| 3. ChannelMixer (cross-feature dependencies, iTransformer-inspired) |
| 4. Global pooling → PredictionHead (multi-task) |
| |
| Designed to be modular and accept varying numbers of input features. |
| """ |
| |
| def __init__( |
| self, |
| num_channels: int, |
| seq_len: int = 60, |
| patch_len: int = 8, |
| stride: int = 4, |
| d_model: int = 128, |
| n_heads: int = 8, |
| n_layers: int = 3, |
| d_ff: int = 256, |
| num_horizons: int = 3, |
| dropout: float = 0.1, |
| use_channel_mixer: bool = True, |
| ): |
| super().__init__() |
| |
| self.num_channels = num_channels |
| self.seq_len = seq_len |
| self.d_model = d_model |
| self.use_channel_mixer = use_channel_mixer |
| |
| |
| self.instance_norm = nn.InstanceNorm1d(num_channels, affine=True) |
| |
| |
| self.patch_embed = PatchEmbedding(patch_len, stride, d_model) |
| |
| |
| self.pos_enc = PositionalEncoding(d_model) |
| |
| |
| self.transformer_blocks = nn.ModuleList([ |
| TransformerBlock(d_model, n_heads, d_ff, dropout) |
| for _ in range(n_layers) |
| ]) |
| |
| |
| if use_channel_mixer: |
| self.channel_mixer = ChannelMixer(num_channels, d_model, n_heads=4, dropout=dropout) |
| |
| |
| self.pool_norm = nn.LayerNorm(d_model) |
| self.prediction_head = PredictionHead(d_model, num_horizons, dropout) |
| |
| |
| self.apply(self._init_weights) |
| |
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| |
| def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: |
| """ |
| Args: |
| x: (batch, num_channels, seq_len) |
| Returns: |
| Dict with 'direction_logits', 'expected_return', 'log_variance' |
| """ |
| B, C, L = x.shape |
| |
| |
| x = self.instance_norm(x) |
| |
| |
| x = self.patch_embed(x) |
| x = self.pos_enc(x) |
| |
| |
| B, C, N, D = x.shape |
| x_flat = x.reshape(B * C, N, D) |
| |
| for block in self.transformer_blocks: |
| x_flat = block(x_flat) |
| |
| x = x_flat.reshape(B, C, N, D) |
| |
| |
| if self.use_channel_mixer: |
| x = self.channel_mixer(x) |
| |
| |
| x = x.mean(dim=[1, 2]) |
| x = self.pool_norm(x) |
| |
| |
| predictions = self.prediction_head(x) |
| |
| return predictions |
| |
| def predict_with_confidence(self, x: torch.Tensor) -> Dict[str, np.ndarray]: |
| """ |
| Make predictions with calibrated confidence scores. |
| |
| Returns: |
| direction_probs: Probability of up move per horizon |
| expected_returns: Expected return per horizon |
| confidence: Confidence score (0-1) derived from uncertainty |
| """ |
| self.eval() |
| with torch.no_grad(): |
| outputs = self.forward(x) |
| |
| direction_probs = torch.sigmoid(outputs['direction_logits']).cpu().numpy() |
| expected_returns = outputs['expected_return'].cpu().numpy() |
| log_var = outputs['log_variance'].cpu().numpy() |
| |
| |
| confidence = 1.0 / (1.0 + np.exp(log_var)) |
| |
| return { |
| 'direction_probs': direction_probs, |
| 'expected_returns': expected_returns, |
| 'confidence': confidence, |
| } |
|
|
|
|
| class MultiTaskLoss(nn.Module): |
| """ |
| Multi-task loss combining: |
| 1. Direction loss (BCE with logits) |
| 2. Return prediction loss (Gaussian NLL for uncertainty-aware regression) |
| 3. Risk-adjusted loss (Sharpe-like penalty) |
| |
| Uses learned task weights (uncertainty weighting from Kendall et al. 2018). |
| """ |
| |
| def __init__(self, num_horizons: int = 3, alpha_direction: float = 1.0, |
| alpha_return: float = 1.0, alpha_risk: float = 0.5): |
| super().__init__() |
| self.num_horizons = num_horizons |
| self.alpha_direction = alpha_direction |
| self.alpha_return = alpha_return |
| self.alpha_risk = alpha_risk |
| |
| |
| self.log_sigma_direction = nn.Parameter(torch.tensor(0.0)) |
| self.log_sigma_return = nn.Parameter(torch.tensor(0.0)) |
| self.log_sigma_risk = nn.Parameter(torch.tensor(0.0)) |
| |
| def forward(self, predictions: Dict[str, torch.Tensor], |
| targets: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| """ |
| Args: |
| predictions: model outputs |
| targets: dict with 'direction' (B, H), 'returns' (B, H) |
| """ |
| |
| direction_loss = F.binary_cross_entropy_with_logits( |
| predictions['direction_logits'], targets['direction'], |
| reduction='mean' |
| ) |
| |
| |
| log_var = predictions['log_variance'] |
| return_loss = 0.5 * ( |
| torch.exp(-log_var) * (predictions['expected_return'] - targets['returns'])**2 |
| + log_var |
| ).mean() |
| |
| |
| |
| pred_returns = predictions['expected_return'] |
| pred_direction = torch.sigmoid(predictions['direction_logits']) |
| simulated_pnl = pred_returns * (2 * pred_direction - 1) |
| risk_loss = -simulated_pnl.mean() / (simulated_pnl.std() + 1e-8) |
| risk_loss = F.relu(risk_loss) |
| |
| |
| total_loss = ( |
| self.alpha_direction * torch.exp(-self.log_sigma_direction) * direction_loss |
| + self.log_sigma_direction |
| + self.alpha_return * torch.exp(-self.log_sigma_return) * return_loss |
| + self.log_sigma_return |
| + self.alpha_risk * torch.exp(-self.log_sigma_risk) * risk_loss |
| + self.log_sigma_risk |
| ) |
| |
| return { |
| 'total_loss': total_loss, |
| 'direction_loss': direction_loss, |
| 'return_loss': return_loss, |
| 'risk_loss': risk_loss, |
| } |
|
|