patchtst-wavelet-sp500-research / patchtst_config_reference.py
tbukuai's picture
Add PatchTST config reference
4eb628b verified
"""
PatchTST Configuration Reference — All Parameters
====================================================
Annotated configuration for PatchTSTForPrediction with all available parameters.
Sources: ibm-granite/granite-timeseries-patchtst config.json + HF Transformers docs.
Reference papers:
- PatchTST: arxiv:2211.14730 (ICLR 2023)
- Wavelet recipe: arxiv:2408.12408 (2024)
"""
from transformers import PatchTSTConfig, PatchTSTForPrediction
import torch
def get_sp500_config():
"""
PatchTST config optimized for S&P 500 OHLCV next-day forecasting.
Based on Recipe 1 (PatchTST + DWT wavelet denoising).
"""
config = PatchTSTConfig(
# === Data dimensions ===
num_input_channels=5, # OHLCV = 5 channels (or 1 for Close-only)
context_length=512, # Look-back window (512=PatchTST/64, 336=PatchTST/42)
prediction_length=1, # Forecast horizon (1=next-day, 5=next-week)
# === Patching ===
patch_length=16, # P: patch size (paper: 16; IBM: 12)
patch_stride=8, # S: stride between patches (overlapping: S<P)
# num_patches = floor((context_length - patch_length) / patch_stride) + 2
# === Transformer architecture ===
d_model=128, # Hidden dimension
num_attention_heads=16, # Multi-head attention heads
num_hidden_layers=3, # Transformer encoder layers
ffn_dim=512, # Feed-forward network dimension
dropout=0.2, # Main dropout
ff_dropout=0.0, # FFN-specific dropout
attention_dropout=0.0, # Attention dropout
head_dropout=0.2, # Prediction head dropout
path_dropout=0.0, # Stochastic depth dropout
# === Normalization ===
norm_type="batchnorm", # "batchnorm" or "layernorm" (paper: batchnorm)
norm_eps=1e-5,
pre_norm=True, # Pre-norm vs post-norm
# === Loss & output ===
loss="mse", # "mse" or "nll"
distribution_output="student_t", # For NLL: "student_t", "normal", "negative_binomial"
scaling="std", # "std" (instance norm), "mean", or None
# === Positional encoding ===
positional_encoding_type="sincos", # "sincos" or "learned"
positional_dropout=0.0,
# === Channel settings ===
share_embedding=True, # Shared embedding across channels (channel-independence)
share_projection=True, # Shared projection head
channel_attention=False, # Cross-channel attention (False=channel-independent)
# === Self-supervised masking (for pre-training only) ===
do_mask_input=False,
mask_type="random", # "random" or "forecast"
random_mask_ratio=0.5, # 40-50% masking
num_forecast_mask_patches=[2],
# === Misc ===
activation_function="gelu",
init_std=0.02,
use_cls_token=True,
num_parallel_samples=100, # For probabilistic inference
)
return config
def get_ibm_granite_config():
"""
Config matching ibm-granite/granite-timeseries-patchtst (ETTh1 pre-trained).
Note: 7 channels (ETT) — NOT directly compatible with 5-channel OHLCV.
"""
config = PatchTSTConfig(
context_length=512,
prediction_length=96,
num_input_channels=7,
patch_length=12,
patch_stride=12,
d_model=128,
num_attention_heads=16,
num_hidden_layers=3,
ffn_dim=512,
dropout=0.2,
head_dropout=0.2,
norm_type="batchnorm",
scaling="std",
positional_encoding_type="sincos",
use_cls_token=True,
share_embedding=True,
channel_attention=False,
)
return config
# === Quick model instantiation examples ===
if __name__ == "__main__":
# S&P 500 config
config = get_sp500_config()
model = PatchTSTForPrediction(config)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"S&P 500 PatchTST Config:")
print(f" Total parameters: {total_params:,}")
print(f" Trainable parameters: {trainable_params:,}")
print(f" Context length: {config.context_length}")
print(f" Prediction length: {config.prediction_length}")
print(f" Channels: {config.num_input_channels}")
print(f" Patches: ~{(config.context_length - config.patch_length) // config.patch_stride + 2}")
# Test forward pass
batch_size = 4
past_values = torch.randn(batch_size, config.context_length, config.num_input_channels)
future_values = torch.randn(batch_size, config.prediction_length, config.num_input_channels)
outputs = model(past_values=past_values, future_values=future_values)
print(f"\n Loss: {outputs.loss.item():.6f}")
print(f" Prediction shape: {outputs.prediction_outputs.shape}")
print(f" Expected shape: ({batch_size}, {config.prediction_length}, {config.num_input_channels})")