avinashhm's picture
Fix: trading_intelligence/prediction_model.py - all 174 tests passing
00ee97b verified
"""
Prediction Model Module
========================
Multi-horizon Transformer-based prediction model.
Architecture: PatchTST-inspired with Kronos-style multi-resolution encoding.
- Patch embedding for temporal features
- Multi-head self-attention across patches
- Multi-task heads for direction, return, and uncertainty
Key design decisions (from literature):
1. PatchTST (2211.14730): Channel-independent patching reduces O(L²) to O((L/S)²)
2. Chronos (2403.07815): Probabilistic outputs via distributional heads
3. Kronos (2508.02739): Coarse-to-fine hierarchical predictions for financial data
4. iTransformer: Inverted attention on variate dimension
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Dict, List, Optional, Tuple
class PatchEmbedding(nn.Module):
"""
PatchTST-style patch embedding for time series.
Splits each channel's sequence into overlapping patches,
then projects to embedding dimension.
"""
def __init__(self, patch_len: int = 8, stride: int = 4, d_model: int = 128):
super().__init__()
self.patch_len = patch_len
self.stride = stride
self.projection = nn.Linear(patch_len, d_model)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, channels, seq_len)
Returns:
patches: (batch, channels, num_patches, d_model)
"""
B, C, L = x.shape
# Pad if necessary
pad_len = (self.stride - (L - self.patch_len) % self.stride) % self.stride
if pad_len > 0:
x = F.pad(x, (0, pad_len), mode='replicate')
L = L + pad_len
# Unfold into patches: (B, C, num_patches, patch_len)
num_patches = (L - self.patch_len) // self.stride + 1
patches = x.unfold(dimension=2, size=self.patch_len, step=self.stride)
# Project: (B, C, num_patches, d_model)
patches = self.projection(patches)
patches = self.layer_norm(patches)
return patches
class PositionalEncoding(nn.Module):
"""Learnable positional encoding for patches."""
def __init__(self, d_model: int, max_patches: int = 200):
super().__init__()
self.pos_embed = nn.Parameter(torch.randn(1, 1, max_patches, d_model) * 0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x: (B, C, num_patches, d_model)"""
return x + self.pos_embed[:, :, :x.size(2), :]
class MultiHeadAttention(nn.Module):
"""Standard multi-head self-attention."""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, N, D = x.shape
Q = self.W_q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, V)
out = out.transpose(1, 2).contiguous().view(B, N, D)
return self.W_o(out)
class TransformerBlock(nn.Module):
"""Transformer encoder block with pre-norm (better for time series per PatchTST)."""
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_heads, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Pre-norm attention
x = x + self.attn(self.norm1(x))
# Pre-norm FFN
x = x + self.ff(self.norm2(x))
return x
class ChannelMixer(nn.Module):
"""
Cross-channel attention for capturing inter-feature dependencies.
Inspired by iTransformer - applies attention across variate dimension.
"""
def __init__(self, num_channels: int, d_model: int, n_heads: int = 4, dropout: float = 0.1):
super().__init__()
self.channel_attn = MultiHeadAttention(d_model, n_heads, dropout)
self.norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C, num_patches, d_model)
Returns:
x: (B, C, num_patches, d_model) with cross-channel info
"""
B, C, N, D = x.shape
# Pool across patches for channel representation
channel_repr = x.mean(dim=2) # (B, C, D)
# Cross-channel attention
channel_out = self.channel_attn(self.norm(channel_repr)) # (B, C, D)
# Broadcast back and add
x = x + channel_out.unsqueeze(2)
return x
class PredictionHead(nn.Module):
"""
Multi-task prediction head.
Outputs:
1. Direction probability (binary classification per horizon)
2. Expected return (regression per horizon)
3. Uncertainty/confidence (learned aleatoric uncertainty)
"""
def __init__(self, d_model: int, num_horizons: int = 3, dropout: float = 0.1):
super().__init__()
self.num_horizons = num_horizons
# Shared representation
self.shared = nn.Sequential(
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Dropout(dropout),
)
# Direction head (classification)
self.direction_head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, num_horizons),
)
# Return prediction head (regression)
self.return_head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, num_horizons),
)
# Uncertainty head (log variance - Gaussian heteroscedastic)
self.uncertainty_head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, num_horizons),
)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Args:
x: (B, d_model) - pooled representation
Returns:
dict with 'direction_logits', 'expected_return', 'log_variance'
"""
shared = self.shared(x)
return {
'direction_logits': self.direction_head(shared), # (B, num_horizons)
'expected_return': self.return_head(shared), # (B, num_horizons)
'log_variance': self.uncertainty_head(shared), # (B, num_horizons)
}
class TradingTransformer(nn.Module):
"""
Main prediction model: Patch-based Transformer for multi-horizon trading predictions.
Architecture:
1. PatchEmbedding → patches per channel (PatchTST)
2. Intra-channel Transformer blocks (temporal patterns)
3. ChannelMixer (cross-feature dependencies, iTransformer-inspired)
4. Global pooling → PredictionHead (multi-task)
Designed to be modular and accept varying numbers of input features.
"""
def __init__(
self,
num_channels: int, # Number of input features
seq_len: int = 60, # Lookback window
patch_len: int = 8, # Patch length
stride: int = 4, # Patch stride
d_model: int = 128, # Model dimension
n_heads: int = 8, # Number of attention heads
n_layers: int = 3, # Number of transformer layers
d_ff: int = 256, # FFN hidden dimension
num_horizons: int = 3, # Number of prediction horizons
dropout: float = 0.1,
use_channel_mixer: bool = True,
):
super().__init__()
self.num_channels = num_channels
self.seq_len = seq_len
self.d_model = d_model
self.use_channel_mixer = use_channel_mixer
# Instance normalization (PatchTST: mitigate distribution shift)
self.instance_norm = nn.InstanceNorm1d(num_channels, affine=True)
# Patch embedding
self.patch_embed = PatchEmbedding(patch_len, stride, d_model)
# Positional encoding
self.pos_enc = PositionalEncoding(d_model)
# Transformer encoder blocks (channel-independent, per PatchTST)
self.transformer_blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
# Channel mixer (optional cross-channel attention)
if use_channel_mixer:
self.channel_mixer = ChannelMixer(num_channels, d_model, n_heads=4, dropout=dropout)
# Global pooling + prediction head
self.pool_norm = nn.LayerNorm(d_model)
self.prediction_head = PredictionHead(d_model, num_horizons, dropout)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Args:
x: (batch, num_channels, seq_len)
Returns:
Dict with 'direction_logits', 'expected_return', 'log_variance'
"""
B, C, L = x.shape
# Instance normalization
x = self.instance_norm(x)
# Patch embedding: (B, C, num_patches, d_model)
x = self.patch_embed(x)
x = self.pos_enc(x)
# Channel-independent transformer (per PatchTST)
B, C, N, D = x.shape
x_flat = x.reshape(B * C, N, D)
for block in self.transformer_blocks:
x_flat = block(x_flat)
x = x_flat.reshape(B, C, N, D)
# Channel mixing
if self.use_channel_mixer:
x = self.channel_mixer(x)
# Global average pooling across channels and patches
x = x.mean(dim=[1, 2]) # (B, D)
x = self.pool_norm(x)
# Multi-task prediction
predictions = self.prediction_head(x)
return predictions
def predict_with_confidence(self, x: torch.Tensor) -> Dict[str, np.ndarray]:
"""
Make predictions with calibrated confidence scores.
Returns:
direction_probs: Probability of up move per horizon
expected_returns: Expected return per horizon
confidence: Confidence score (0-1) derived from uncertainty
"""
self.eval()
with torch.no_grad():
outputs = self.forward(x)
direction_probs = torch.sigmoid(outputs['direction_logits']).cpu().numpy()
expected_returns = outputs['expected_return'].cpu().numpy()
log_var = outputs['log_variance'].cpu().numpy()
# Confidence = 1 / (1 + exp(log_variance))
confidence = 1.0 / (1.0 + np.exp(log_var))
return {
'direction_probs': direction_probs,
'expected_returns': expected_returns,
'confidence': confidence,
}
class MultiTaskLoss(nn.Module):
"""
Multi-task loss combining:
1. Direction loss (BCE with logits)
2. Return prediction loss (Gaussian NLL for uncertainty-aware regression)
3. Risk-adjusted loss (Sharpe-like penalty)
Uses learned task weights (uncertainty weighting from Kendall et al. 2018).
"""
def __init__(self, num_horizons: int = 3, alpha_direction: float = 1.0,
alpha_return: float = 1.0, alpha_risk: float = 0.5):
super().__init__()
self.num_horizons = num_horizons
self.alpha_direction = alpha_direction
self.alpha_return = alpha_return
self.alpha_risk = alpha_risk
# Learned task uncertainty weights (Kendall et al.)
self.log_sigma_direction = nn.Parameter(torch.tensor(0.0))
self.log_sigma_return = nn.Parameter(torch.tensor(0.0))
self.log_sigma_risk = nn.Parameter(torch.tensor(0.0))
def forward(self, predictions: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Args:
predictions: model outputs
targets: dict with 'direction' (B, H), 'returns' (B, H)
"""
# Direction loss (BCE)
direction_loss = F.binary_cross_entropy_with_logits(
predictions['direction_logits'], targets['direction'],
reduction='mean'
)
# Return prediction loss (Gaussian NLL - heteroscedastic)
log_var = predictions['log_variance']
return_loss = 0.5 * (
torch.exp(-log_var) * (predictions['expected_return'] - targets['returns'])**2
+ log_var
).mean()
# Risk-adjusted loss: penalize predictions that would lead to large drawdowns
# Simulates a simple PnL and penalizes negative Sharpe-like ratio
pred_returns = predictions['expected_return']
pred_direction = torch.sigmoid(predictions['direction_logits'])
simulated_pnl = pred_returns * (2 * pred_direction - 1) # Long if bullish, short if bearish
risk_loss = -simulated_pnl.mean() / (simulated_pnl.std() + 1e-8) # Negative Sharpe
risk_loss = F.relu(risk_loss) # Only penalize negative Sharpe
# Uncertainty-weighted combination
total_loss = (
self.alpha_direction * torch.exp(-self.log_sigma_direction) * direction_loss
+ self.log_sigma_direction
+ self.alpha_return * torch.exp(-self.log_sigma_return) * return_loss
+ self.log_sigma_return
+ self.alpha_risk * torch.exp(-self.log_sigma_risk) * risk_loss
+ self.log_sigma_risk
)
return {
'total_loss': total_loss,
'direction_loss': direction_loss,
'return_loss': return_loss,
'risk_loss': risk_loss,
}