File size: 5,250 Bytes
4eb628b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
PatchTST Configuration Reference — All Parameters
====================================================
Annotated configuration for PatchTSTForPrediction with all available parameters.
Sources: ibm-granite/granite-timeseries-patchtst config.json + HF Transformers docs.
Reference papers:
- PatchTST: arxiv:2211.14730 (ICLR 2023)
- Wavelet recipe: arxiv:2408.12408 (2024)
"""
from transformers import PatchTSTConfig, PatchTSTForPrediction
import torch
def get_sp500_config():
"""
PatchTST config optimized for S&P 500 OHLCV next-day forecasting.
Based on Recipe 1 (PatchTST + DWT wavelet denoising).
"""
config = PatchTSTConfig(
# === Data dimensions ===
num_input_channels=5, # OHLCV = 5 channels (or 1 for Close-only)
context_length=512, # Look-back window (512=PatchTST/64, 336=PatchTST/42)
prediction_length=1, # Forecast horizon (1=next-day, 5=next-week)
# === Patching ===
patch_length=16, # P: patch size (paper: 16; IBM: 12)
patch_stride=8, # S: stride between patches (overlapping: S<P)
# num_patches = floor((context_length - patch_length) / patch_stride) + 2
# === Transformer architecture ===
d_model=128, # Hidden dimension
num_attention_heads=16, # Multi-head attention heads
num_hidden_layers=3, # Transformer encoder layers
ffn_dim=512, # Feed-forward network dimension
dropout=0.2, # Main dropout
ff_dropout=0.0, # FFN-specific dropout
attention_dropout=0.0, # Attention dropout
head_dropout=0.2, # Prediction head dropout
path_dropout=0.0, # Stochastic depth dropout
# === Normalization ===
norm_type="batchnorm", # "batchnorm" or "layernorm" (paper: batchnorm)
norm_eps=1e-5,
pre_norm=True, # Pre-norm vs post-norm
# === Loss & output ===
loss="mse", # "mse" or "nll"
distribution_output="student_t", # For NLL: "student_t", "normal", "negative_binomial"
scaling="std", # "std" (instance norm), "mean", or None
# === Positional encoding ===
positional_encoding_type="sincos", # "sincos" or "learned"
positional_dropout=0.0,
# === Channel settings ===
share_embedding=True, # Shared embedding across channels (channel-independence)
share_projection=True, # Shared projection head
channel_attention=False, # Cross-channel attention (False=channel-independent)
# === Self-supervised masking (for pre-training only) ===
do_mask_input=False,
mask_type="random", # "random" or "forecast"
random_mask_ratio=0.5, # 40-50% masking
num_forecast_mask_patches=[2],
# === Misc ===
activation_function="gelu",
init_std=0.02,
use_cls_token=True,
num_parallel_samples=100, # For probabilistic inference
)
return config
def get_ibm_granite_config():
"""
Config matching ibm-granite/granite-timeseries-patchtst (ETTh1 pre-trained).
Note: 7 channels (ETT) — NOT directly compatible with 5-channel OHLCV.
"""
config = PatchTSTConfig(
context_length=512,
prediction_length=96,
num_input_channels=7,
patch_length=12,
patch_stride=12,
d_model=128,
num_attention_heads=16,
num_hidden_layers=3,
ffn_dim=512,
dropout=0.2,
head_dropout=0.2,
norm_type="batchnorm",
scaling="std",
positional_encoding_type="sincos",
use_cls_token=True,
share_embedding=True,
channel_attention=False,
)
return config
# === Quick model instantiation examples ===
if __name__ == "__main__":
# S&P 500 config
config = get_sp500_config()
model = PatchTSTForPrediction(config)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"S&P 500 PatchTST Config:")
print(f" Total parameters: {total_params:,}")
print(f" Trainable parameters: {trainable_params:,}")
print(f" Context length: {config.context_length}")
print(f" Prediction length: {config.prediction_length}")
print(f" Channels: {config.num_input_channels}")
print(f" Patches: ~{(config.context_length - config.patch_length) // config.patch_stride + 2}")
# Test forward pass
batch_size = 4
past_values = torch.randn(batch_size, config.context_length, config.num_input_channels)
future_values = torch.randn(batch_size, config.prediction_length, config.num_input_channels)
outputs = model(past_values=past_values, future_values=future_values)
print(f"\n Loss: {outputs.loss.item():.6f}")
print(f" Prediction shape: {outputs.prediction_outputs.shape}")
print(f" Expected shape: ({batch_size}, {config.prediction_length}, {config.num_input_channels})")
|