""" Prediction Model Module ======================== Multi-horizon Transformer-based prediction model. Architecture: PatchTST-inspired with Kronos-style multi-resolution encoding. - Patch embedding for temporal features - Multi-head self-attention across patches - Multi-task heads for direction, return, and uncertainty Key design decisions (from literature): 1. PatchTST (2211.14730): Channel-independent patching reduces O(L²) to O((L/S)²) 2. Chronos (2403.07815): Probabilistic outputs via distributional heads 3. Kronos (2508.02739): Coarse-to-fine hierarchical predictions for financial data 4. iTransformer: Inverted attention on variate dimension """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math from typing import Dict, List, Optional, Tuple class PatchEmbedding(nn.Module): """ PatchTST-style patch embedding for time series. Splits each channel's sequence into overlapping patches, then projects to embedding dimension. """ def __init__(self, patch_len: int = 8, stride: int = 4, d_model: int = 128): super().__init__() self.patch_len = patch_len self.stride = stride self.projection = nn.Linear(patch_len, d_model) self.layer_norm = nn.LayerNorm(d_model) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (batch, channels, seq_len) Returns: patches: (batch, channels, num_patches, d_model) """ B, C, L = x.shape # Pad if necessary pad_len = (self.stride - (L - self.patch_len) % self.stride) % self.stride if pad_len > 0: x = F.pad(x, (0, pad_len), mode='replicate') L = L + pad_len # Unfold into patches: (B, C, num_patches, patch_len) num_patches = (L - self.patch_len) // self.stride + 1 patches = x.unfold(dimension=2, size=self.patch_len, step=self.stride) # Project: (B, C, num_patches, d_model) patches = self.projection(patches) patches = self.layer_norm(patches) return patches class PositionalEncoding(nn.Module): """Learnable positional encoding for patches.""" def __init__(self, d_model: int, max_patches: int = 200): super().__init__() self.pos_embed = nn.Parameter(torch.randn(1, 1, max_patches, d_model) * 0.02) def forward(self, x: torch.Tensor) -> torch.Tensor: """x: (B, C, num_patches, d_model)""" return x + self.pos_embed[:, :, :x.size(2), :] class MultiHeadAttention(nn.Module): """Standard multi-head self-attention.""" def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1): super().__init__() self.n_heads = n_heads self.d_k = d_model // n_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: B, N, D = x.shape Q = self.W_q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) K = self.W_k(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) V = self.W_v(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2) scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attn = F.softmax(scores, dim=-1) attn = self.dropout(attn) out = torch.matmul(attn, V) out = out.transpose(1, 2).contiguous().view(B, N, D) return self.W_o(out) class TransformerBlock(nn.Module): """Transformer encoder block with pre-norm (better for time series per PatchTST).""" def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1): super().__init__() self.attn = MultiHeadAttention(d_model, n_heads, dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.ff = nn.Sequential( nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model), nn.Dropout(dropout) ) def forward(self, x: torch.Tensor) -> torch.Tensor: # Pre-norm attention x = x + self.attn(self.norm1(x)) # Pre-norm FFN x = x + self.ff(self.norm2(x)) return x class ChannelMixer(nn.Module): """ Cross-channel attention for capturing inter-feature dependencies. Inspired by iTransformer - applies attention across variate dimension. """ def __init__(self, num_channels: int, d_model: int, n_heads: int = 4, dropout: float = 0.1): super().__init__() self.channel_attn = MultiHeadAttention(d_model, n_heads, dropout) self.norm = nn.LayerNorm(d_model) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (B, C, num_patches, d_model) Returns: x: (B, C, num_patches, d_model) with cross-channel info """ B, C, N, D = x.shape # Pool across patches for channel representation channel_repr = x.mean(dim=2) # (B, C, D) # Cross-channel attention channel_out = self.channel_attn(self.norm(channel_repr)) # (B, C, D) # Broadcast back and add x = x + channel_out.unsqueeze(2) return x class PredictionHead(nn.Module): """ Multi-task prediction head. Outputs: 1. Direction probability (binary classification per horizon) 2. Expected return (regression per horizon) 3. Uncertainty/confidence (learned aleatoric uncertainty) """ def __init__(self, d_model: int, num_horizons: int = 3, dropout: float = 0.1): super().__init__() self.num_horizons = num_horizons # Shared representation self.shared = nn.Sequential( nn.Linear(d_model, d_model), nn.GELU(), nn.Dropout(dropout), ) # Direction head (classification) self.direction_head = nn.Sequential( nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Linear(d_model // 2, num_horizons), ) # Return prediction head (regression) self.return_head = nn.Sequential( nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Linear(d_model // 2, num_horizons), ) # Uncertainty head (log variance - Gaussian heteroscedastic) self.uncertainty_head = nn.Sequential( nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Linear(d_model // 2, num_horizons), ) def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ Args: x: (B, d_model) - pooled representation Returns: dict with 'direction_logits', 'expected_return', 'log_variance' """ shared = self.shared(x) return { 'direction_logits': self.direction_head(shared), # (B, num_horizons) 'expected_return': self.return_head(shared), # (B, num_horizons) 'log_variance': self.uncertainty_head(shared), # (B, num_horizons) } class TradingTransformer(nn.Module): """ Main prediction model: Patch-based Transformer for multi-horizon trading predictions. Architecture: 1. PatchEmbedding → patches per channel (PatchTST) 2. Intra-channel Transformer blocks (temporal patterns) 3. ChannelMixer (cross-feature dependencies, iTransformer-inspired) 4. Global pooling → PredictionHead (multi-task) Designed to be modular and accept varying numbers of input features. """ def __init__( self, num_channels: int, # Number of input features seq_len: int = 60, # Lookback window patch_len: int = 8, # Patch length stride: int = 4, # Patch stride d_model: int = 128, # Model dimension n_heads: int = 8, # Number of attention heads n_layers: int = 3, # Number of transformer layers d_ff: int = 256, # FFN hidden dimension num_horizons: int = 3, # Number of prediction horizons dropout: float = 0.1, use_channel_mixer: bool = True, ): super().__init__() self.num_channels = num_channels self.seq_len = seq_len self.d_model = d_model self.use_channel_mixer = use_channel_mixer # Instance normalization (PatchTST: mitigate distribution shift) self.instance_norm = nn.InstanceNorm1d(num_channels, affine=True) # Patch embedding self.patch_embed = PatchEmbedding(patch_len, stride, d_model) # Positional encoding self.pos_enc = PositionalEncoding(d_model) # Transformer encoder blocks (channel-independent, per PatchTST) self.transformer_blocks = nn.ModuleList([ TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers) ]) # Channel mixer (optional cross-channel attention) if use_channel_mixer: self.channel_mixer = ChannelMixer(num_channels, d_model, n_heads=4, dropout=dropout) # Global pooling + prediction head self.pool_norm = nn.LayerNorm(d_model) self.prediction_head = PredictionHead(d_model, num_horizons, dropout) # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ Args: x: (batch, num_channels, seq_len) Returns: Dict with 'direction_logits', 'expected_return', 'log_variance' """ B, C, L = x.shape # Instance normalization x = self.instance_norm(x) # Patch embedding: (B, C, num_patches, d_model) x = self.patch_embed(x) x = self.pos_enc(x) # Channel-independent transformer (per PatchTST) B, C, N, D = x.shape x_flat = x.reshape(B * C, N, D) for block in self.transformer_blocks: x_flat = block(x_flat) x = x_flat.reshape(B, C, N, D) # Channel mixing if self.use_channel_mixer: x = self.channel_mixer(x) # Global average pooling across channels and patches x = x.mean(dim=[1, 2]) # (B, D) x = self.pool_norm(x) # Multi-task prediction predictions = self.prediction_head(x) return predictions def predict_with_confidence(self, x: torch.Tensor) -> Dict[str, np.ndarray]: """ Make predictions with calibrated confidence scores. Returns: direction_probs: Probability of up move per horizon expected_returns: Expected return per horizon confidence: Confidence score (0-1) derived from uncertainty """ self.eval() with torch.no_grad(): outputs = self.forward(x) direction_probs = torch.sigmoid(outputs['direction_logits']).cpu().numpy() expected_returns = outputs['expected_return'].cpu().numpy() log_var = outputs['log_variance'].cpu().numpy() # Confidence = 1 / (1 + exp(log_variance)) confidence = 1.0 / (1.0 + np.exp(log_var)) return { 'direction_probs': direction_probs, 'expected_returns': expected_returns, 'confidence': confidence, } class MultiTaskLoss(nn.Module): """ Multi-task loss combining: 1. Direction loss (BCE with logits) 2. Return prediction loss (Gaussian NLL for uncertainty-aware regression) 3. Risk-adjusted loss (Sharpe-like penalty) Uses learned task weights (uncertainty weighting from Kendall et al. 2018). """ def __init__(self, num_horizons: int = 3, alpha_direction: float = 1.0, alpha_return: float = 1.0, alpha_risk: float = 0.5): super().__init__() self.num_horizons = num_horizons self.alpha_direction = alpha_direction self.alpha_return = alpha_return self.alpha_risk = alpha_risk # Learned task uncertainty weights (Kendall et al.) self.log_sigma_direction = nn.Parameter(torch.tensor(0.0)) self.log_sigma_return = nn.Parameter(torch.tensor(0.0)) self.log_sigma_risk = nn.Parameter(torch.tensor(0.0)) def forward(self, predictions: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Args: predictions: model outputs targets: dict with 'direction' (B, H), 'returns' (B, H) """ # Direction loss (BCE) direction_loss = F.binary_cross_entropy_with_logits( predictions['direction_logits'], targets['direction'], reduction='mean' ) # Return prediction loss (Gaussian NLL - heteroscedastic) log_var = predictions['log_variance'] return_loss = 0.5 * ( torch.exp(-log_var) * (predictions['expected_return'] - targets['returns'])**2 + log_var ).mean() # Risk-adjusted loss: penalize predictions that would lead to large drawdowns # Simulates a simple PnL and penalizes negative Sharpe-like ratio pred_returns = predictions['expected_return'] pred_direction = torch.sigmoid(predictions['direction_logits']) simulated_pnl = pred_returns * (2 * pred_direction - 1) # Long if bullish, short if bearish risk_loss = -simulated_pnl.mean() / (simulated_pnl.std() + 1e-8) # Negative Sharpe risk_loss = F.relu(risk_loss) # Only penalize negative Sharpe # Uncertainty-weighted combination total_loss = ( self.alpha_direction * torch.exp(-self.log_sigma_direction) * direction_loss + self.log_sigma_direction + self.alpha_return * torch.exp(-self.log_sigma_return) * return_loss + self.log_sigma_return + self.alpha_risk * torch.exp(-self.log_sigma_risk) * risk_loss + self.log_sigma_risk ) return { 'total_loss': total_loss, 'direction_loss': direction_loss, 'return_loss': return_loss, 'risk_loss': risk_loss, }