File size: 15,207 Bytes
fbc0d3d 00ee97b fbc0d3d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 | """
Prediction Model Module
========================
Multi-horizon Transformer-based prediction model.
Architecture: PatchTST-inspired with Kronos-style multi-resolution encoding.
- Patch embedding for temporal features
- Multi-head self-attention across patches
- Multi-task heads for direction, return, and uncertainty
Key design decisions (from literature):
1. PatchTST (2211.14730): Channel-independent patching reduces O(L²) to O((L/S)²)
2. Chronos (2403.07815): Probabilistic outputs via distributional heads
3. Kronos (2508.02739): Coarse-to-fine hierarchical predictions for financial data
4. iTransformer: Inverted attention on variate dimension
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Dict, List, Optional, Tuple
class PatchEmbedding(nn.Module):
"""
PatchTST-style patch embedding for time series.
Splits each channel's sequence into overlapping patches,
then projects to embedding dimension.
"""
def __init__(self, patch_len: int = 8, stride: int = 4, d_model: int = 128):
super().__init__()
self.patch_len = patch_len
self.stride = stride
self.projection = nn.Linear(patch_len, d_model)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, channels, seq_len)
Returns:
patches: (batch, channels, num_patches, d_model)
"""
B, C, L = x.shape
# Pad if necessary
pad_len = (self.stride - (L - self.patch_len) % self.stride) % self.stride
if pad_len > 0:
x = F.pad(x, (0, pad_len), mode='replicate')
L = L + pad_len
# Unfold into patches: (B, C, num_patches, patch_len)
num_patches = (L - self.patch_len) // self.stride + 1
patches = x.unfold(dimension=2, size=self.patch_len, step=self.stride)
# Project: (B, C, num_patches, d_model)
patches = self.projection(patches)
patches = self.layer_norm(patches)
return patches
class PositionalEncoding(nn.Module):
"""Learnable positional encoding for patches."""
def __init__(self, d_model: int, max_patches: int = 200):
super().__init__()
self.pos_embed = nn.Parameter(torch.randn(1, 1, max_patches, d_model) * 0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x: (B, C, num_patches, d_model)"""
return x + self.pos_embed[:, :, :x.size(2), :]
class MultiHeadAttention(nn.Module):
"""Standard multi-head self-attention."""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, N, D = x.shape
Q = self.W_q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, V)
out = out.transpose(1, 2).contiguous().view(B, N, D)
return self.W_o(out)
class TransformerBlock(nn.Module):
"""Transformer encoder block with pre-norm (better for time series per PatchTST)."""
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_heads, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Pre-norm attention
x = x + self.attn(self.norm1(x))
# Pre-norm FFN
x = x + self.ff(self.norm2(x))
return x
class ChannelMixer(nn.Module):
"""
Cross-channel attention for capturing inter-feature dependencies.
Inspired by iTransformer - applies attention across variate dimension.
"""
def __init__(self, num_channels: int, d_model: int, n_heads: int = 4, dropout: float = 0.1):
super().__init__()
self.channel_attn = MultiHeadAttention(d_model, n_heads, dropout)
self.norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C, num_patches, d_model)
Returns:
x: (B, C, num_patches, d_model) with cross-channel info
"""
B, C, N, D = x.shape
# Pool across patches for channel representation
channel_repr = x.mean(dim=2) # (B, C, D)
# Cross-channel attention
channel_out = self.channel_attn(self.norm(channel_repr)) # (B, C, D)
# Broadcast back and add
x = x + channel_out.unsqueeze(2)
return x
class PredictionHead(nn.Module):
"""
Multi-task prediction head.
Outputs:
1. Direction probability (binary classification per horizon)
2. Expected return (regression per horizon)
3. Uncertainty/confidence (learned aleatoric uncertainty)
"""
def __init__(self, d_model: int, num_horizons: int = 3, dropout: float = 0.1):
super().__init__()
self.num_horizons = num_horizons
# Shared representation
self.shared = nn.Sequential(
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Dropout(dropout),
)
# Direction head (classification)
self.direction_head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, num_horizons),
)
# Return prediction head (regression)
self.return_head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, num_horizons),
)
# Uncertainty head (log variance - Gaussian heteroscedastic)
self.uncertainty_head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, num_horizons),
)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Args:
x: (B, d_model) - pooled representation
Returns:
dict with 'direction_logits', 'expected_return', 'log_variance'
"""
shared = self.shared(x)
return {
'direction_logits': self.direction_head(shared), # (B, num_horizons)
'expected_return': self.return_head(shared), # (B, num_horizons)
'log_variance': self.uncertainty_head(shared), # (B, num_horizons)
}
class TradingTransformer(nn.Module):
"""
Main prediction model: Patch-based Transformer for multi-horizon trading predictions.
Architecture:
1. PatchEmbedding → patches per channel (PatchTST)
2. Intra-channel Transformer blocks (temporal patterns)
3. ChannelMixer (cross-feature dependencies, iTransformer-inspired)
4. Global pooling → PredictionHead (multi-task)
Designed to be modular and accept varying numbers of input features.
"""
def __init__(
self,
num_channels: int, # Number of input features
seq_len: int = 60, # Lookback window
patch_len: int = 8, # Patch length
stride: int = 4, # Patch stride
d_model: int = 128, # Model dimension
n_heads: int = 8, # Number of attention heads
n_layers: int = 3, # Number of transformer layers
d_ff: int = 256, # FFN hidden dimension
num_horizons: int = 3, # Number of prediction horizons
dropout: float = 0.1,
use_channel_mixer: bool = True,
):
super().__init__()
self.num_channels = num_channels
self.seq_len = seq_len
self.d_model = d_model
self.use_channel_mixer = use_channel_mixer
# Instance normalization (PatchTST: mitigate distribution shift)
self.instance_norm = nn.InstanceNorm1d(num_channels, affine=True)
# Patch embedding
self.patch_embed = PatchEmbedding(patch_len, stride, d_model)
# Positional encoding
self.pos_enc = PositionalEncoding(d_model)
# Transformer encoder blocks (channel-independent, per PatchTST)
self.transformer_blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
# Channel mixer (optional cross-channel attention)
if use_channel_mixer:
self.channel_mixer = ChannelMixer(num_channels, d_model, n_heads=4, dropout=dropout)
# Global pooling + prediction head
self.pool_norm = nn.LayerNorm(d_model)
self.prediction_head = PredictionHead(d_model, num_horizons, dropout)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Args:
x: (batch, num_channels, seq_len)
Returns:
Dict with 'direction_logits', 'expected_return', 'log_variance'
"""
B, C, L = x.shape
# Instance normalization
x = self.instance_norm(x)
# Patch embedding: (B, C, num_patches, d_model)
x = self.patch_embed(x)
x = self.pos_enc(x)
# Channel-independent transformer (per PatchTST)
B, C, N, D = x.shape
x_flat = x.reshape(B * C, N, D)
for block in self.transformer_blocks:
x_flat = block(x_flat)
x = x_flat.reshape(B, C, N, D)
# Channel mixing
if self.use_channel_mixer:
x = self.channel_mixer(x)
# Global average pooling across channels and patches
x = x.mean(dim=[1, 2]) # (B, D)
x = self.pool_norm(x)
# Multi-task prediction
predictions = self.prediction_head(x)
return predictions
def predict_with_confidence(self, x: torch.Tensor) -> Dict[str, np.ndarray]:
"""
Make predictions with calibrated confidence scores.
Returns:
direction_probs: Probability of up move per horizon
expected_returns: Expected return per horizon
confidence: Confidence score (0-1) derived from uncertainty
"""
self.eval()
with torch.no_grad():
outputs = self.forward(x)
direction_probs = torch.sigmoid(outputs['direction_logits']).cpu().numpy()
expected_returns = outputs['expected_return'].cpu().numpy()
log_var = outputs['log_variance'].cpu().numpy()
# Confidence = 1 / (1 + exp(log_variance))
confidence = 1.0 / (1.0 + np.exp(log_var))
return {
'direction_probs': direction_probs,
'expected_returns': expected_returns,
'confidence': confidence,
}
class MultiTaskLoss(nn.Module):
"""
Multi-task loss combining:
1. Direction loss (BCE with logits)
2. Return prediction loss (Gaussian NLL for uncertainty-aware regression)
3. Risk-adjusted loss (Sharpe-like penalty)
Uses learned task weights (uncertainty weighting from Kendall et al. 2018).
"""
def __init__(self, num_horizons: int = 3, alpha_direction: float = 1.0,
alpha_return: float = 1.0, alpha_risk: float = 0.5):
super().__init__()
self.num_horizons = num_horizons
self.alpha_direction = alpha_direction
self.alpha_return = alpha_return
self.alpha_risk = alpha_risk
# Learned task uncertainty weights (Kendall et al.)
self.log_sigma_direction = nn.Parameter(torch.tensor(0.0))
self.log_sigma_return = nn.Parameter(torch.tensor(0.0))
self.log_sigma_risk = nn.Parameter(torch.tensor(0.0))
def forward(self, predictions: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Args:
predictions: model outputs
targets: dict with 'direction' (B, H), 'returns' (B, H)
"""
# Direction loss (BCE)
direction_loss = F.binary_cross_entropy_with_logits(
predictions['direction_logits'], targets['direction'],
reduction='mean'
)
# Return prediction loss (Gaussian NLL - heteroscedastic)
log_var = predictions['log_variance']
return_loss = 0.5 * (
torch.exp(-log_var) * (predictions['expected_return'] - targets['returns'])**2
+ log_var
).mean()
# Risk-adjusted loss: penalize predictions that would lead to large drawdowns
# Simulates a simple PnL and penalizes negative Sharpe-like ratio
pred_returns = predictions['expected_return']
pred_direction = torch.sigmoid(predictions['direction_logits'])
simulated_pnl = pred_returns * (2 * pred_direction - 1) # Long if bullish, short if bearish
risk_loss = -simulated_pnl.mean() / (simulated_pnl.std() + 1e-8) # Negative Sharpe
risk_loss = F.relu(risk_loss) # Only penalize negative Sharpe
# Uncertainty-weighted combination
total_loss = (
self.alpha_direction * torch.exp(-self.log_sigma_direction) * direction_loss
+ self.log_sigma_direction
+ self.alpha_return * torch.exp(-self.log_sigma_return) * return_loss
+ self.log_sigma_return
+ self.alpha_risk * torch.exp(-self.log_sigma_risk) * risk_loss
+ self.log_sigma_risk
)
return {
'total_loss': total_loss,
'direction_loss': direction_loss,
'return_loss': return_loss,
'risk_loss': risk_loss,
}
|