File size: 5,250 Bytes
4eb628b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
PatchTST Configuration Reference — All Parameters
====================================================
Annotated configuration for PatchTSTForPrediction with all available parameters.
Sources: ibm-granite/granite-timeseries-patchtst config.json + HF Transformers docs.

Reference papers:
  - PatchTST: arxiv:2211.14730 (ICLR 2023)
  - Wavelet recipe: arxiv:2408.12408 (2024)
"""

from transformers import PatchTSTConfig, PatchTSTForPrediction
import torch


def get_sp500_config():
    """
    PatchTST config optimized for S&P 500 OHLCV next-day forecasting.
    Based on Recipe 1 (PatchTST + DWT wavelet denoising).
    """
    config = PatchTSTConfig(
        # === Data dimensions ===
        num_input_channels=5,        # OHLCV = 5 channels (or 1 for Close-only)
        context_length=512,          # Look-back window (512=PatchTST/64, 336=PatchTST/42)
        prediction_length=1,         # Forecast horizon (1=next-day, 5=next-week)

        # === Patching ===
        patch_length=16,             # P: patch size (paper: 16; IBM: 12)
        patch_stride=8,              # S: stride between patches (overlapping: S<P)
        # num_patches = floor((context_length - patch_length) / patch_stride) + 2

        # === Transformer architecture ===
        d_model=128,                 # Hidden dimension
        num_attention_heads=16,      # Multi-head attention heads
        num_hidden_layers=3,         # Transformer encoder layers
        ffn_dim=512,                 # Feed-forward network dimension
        dropout=0.2,                 # Main dropout
        ff_dropout=0.0,              # FFN-specific dropout
        attention_dropout=0.0,       # Attention dropout
        head_dropout=0.2,            # Prediction head dropout
        path_dropout=0.0,            # Stochastic depth dropout

        # === Normalization ===
        norm_type="batchnorm",       # "batchnorm" or "layernorm" (paper: batchnorm)
        norm_eps=1e-5,
        pre_norm=True,               # Pre-norm vs post-norm

        # === Loss & output ===
        loss="mse",                  # "mse" or "nll"
        distribution_output="student_t",  # For NLL: "student_t", "normal", "negative_binomial"
        scaling="std",               # "std" (instance norm), "mean", or None

        # === Positional encoding ===
        positional_encoding_type="sincos",   # "sincos" or "learned"
        positional_dropout=0.0,

        # === Channel settings ===
        share_embedding=True,        # Shared embedding across channels (channel-independence)
        share_projection=True,       # Shared projection head
        channel_attention=False,     # Cross-channel attention (False=channel-independent)

        # === Self-supervised masking (for pre-training only) ===
        do_mask_input=False,
        mask_type="random",          # "random" or "forecast"
        random_mask_ratio=0.5,       # 40-50% masking
        num_forecast_mask_patches=[2],

        # === Misc ===
        activation_function="gelu",
        init_std=0.02,
        use_cls_token=True,
        num_parallel_samples=100,    # For probabilistic inference
    )
    return config


def get_ibm_granite_config():
    """
    Config matching ibm-granite/granite-timeseries-patchtst (ETTh1 pre-trained).
    Note: 7 channels (ETT) — NOT directly compatible with 5-channel OHLCV.
    """
    config = PatchTSTConfig(
        context_length=512,
        prediction_length=96,
        num_input_channels=7,
        patch_length=12,
        patch_stride=12,
        d_model=128,
        num_attention_heads=16,
        num_hidden_layers=3,
        ffn_dim=512,
        dropout=0.2,
        head_dropout=0.2,
        norm_type="batchnorm",
        scaling="std",
        positional_encoding_type="sincos",
        use_cls_token=True,
        share_embedding=True,
        channel_attention=False,
    )
    return config


# === Quick model instantiation examples ===

if __name__ == "__main__":
    # S&P 500 config
    config = get_sp500_config()
    model = PatchTSTForPrediction(config)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"S&P 500 PatchTST Config:")
    print(f"  Total parameters:     {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    print(f"  Context length:       {config.context_length}")
    print(f"  Prediction length:    {config.prediction_length}")
    print(f"  Channels:             {config.num_input_channels}")
    print(f"  Patches:              ~{(config.context_length - config.patch_length) // config.patch_stride + 2}")
    
    # Test forward pass
    batch_size = 4
    past_values = torch.randn(batch_size, config.context_length, config.num_input_channels)
    future_values = torch.randn(batch_size, config.prediction_length, config.num_input_channels)
    
    outputs = model(past_values=past_values, future_values=future_values)
    print(f"\n  Loss:              {outputs.loss.item():.6f}")
    print(f"  Prediction shape:  {outputs.prediction_outputs.shape}")
    print(f"  Expected shape:    ({batch_size}, {config.prediction_length}, {config.num_input_channels})")