avinashhm commited on
Commit
fbc0d3d
·
verified ·
1 Parent(s): 0e8833f

Add trading_intelligence/prediction_model.py

Browse files
trading_intelligence/prediction_model.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prediction Model Module
3
+ ========================
4
+ Multi-horizon Transformer-based prediction model.
5
+
6
+ Architecture: PatchTST-inspired with Kronos-style multi-resolution encoding.
7
+ - Patch embedding for temporal features
8
+ - Multi-head self-attention across patches
9
+ - Multi-task heads for direction, return, and uncertainty
10
+
11
+ Key design decisions (from literature):
12
+ 1. PatchTST (2211.14730): Channel-independent patching reduces O(L²) to O((L/S)²)
13
+ 2. Chronos (2403.07815): Probabilistic outputs via distributional heads
14
+ 3. Kronos (2508.02739): Coarse-to-fine hierarchical predictions for financial data
15
+ 4. iTransformer: Inverted attention on variate dimension
16
+ """
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import numpy as np
22
+ import math
23
+ from typing import Dict, List, Optional, Tuple
24
+
25
+
26
+ class PatchEmbedding(nn.Module):
27
+ """
28
+ PatchTST-style patch embedding for time series.
29
+
30
+ Splits each channel's sequence into overlapping patches,
31
+ then projects to embedding dimension.
32
+ """
33
+
34
+ def __init__(self, patch_len: int = 8, stride: int = 4, d_model: int = 128):
35
+ super().__init__()
36
+ self.patch_len = patch_len
37
+ self.stride = stride
38
+ self.projection = nn.Linear(patch_len, d_model)
39
+ self.layer_norm = nn.LayerNorm(d_model)
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ """
43
+ Args:
44
+ x: (batch, channels, seq_len)
45
+ Returns:
46
+ patches: (batch, channels, num_patches, d_model)
47
+ """
48
+ B, C, L = x.shape
49
+
50
+ # Pad if necessary
51
+ pad_len = (self.stride - (L - self.patch_len) % self.stride) % self.stride
52
+ if pad_len > 0:
53
+ x = F.pad(x, (0, pad_len), mode='replicate')
54
+ L = L + pad_len
55
+
56
+ # Unfold into patches: (B, C, num_patches, patch_len)
57
+ num_patches = (L - self.patch_len) // self.stride + 1
58
+ patches = x.unfold(dimension=2, size=self.patch_len, step=self.stride)
59
+
60
+ # Project: (B, C, num_patches, d_model)
61
+ patches = self.projection(patches)
62
+ patches = self.layer_norm(patches)
63
+
64
+ return patches
65
+
66
+
67
+ class PositionalEncoding(nn.Module):
68
+ """Learnable positional encoding for patches."""
69
+
70
+ def __init__(self, d_model: int, max_patches: int = 200):
71
+ super().__init__()
72
+ self.pos_embed = nn.Parameter(torch.randn(1, 1, max_patches, d_model) * 0.02)
73
+
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ """x: (B, C, num_patches, d_model)"""
76
+ return x + self.pos_embed[:, :, :x.size(2), :]
77
+
78
+
79
+ class MultiHeadAttention(nn.Module):
80
+ """Standard multi-head self-attention."""
81
+
82
+ def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
83
+ super().__init__()
84
+ self.n_heads = n_heads
85
+ self.d_k = d_model // n_heads
86
+
87
+ self.W_q = nn.Linear(d_model, d_model)
88
+ self.W_k = nn.Linear(d_model, d_model)
89
+ self.W_v = nn.Linear(d_model, d_model)
90
+ self.W_o = nn.Linear(d_model, d_model)
91
+ self.dropout = nn.Dropout(dropout)
92
+
93
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
94
+ B, N, D = x.shape
95
+
96
+ Q = self.W_q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
97
+ K = self.W_k(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
98
+ V = self.W_v(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
99
+
100
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
101
+ if mask is not None:
102
+ scores = scores.masked_fill(mask == 0, -1e9)
103
+
104
+ attn = F.softmax(scores, dim=-1)
105
+ attn = self.dropout(attn)
106
+
107
+ out = torch.matmul(attn, V)
108
+ out = out.transpose(1, 2).contiguous().view(B, N, D)
109
+ return self.W_o(out)
110
+
111
+
112
+ class TransformerBlock(nn.Module):
113
+ """Transformer encoder block with pre-norm (better for time series per PatchTST)."""
114
+
115
+ def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
116
+ super().__init__()
117
+ self.attn = MultiHeadAttention(d_model, n_heads, dropout)
118
+ self.norm1 = nn.LayerNorm(d_model)
119
+ self.norm2 = nn.LayerNorm(d_model)
120
+ self.ff = nn.Sequential(
121
+ nn.Linear(d_model, d_ff),
122
+ nn.GELU(),
123
+ nn.Dropout(dropout),
124
+ nn.Linear(d_ff, d_model),
125
+ nn.Dropout(dropout)
126
+ )
127
+
128
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
129
+ # Pre-norm attention
130
+ x = x + self.attn(self.norm1(x))
131
+ # Pre-norm FFN
132
+ x = x + self.ff(self.norm2(x))
133
+ return x
134
+
135
+
136
+ class ChannelMixer(nn.Module):
137
+ """
138
+ Cross-channel attention for capturing inter-feature dependencies.
139
+ Inspired by iTransformer - applies attention across variate dimension.
140
+ """
141
+
142
+ def __init__(self, num_channels: int, d_model: int, n_heads: int = 4, dropout: float = 0.1):
143
+ super().__init__()
144
+ self.channel_attn = MultiHeadAttention(d_model, n_heads, dropout)
145
+ self.norm = nn.LayerNorm(d_model)
146
+
147
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
148
+ """
149
+ Args:
150
+ x: (B, C, num_patches, d_model)
151
+ Returns:
152
+ x: (B, C, num_patches, d_model) with cross-channel info
153
+ """
154
+ B, C, N, D = x.shape
155
+
156
+ # Pool across patches for channel representation
157
+ channel_repr = x.mean(dim=2) # (B, C, D)
158
+
159
+ # Cross-channel attention
160
+ channel_out = self.channel_attn(self.norm(channel_repr)) # (B, C, D)
161
+
162
+ # Broadcast back and add
163
+ x = x + channel_out.unsqueeze(2)
164
+
165
+ return x
166
+
167
+
168
+ class PredictionHead(nn.Module):
169
+ """
170
+ Multi-task prediction head.
171
+
172
+ Outputs:
173
+ 1. Direction probability (binary classification per horizon)
174
+ 2. Expected return (regression per horizon)
175
+ 3. Uncertainty/confidence (learned aleatoric uncertainty)
176
+ """
177
+
178
+ def __init__(self, d_model: int, num_horizons: int = 3, dropout: float = 0.1):
179
+ super().__init__()
180
+ self.num_horizons = num_horizons
181
+
182
+ # Shared representation
183
+ self.shared = nn.Sequential(
184
+ nn.Linear(d_model, d_model),
185
+ nn.GELU(),
186
+ nn.Dropout(dropout),
187
+ )
188
+
189
+ # Direction head (classification)
190
+ self.direction_head = nn.Sequential(
191
+ nn.Linear(d_model, d_model // 2),
192
+ nn.GELU(),
193
+ nn.Linear(d_model // 2, num_horizons),
194
+ )
195
+
196
+ # Return prediction head (regression)
197
+ self.return_head = nn.Sequential(
198
+ nn.Linear(d_model, d_model // 2),
199
+ nn.GELU(),
200
+ nn.Linear(d_model // 2, num_horizons),
201
+ )
202
+
203
+ # Uncertainty head (log variance - Gaussian heteroscedastic)
204
+ self.uncertainty_head = nn.Sequential(
205
+ nn.Linear(d_model, d_model // 2),
206
+ nn.GELU(),
207
+ nn.Linear(d_model // 2, num_horizons),
208
+ )
209
+
210
+ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
211
+ """
212
+ Args:
213
+ x: (B, d_model) - pooled representation
214
+ Returns:
215
+ dict with 'direction_logits', 'expected_return', 'log_variance'
216
+ """
217
+ shared = self.shared(x)
218
+
219
+ return {
220
+ 'direction_logits': self.direction_head(shared), # (B, num_horizons)
221
+ 'expected_return': self.return_head(shared), # (B, num_horizons)
222
+ 'log_variance': self.uncertainty_head(shared), # (B, num_horizons)
223
+ }
224
+
225
+
226
+ class TradingTransformer(nn.Module):
227
+ """
228
+ Main prediction model: Patch-based Transformer for multi-horizon trading predictions.
229
+
230
+ Architecture:
231
+ 1. PatchEmbedding → patches per channel (PatchTST)
232
+ 2. Intra-channel Transformer blocks (temporal patterns)
233
+ 3. ChannelMixer (cross-feature dependencies, iTransformer-inspired)
234
+ 4. Global pooling → PredictionHead (multi-task)
235
+
236
+ Designed to be modular and accept varying numbers of input features.
237
+ """
238
+
239
+ def __init__(
240
+ self,
241
+ num_channels: int, # Number of input features
242
+ seq_len: int = 60, # Lookback window
243
+ patch_len: int = 8, # Patch length
244
+ stride: int = 4, # Patch stride
245
+ d_model: int = 128, # Model dimension
246
+ n_heads: int = 8, # Number of attention heads
247
+ n_layers: int = 3, # Number of transformer layers
248
+ d_ff: int = 256, # FFN hidden dimension
249
+ num_horizons: int = 3, # Number of prediction horizons
250
+ dropout: float = 0.1,
251
+ use_channel_mixer: bool = True,
252
+ ):
253
+ super().__init__()
254
+
255
+ self.num_channels = num_channels
256
+ self.seq_len = seq_len
257
+ self.d_model = d_model
258
+ self.use_channel_mixer = use_channel_mixer
259
+
260
+ # Instance normalization (PatchTST: mitigate distribution shift)
261
+ self.instance_norm = nn.InstanceNorm1d(num_channels, affine=True)
262
+
263
+ # Patch embedding
264
+ self.patch_embed = PatchEmbedding(patch_len, stride, d_model)
265
+
266
+ # Positional encoding
267
+ self.pos_enc = PositionalEncoding(d_model)
268
+
269
+ # Transformer encoder blocks (channel-independent, per PatchTST)
270
+ self.transformer_blocks = nn.ModuleList([
271
+ TransformerBlock(d_model, n_heads, d_ff, dropout)
272
+ for _ in range(n_layers)
273
+ ])
274
+
275
+ # Channel mixer (optional cross-channel attention)
276
+ if use_channel_mixer:
277
+ self.channel_mixer = ChannelMixer(num_channels, d_model, n_heads=4, dropout=dropout)
278
+
279
+ # Global pooling + prediction head
280
+ self.pool_norm = nn.LayerNorm(d_model)
281
+ self.prediction_head = PredictionHead(d_model, num_horizons, dropout)
282
+
283
+ # Initialize weights
284
+ self.apply(self._init_weights)
285
+
286
+ def _init_weights(self, module):
287
+ if isinstance(module, nn.Linear):
288
+ nn.init.xavier_uniform_(module.weight)
289
+ if module.bias is not None:
290
+ nn.init.zeros_(module.bias)
291
+
292
+ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
293
+ """
294
+ Args:
295
+ x: (batch, num_channels, seq_len)
296
+ Returns:
297
+ Dict with 'direction_logits', 'expected_return', 'log_variance'
298
+ """
299
+ B, C, L = x.shape
300
+
301
+ # Instance normalization
302
+ x = self.instance_norm(x)
303
+
304
+ # Patch embedding: (B, C, num_patches, d_model)
305
+ x = self.patch_embed(x)
306
+ x = self.pos_enc(x)
307
+
308
+ # Channel-independent transformer (per PatchTST)
309
+ B, C, N, D = x.shape
310
+ x_flat = x.reshape(B * C, N, D)
311
+
312
+ for block in self.transformer_blocks:
313
+ x_flat = block(x_flat)
314
+
315
+ x = x_flat.reshape(B, C, N, D)
316
+
317
+ # Channel mixing
318
+ if self.use_channel_mixer:
319
+ x = self.channel_mixer(x)
320
+
321
+ # Global average pooling across channels and patches
322
+ x = x.mean(dim=[1, 2]) # (B, D)
323
+ x = self.pool_norm(x)
324
+
325
+ # Multi-task prediction
326
+ predictions = self.prediction_head(x)
327
+
328
+ return predictions
329
+
330
+ def predict_with_confidence(self, x: torch.Tensor) -> Dict[str, np.ndarray]:
331
+ """
332
+ Make predictions with calibrated confidence scores.
333
+
334
+ Returns:
335
+ direction_probs: Probability of up move per horizon
336
+ expected_returns: Expected return per horizon
337
+ confidence: Confidence score (0-1) derived from uncertainty
338
+ """
339
+ self.eval()
340
+ with torch.no_grad():
341
+ outputs = self.forward(x)
342
+
343
+ direction_probs = torch.sigmoid(outputs['direction_logits']).cpu().numpy()
344
+ expected_returns = outputs['expected_return'].cpu().numpy()
345
+ log_var = outputs['log_variance'].cpu().numpy()
346
+
347
+ # Confidence = 1 / (1 + exp(log_variance))
348
+ confidence = 1.0 / (1.0 + np.exp(log_var))
349
+
350
+ return {
351
+ 'direction_probs': direction_probs,
352
+ 'expected_returns': expected_returns,
353
+ 'confidence': confidence,
354
+ }
355
+
356
+
357
+ class MultiTaskLoss(nn.Module):
358
+ """
359
+ Multi-task loss combining:
360
+ 1. Direction loss (BCE with logits)
361
+ 2. Return prediction loss (Gaussian NLL for uncertainty-aware regression)
362
+ 3. Risk-adjusted loss (Sharpe-like penalty)
363
+
364
+ Uses learned task weights (uncertainty weighting from Kendall et al. 2018).
365
+ """
366
+
367
+ def __init__(self, num_horizons: int = 3, alpha_direction: float = 1.0,
368
+ alpha_return: float = 1.0, alpha_risk: float = 0.5):
369
+ super().__init__()
370
+ self.num_horizons = num_horizons
371
+ self.alpha_direction = alpha_direction
372
+ self.alpha_return = alpha_return
373
+ self.alpha_risk = alpha_risk
374
+
375
+ # Learned task uncertainty weights (Kendall et al.)
376
+ self.log_sigma_direction = nn.Parameter(torch.zeros(1))
377
+ self.log_sigma_return = nn.Parameter(torch.zeros(1))
378
+ self.log_sigma_risk = nn.Parameter(torch.zeros(1))
379
+
380
+ def forward(self, predictions: Dict[str, torch.Tensor],
381
+ targets: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
382
+ """
383
+ Args:
384
+ predictions: model outputs
385
+ targets: dict with 'direction' (B, H), 'returns' (B, H)
386
+ """
387
+ # Direction loss (BCE)
388
+ direction_loss = F.binary_cross_entropy_with_logits(
389
+ predictions['direction_logits'], targets['direction'],
390
+ reduction='mean'
391
+ )
392
+
393
+ # Return prediction loss (Gaussian NLL - heteroscedastic)
394
+ log_var = predictions['log_variance']
395
+ return_loss = 0.5 * (
396
+ torch.exp(-log_var) * (predictions['expected_return'] - targets['returns'])**2
397
+ + log_var
398
+ ).mean()
399
+
400
+ # Risk-adjusted loss: penalize predictions that would lead to large drawdowns
401
+ # Simulates a simple PnL and penalizes negative Sharpe-like ratio
402
+ pred_returns = predictions['expected_return']
403
+ pred_direction = torch.sigmoid(predictions['direction_logits'])
404
+ simulated_pnl = pred_returns * (2 * pred_direction - 1) # Long if bullish, short if bearish
405
+ risk_loss = -simulated_pnl.mean() / (simulated_pnl.std() + 1e-8) # Negative Sharpe
406
+ risk_loss = F.relu(risk_loss) # Only penalize negative Sharpe
407
+
408
+ # Uncertainty-weighted combination
409
+ total_loss = (
410
+ self.alpha_direction * torch.exp(-self.log_sigma_direction) * direction_loss
411
+ + self.log_sigma_direction
412
+ + self.alpha_return * torch.exp(-self.log_sigma_return) * return_loss
413
+ + self.log_sigma_return
414
+ + self.alpha_risk * torch.exp(-self.log_sigma_risk) * risk_loss
415
+ + self.log_sigma_risk
416
+ )
417
+
418
+ return {
419
+ 'total_loss': total_loss,
420
+ 'direction_loss': direction_loss,
421
+ 'return_loss': return_loss,
422
+ 'risk_loss': risk_loss,
423
+ }