|
|
"""Self-contained TimesFM 2.x wrapper compatible with the TimesFM interface.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import dataclasses |
|
|
import math |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import nn |
|
|
|
|
|
try: |
|
|
from safetensors.torch import load_file as _load_safetensors |
|
|
except ImportError: |
|
|
_load_safetensors = None |
|
|
|
|
|
_TOLERANCE = 1e-6 |
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True) |
|
|
class ResidualBlockConfig: |
|
|
input_dims: int |
|
|
hidden_dims: int |
|
|
output_dims: int |
|
|
use_bias: bool |
|
|
activation: str |
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True) |
|
|
class TransformerConfig: |
|
|
model_dims: int |
|
|
hidden_dims: int |
|
|
num_heads: int |
|
|
attention_norm: str |
|
|
feedforward_norm: str |
|
|
qk_norm: str |
|
|
use_bias: bool |
|
|
use_rotary_position_embeddings: bool |
|
|
ff_activation: str |
|
|
fuse_qkv: bool |
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True) |
|
|
class StackedTransformersConfig: |
|
|
num_layers: int |
|
|
transformer: TransformerConfig |
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True) |
|
|
class TimesFM2Definition: |
|
|
"""Framework-agnostic description of TimesFM 2.5 (200M parameters).""" |
|
|
|
|
|
context_limit: int = 16384 |
|
|
input_patch_len: int = 32 |
|
|
output_patch_len: int = 128 |
|
|
output_quantile_len: int = 1024 |
|
|
quantiles: tuple[float, ...] = ( |
|
|
0.1, |
|
|
0.2, |
|
|
0.3, |
|
|
0.4, |
|
|
0.5, |
|
|
0.6, |
|
|
0.7, |
|
|
0.8, |
|
|
0.9, |
|
|
) |
|
|
decode_index: int = 5 |
|
|
tokenizer: ResidualBlockConfig = dataclasses.field( |
|
|
default_factory=lambda: ResidualBlockConfig( |
|
|
input_dims=64, |
|
|
hidden_dims=1280, |
|
|
output_dims=1280, |
|
|
use_bias=True, |
|
|
activation="swish", |
|
|
) |
|
|
) |
|
|
stacked_transformers: StackedTransformersConfig = dataclasses.field( |
|
|
default_factory=lambda: StackedTransformersConfig( |
|
|
num_layers=20, |
|
|
transformer=TransformerConfig( |
|
|
model_dims=1280, |
|
|
hidden_dims=1280, |
|
|
num_heads=16, |
|
|
attention_norm="rms", |
|
|
feedforward_norm="rms", |
|
|
qk_norm="rms", |
|
|
use_bias=False, |
|
|
use_rotary_position_embeddings=True, |
|
|
ff_activation="swish", |
|
|
fuse_qkv=True, |
|
|
), |
|
|
) |
|
|
) |
|
|
output_projection_point: ResidualBlockConfig = dataclasses.field( |
|
|
default_factory=lambda: ResidualBlockConfig( |
|
|
input_dims=1280, |
|
|
hidden_dims=1280, |
|
|
output_dims=1280, |
|
|
use_bias=False, |
|
|
activation="swish", |
|
|
) |
|
|
) |
|
|
output_projection_quantiles: ResidualBlockConfig = dataclasses.field( |
|
|
default_factory=lambda: ResidualBlockConfig( |
|
|
input_dims=1280, |
|
|
hidden_dims=1280, |
|
|
output_dims=10240, |
|
|
use_bias=False, |
|
|
activation="swish", |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=False) |
|
|
class DecodeCache: |
|
|
next_index: torch.Tensor |
|
|
num_masked: torch.Tensor |
|
|
key: torch.Tensor |
|
|
value: torch.Tensor |
|
|
|
|
|
|
|
|
def update_running_stats( |
|
|
n: torch.Tensor, |
|
|
mu: torch.Tensor, |
|
|
sigma: torch.Tensor, |
|
|
x: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
|
|
"""Updates reversible normalization statistics for a new patch.""" |
|
|
is_legit = torch.logical_not(mask) |
|
|
inc_n = torch.sum(is_legit.to(x.dtype), dim=-1) |
|
|
|
|
|
inc_mu_numerator = torch.sum(x * is_legit, dim=-1) |
|
|
inc_n_safe = torch.where(inc_n == 0, 1.0, inc_n) |
|
|
inc_mu = inc_mu_numerator / inc_n_safe |
|
|
inc_mu = torch.where(inc_n == 0, 0.0, inc_mu) |
|
|
|
|
|
inc_var_numerator = torch.sum(((x - inc_mu.unsqueeze(-1)) ** 2) * is_legit, dim=-1) |
|
|
inc_var = inc_var_numerator / inc_n_safe |
|
|
inc_var = torch.where(inc_n == 0, 0.0, inc_var) |
|
|
inc_sigma = torch.sqrt(inc_var) |
|
|
|
|
|
new_n = n + inc_n |
|
|
new_n_safe = torch.where(new_n == 0, 1.0, new_n) |
|
|
|
|
|
new_mu = (n * mu + inc_mu * inc_n) / new_n_safe |
|
|
new_mu = torch.where(new_n == 0, 0.0, new_mu) |
|
|
|
|
|
term1 = n * sigma.pow(2) |
|
|
term2 = inc_n * inc_sigma.pow(2) |
|
|
term3 = n * (mu - new_mu).pow(2) |
|
|
term4 = inc_n * (inc_mu - new_mu).pow(2) |
|
|
|
|
|
new_var = (term1 + term2 + term3 + term4) / new_n_safe |
|
|
new_var = torch.where(new_n == 0, 0.0, new_var) |
|
|
new_sigma = torch.sqrt(torch.clamp(new_var, min=0.0)) |
|
|
|
|
|
return (new_n, new_mu, new_sigma), (new_n, new_mu, new_sigma) |
|
|
|
|
|
|
|
|
def revin(x: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor, reverse: bool = False) -> torch.Tensor: |
|
|
"""Reversible instance normalization.""" |
|
|
if len(mu.shape) == len(x.shape) - 1: |
|
|
mu = mu[..., None] |
|
|
sigma = sigma[..., None] |
|
|
elif len(mu.shape) == len(x.shape) - 2: |
|
|
mu = mu[..., None, None] |
|
|
sigma = sigma[..., None, None] |
|
|
|
|
|
if reverse: |
|
|
return x * sigma + mu |
|
|
|
|
|
sigma_safe = torch.where(sigma < _TOLERANCE, torch.ones_like(sigma), sigma) |
|
|
return (x - mu) / sigma_safe |
|
|
|
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
|
"""Residual block composed of a pair of linear layers.""" |
|
|
|
|
|
def __init__(self, config: ResidualBlockConfig): |
|
|
super().__init__() |
|
|
self.activation = self._resolve_activation(config.activation) |
|
|
self.hidden_layer = nn.Linear(config.input_dims, config.hidden_dims, bias=config.use_bias) |
|
|
self.output_layer = nn.Linear(config.hidden_dims, config.output_dims, bias=config.use_bias) |
|
|
self.residual_layer = nn.Linear(config.input_dims, config.output_dims, bias=config.use_bias) |
|
|
|
|
|
@staticmethod |
|
|
def _resolve_activation(name: str) -> nn.Module: |
|
|
if name == "relu": |
|
|
return nn.ReLU() |
|
|
if name == "swish": |
|
|
return nn.SiLU() |
|
|
if name == "none": |
|
|
return nn.Identity() |
|
|
raise ValueError(f"Unsupported activation: {name}") |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
hidden = self.activation(self.hidden_layer(x)) |
|
|
return self.output_layer(hidden) + self.residual_layer(x) |
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
"""Root-mean-square normalization.""" |
|
|
|
|
|
def __init__(self, num_features: int, epsilon: float = 1e-6): |
|
|
super().__init__() |
|
|
self.scale = nn.Parameter(torch.zeros(num_features)) |
|
|
self.epsilon = epsilon |
|
|
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|
|
var = torch.mean(torch.square(inputs), dim=-1, keepdim=True) |
|
|
normed_inputs = inputs * torch.rsqrt(var + self.epsilon) |
|
|
return normed_inputs * self.scale |
|
|
|
|
|
|
|
|
def make_attn_mask( |
|
|
query_length: int, |
|
|
num_all_masked_kv: torch.Tensor, |
|
|
query_index_offset: torch.Tensor | None = None, |
|
|
kv_length: int = 0, |
|
|
) -> torch.Tensor: |
|
|
"""Creates a causal mask consistent with cached decoding.""" |
|
|
if kv_length == 0: |
|
|
kv_length = query_length |
|
|
|
|
|
q_index = torch.arange(query_length, device=num_all_masked_kv.device)[None, None, :, None] |
|
|
if query_index_offset is not None: |
|
|
q_index = q_index + query_index_offset[:, None, None, None] |
|
|
kv_index = torch.arange(kv_length, device=num_all_masked_kv.device)[None, None, None, :] |
|
|
|
|
|
return torch.logical_and(q_index >= kv_index, kv_index >= num_all_masked_kv[:, None, None, None]) |
|
|
|
|
|
|
|
|
class RotaryPositionalEmbedding(nn.Module): |
|
|
"""Applies rotary position embeddings to query/key projections.""" |
|
|
|
|
|
def __init__(self, embedding_dims: int, min_timescale: float = 1.0, max_timescale: float = 10000.0): |
|
|
super().__init__() |
|
|
self.embedding_dims = embedding_dims |
|
|
self.min_timescale = min_timescale |
|
|
self.max_timescale = max_timescale |
|
|
|
|
|
def forward(self, inputs: torch.Tensor, position: torch.Tensor | None = None) -> torch.Tensor: |
|
|
if self.embedding_dims != inputs.shape[-1]: |
|
|
raise ValueError("Rotary embedding dimension must equal the head dimension.") |
|
|
|
|
|
half_dim = self.embedding_dims // 2 |
|
|
fraction = 2 * torch.arange(half_dim, device=inputs.device) / self.embedding_dims |
|
|
timescale = (self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction).to(inputs.device) |
|
|
|
|
|
if position is None: |
|
|
position = torch.arange(inputs.shape[1], dtype=torch.float32, device=inputs.device)[None, :] |
|
|
|
|
|
if len(inputs.shape) == 4: |
|
|
position = position[..., None, None] |
|
|
timescale = timescale[None, None, None, :] |
|
|
elif len(inputs.shape) == 3: |
|
|
position = position[..., None] |
|
|
timescale = timescale[None, None, :] |
|
|
else: |
|
|
raise ValueError("Expected rank-3 or rank-4 tensor for rotary embeddings.") |
|
|
|
|
|
sinusoid = position / timescale |
|
|
sin = torch.sin(sinusoid) |
|
|
cos = torch.cos(sinusoid) |
|
|
|
|
|
first_half, second_half = torch.chunk(inputs, 2, dim=-1) |
|
|
rotated_first = first_half * cos - second_half * sin |
|
|
rotated_second = second_half * cos + first_half * sin |
|
|
return torch.cat([rotated_first, rotated_second], dim=-1) |
|
|
|
|
|
|
|
|
class PerDimScale(nn.Module): |
|
|
"""Learned per-dimension scaling used prior to attention.""" |
|
|
|
|
|
def __init__(self, num_dims: int): |
|
|
super().__init__() |
|
|
self.num_dims = num_dims |
|
|
self.per_dim_scale = nn.Parameter(torch.zeros(num_dims)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
scale_factor = 1.442695041 / math.sqrt(self.num_dims) * F.softplus(self.per_dim_scale) |
|
|
return x * scale_factor |
|
|
|
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
|
"""Multi-head attention supporting fused QKV projections and caching.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_heads: int, |
|
|
in_features: int, |
|
|
*, |
|
|
use_per_dim_scale: bool = True, |
|
|
use_rotary_position_embeddings: bool = True, |
|
|
use_bias: bool = False, |
|
|
attention_fn=F.scaled_dot_product_attention, |
|
|
qk_norm: str = "rms", |
|
|
fuse_qkv: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
self.num_heads = num_heads |
|
|
self.in_features = in_features |
|
|
self.head_dim = in_features // num_heads |
|
|
self.use_bias = use_bias |
|
|
self.attention_fn = attention_fn |
|
|
self.qk_norm = qk_norm |
|
|
self.fuse_qkv = fuse_qkv |
|
|
|
|
|
if in_features % num_heads != 0: |
|
|
raise ValueError(f"Model dimension {in_features} must be divisible by {num_heads} heads.") |
|
|
|
|
|
if fuse_qkv: |
|
|
self.qkv_proj = nn.Linear(in_features, 3 * in_features, bias=use_bias) |
|
|
else: |
|
|
self.query = nn.Linear(in_features, in_features, bias=use_bias) |
|
|
self.key = nn.Linear(in_features, in_features, bias=use_bias) |
|
|
self.value = nn.Linear(in_features, in_features, bias=use_bias) |
|
|
|
|
|
self.out = nn.Linear(in_features, in_features, bias=use_bias) |
|
|
|
|
|
if qk_norm == "rms": |
|
|
self.query_ln = RMSNorm(self.head_dim) |
|
|
self.key_ln = RMSNorm(self.head_dim) |
|
|
else: |
|
|
self.query_ln = nn.Identity() |
|
|
self.key_ln = nn.Identity() |
|
|
|
|
|
self.use_rotary_position_embeddings = use_rotary_position_embeddings |
|
|
if use_rotary_position_embeddings: |
|
|
self.rotary_position_embedding = RotaryPositionalEmbedding(self.head_dim) |
|
|
|
|
|
self.use_per_dim_scale = use_per_dim_scale |
|
|
if use_per_dim_scale: |
|
|
self.per_dim_scale = PerDimScale(self.head_dim) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
inputs_q: torch.Tensor, |
|
|
*, |
|
|
decode_cache: DecodeCache | None = None, |
|
|
patch_mask: torch.Tensor | None = None, |
|
|
) -> tuple[torch.Tensor, DecodeCache | None]: |
|
|
batch, num_patches, _ = inputs_q.shape |
|
|
if patch_mask is None: |
|
|
patch_mask = torch.zeros(batch, num_patches, dtype=torch.bool, device=inputs_q.device) |
|
|
|
|
|
if self.fuse_qkv: |
|
|
qkv = self.qkv_proj(inputs_q) |
|
|
query, key, value = torch.chunk(qkv, 3, dim=-1) |
|
|
query = query.view(batch, num_patches, self.num_heads, self.head_dim) |
|
|
key = key.view(batch, num_patches, self.num_heads, self.head_dim) |
|
|
value = value.view(batch, num_patches, self.num_heads, self.head_dim) |
|
|
else: |
|
|
query = self.query(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim) |
|
|
key = self.key(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim) |
|
|
value = self.value(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim) |
|
|
|
|
|
if decode_cache is None: |
|
|
num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1) |
|
|
next_index = torch.zeros_like(num_masked, dtype=torch.int32) |
|
|
else: |
|
|
num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1) + decode_cache.num_masked |
|
|
next_index = decode_cache.next_index.clone() |
|
|
|
|
|
if self.use_rotary_position_embeddings: |
|
|
position = ( |
|
|
torch.arange(num_patches, device=inputs_q.device)[None, :] |
|
|
+ next_index[:, None] |
|
|
- num_masked[:, None] |
|
|
) |
|
|
query = self.rotary_position_embedding(query, position) |
|
|
key = self.rotary_position_embedding(key, position) |
|
|
|
|
|
query = self.query_ln(query) |
|
|
key = self.key_ln(key) |
|
|
|
|
|
if self.use_per_dim_scale: |
|
|
query = self.per_dim_scale(query) |
|
|
|
|
|
if decode_cache is not None: |
|
|
_, cache_size, _, _ = decode_cache.value.shape |
|
|
start = decode_cache.next_index[0] |
|
|
end = start + num_patches |
|
|
|
|
|
decode_cache.key[:, start:end] = key |
|
|
decode_cache.value[:, start:end] = value |
|
|
|
|
|
key = decode_cache.key |
|
|
value = decode_cache.value |
|
|
decode_cache.next_index += num_patches |
|
|
decode_cache.num_masked = num_masked |
|
|
attn_mask = make_attn_mask( |
|
|
query_length=num_patches, |
|
|
num_all_masked_kv=num_masked, |
|
|
query_index_offset=next_index, |
|
|
kv_length=cache_size, |
|
|
) |
|
|
else: |
|
|
attn_mask = make_attn_mask(query_length=num_patches, num_all_masked_kv=num_masked) |
|
|
|
|
|
attn_output = F.scaled_dot_product_attention( |
|
|
query.permute(0, 2, 1, 3), |
|
|
key.permute(0, 2, 1, 3), |
|
|
value.permute(0, 2, 1, 3), |
|
|
attn_mask=attn_mask, |
|
|
scale=1.0, |
|
|
) |
|
|
attn_output = attn_output.permute(0, 2, 1, 3) |
|
|
attn_output = attn_output.reshape(batch, num_patches, self.in_features) |
|
|
return self.out(attn_output), decode_cache |
|
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
|
"""Transformer block used by TimesFM.""" |
|
|
|
|
|
def __init__(self, config: TransformerConfig): |
|
|
super().__init__() |
|
|
if config.attention_norm != "rms" or config.feedforward_norm != "rms": |
|
|
raise ValueError("Only RMS normalization is supported.") |
|
|
|
|
|
self.pre_attn_ln = RMSNorm(config.model_dims) |
|
|
self.post_attn_ln = RMSNorm(config.model_dims) |
|
|
self.attn = MultiHeadAttention( |
|
|
num_heads=config.num_heads, |
|
|
in_features=config.model_dims, |
|
|
use_per_dim_scale=True, |
|
|
use_rotary_position_embeddings=config.use_rotary_position_embeddings, |
|
|
qk_norm=config.qk_norm, |
|
|
fuse_qkv=config.fuse_qkv, |
|
|
) |
|
|
|
|
|
self.pre_ff_ln = RMSNorm(config.model_dims) |
|
|
self.post_ff_ln = RMSNorm(config.model_dims) |
|
|
self.ff0 = nn.Linear(config.model_dims, config.hidden_dims, bias=config.use_bias) |
|
|
self.ff1 = nn.Linear(config.hidden_dims, config.model_dims, bias=config.use_bias) |
|
|
self.activation = ResidualBlock._resolve_activation(config.ff_activation) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_embeddings: torch.Tensor, |
|
|
patch_mask: torch.Tensor, |
|
|
decode_cache: DecodeCache | None = None, |
|
|
) -> tuple[torch.Tensor, DecodeCache | None]: |
|
|
attn_output, decode_cache = self.attn( |
|
|
inputs_q=self.pre_attn_ln(input_embeddings), |
|
|
decode_cache=decode_cache, |
|
|
patch_mask=patch_mask, |
|
|
) |
|
|
attn_output = self.post_attn_ln(attn_output) + input_embeddings |
|
|
feedforward = self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output)))) |
|
|
output_embeddings = self.post_ff_ln(feedforward) + attn_output |
|
|
return output_embeddings, decode_cache |
|
|
|
|
|
|
|
|
class TimesFM2Core(nn.Module): |
|
|
"""Core TimesFM 2.x backbone without external dependencies.""" |
|
|
|
|
|
def __init__(self, definition: TimesFM2Definition | None = None): |
|
|
super().__init__() |
|
|
self.config = definition or TimesFM2Definition() |
|
|
|
|
|
self.p = self.config.input_patch_len |
|
|
self.o = self.config.output_patch_len |
|
|
self.os = self.config.output_quantile_len |
|
|
self.m = self.o // self.p |
|
|
self.x = self.config.stacked_transformers.num_layers |
|
|
self.h = self.config.stacked_transformers.transformer.num_heads |
|
|
self.md = self.config.stacked_transformers.transformer.model_dims |
|
|
self.hd = self.md // self.h |
|
|
self.q = len(self.config.quantiles) + 1 |
|
|
self.aridx = self.config.decode_index |
|
|
|
|
|
self.tokenizer = ResidualBlock(self.config.tokenizer) |
|
|
self.stacked_xf = nn.ModuleList( |
|
|
[Transformer(self.config.stacked_transformers.transformer) for _ in range(self.x)] |
|
|
) |
|
|
self.output_projection_point = ResidualBlock(self.config.output_projection_point) |
|
|
self.output_projection_quantiles = ResidualBlock(self.config.output_projection_quantiles) |
|
|
|
|
|
def load_safetensors(self, path: str, strict: bool = True) -> None: |
|
|
if _load_safetensors is None: |
|
|
raise ImportError("Install safetensors to load TimesFM2 checkpoints.") |
|
|
tensors = _load_safetensors(path) |
|
|
self.load_state_dict(tensors, strict=strict) |
|
|
self.eval() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
inputs: torch.Tensor, |
|
|
masks: torch.Tensor, |
|
|
decode_caches: list[DecodeCache] | None = None, |
|
|
) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], list[DecodeCache]]: |
|
|
tokenizer_inputs = torch.cat([inputs, masks.to(inputs.dtype)], dim=-1) |
|
|
input_embeddings = self.tokenizer(tokenizer_inputs) |
|
|
|
|
|
if decode_caches is None: |
|
|
decode_caches = [None] * self.x |
|
|
|
|
|
output_embeddings = input_embeddings |
|
|
new_decode_caches: list[DecodeCache] = [] |
|
|
for layer, cache in zip(self.stacked_xf, decode_caches): |
|
|
output_embeddings, new_cache = layer(output_embeddings, masks[..., -1], cache) |
|
|
new_decode_caches.append(new_cache) |
|
|
|
|
|
output_ts = self.output_projection_point(output_embeddings) |
|
|
output_quantile_spread = self.output_projection_quantiles(output_embeddings) |
|
|
return (input_embeddings, output_embeddings, output_ts, output_quantile_spread), new_decode_caches |
|
|
|
|
|
def decode( |
|
|
self, |
|
|
horizon: int, |
|
|
inputs: torch.Tensor, |
|
|
masks: torch.Tensor, |
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: |
|
|
"""Autoregressively decodes a batch of sequences.""" |
|
|
batch_size, context = inputs.shape |
|
|
num_decode_steps = (horizon - 1) // self.o |
|
|
num_input_patches = context // self.p |
|
|
use_cache = not torch.is_grad_enabled() |
|
|
|
|
|
patched_inputs = torch.reshape(inputs, (batch_size, -1, self.p)) |
|
|
patched_masks = torch.reshape(masks, (batch_size, -1, self.p)) |
|
|
|
|
|
n = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype) |
|
|
mu = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype) |
|
|
sigma = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype) |
|
|
patch_mu: list[torch.Tensor] = [] |
|
|
patch_sigma: list[torch.Tensor] = [] |
|
|
for i in range(num_input_patches): |
|
|
(n, mu, sigma), _ = update_running_stats(n, mu, sigma, patched_inputs[:, i], patched_masks[:, i]) |
|
|
patch_mu.append(mu) |
|
|
patch_sigma.append(sigma) |
|
|
|
|
|
last_n, last_mu, last_sigma = n, mu, sigma |
|
|
context_mu = torch.stack(patch_mu, dim=1) |
|
|
context_sigma = torch.stack(patch_sigma, dim=1) |
|
|
|
|
|
decode_caches: list[DecodeCache] | None |
|
|
if use_cache: |
|
|
decode_cache_size = num_input_patches + num_decode_steps * self.m |
|
|
decode_caches = [ |
|
|
DecodeCache( |
|
|
next_index=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device), |
|
|
num_masked=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device), |
|
|
key=torch.zeros( |
|
|
batch_size, |
|
|
decode_cache_size, |
|
|
self.h, |
|
|
self.hd, |
|
|
device=inputs.device, |
|
|
dtype=inputs.dtype, |
|
|
), |
|
|
value=torch.zeros( |
|
|
batch_size, |
|
|
decode_cache_size, |
|
|
self.h, |
|
|
self.hd, |
|
|
device=inputs.device, |
|
|
dtype=inputs.dtype, |
|
|
), |
|
|
) |
|
|
for _ in range(self.x) |
|
|
] |
|
|
else: |
|
|
decode_caches = None |
|
|
|
|
|
normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False) |
|
|
normed_inputs = torch.where(patched_masks, torch.zeros((), device=inputs.device, dtype=inputs.dtype), normed_inputs) |
|
|
(_, _, normed_outputs, normed_quantile_spread), decode_caches = self(normed_inputs, patched_masks, decode_caches) |
|
|
|
|
|
renormed_outputs = torch.reshape( |
|
|
revin(normed_outputs, context_mu, context_sigma, reverse=True), |
|
|
(batch_size, -1, self.o, self.q), |
|
|
) |
|
|
renormed_quantile_spread = torch.reshape( |
|
|
revin(normed_quantile_spread, context_mu, context_sigma, reverse=True), |
|
|
(batch_size, -1, self.os, self.q), |
|
|
)[:, -1, ...] |
|
|
|
|
|
ar_outputs: list[torch.Tensor] = [] |
|
|
last_renormed_output = renormed_outputs[:, -1, :, self.aridx] |
|
|
|
|
|
for _ in range(num_decode_steps): |
|
|
new_patched_input = torch.reshape(last_renormed_output, (batch_size, self.m, self.p)) |
|
|
new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool) |
|
|
|
|
|
n, mu, sigma = last_n, last_mu, last_sigma |
|
|
new_mus: list[torch.Tensor] = [] |
|
|
new_sigmas: list[torch.Tensor] = [] |
|
|
for i in range(self.m): |
|
|
(n, mu, sigma), _ = update_running_stats(n, mu, sigma, new_patched_input[:, i], new_mask[:, i]) |
|
|
new_mus.append(mu) |
|
|
new_sigmas.append(sigma) |
|
|
last_n, last_mu, last_sigma = n, mu, sigma |
|
|
new_mu = torch.stack(new_mus, dim=1) |
|
|
new_sigma = torch.stack(new_sigmas, dim=1) |
|
|
|
|
|
new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False) |
|
|
(_, _, new_normed_output, _), decode_caches = self(new_normed_input, new_mask, decode_caches) |
|
|
|
|
|
new_renormed_output = torch.reshape( |
|
|
revin(new_normed_output, new_mu, new_sigma, reverse=True), |
|
|
(batch_size, self.m, self.o, self.q), |
|
|
) |
|
|
ar_outputs.append(new_renormed_output[:, -1, ...]) |
|
|
last_renormed_output = new_renormed_output[:, -1, :, self.aridx] |
|
|
|
|
|
ar_renormed_outputs = torch.stack(ar_outputs, dim=1) if num_decode_steps > 0 else None |
|
|
|
|
|
return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs |
|
|
|
|
|
|
|
|
class TimesFM2(nn.Module): |
|
|
"""High-level TimesFM 2.x wrapper mirroring the TimesFM interface.""" |
|
|
|
|
|
def __init__(self, lookback: int = 512, lookahead: int = 96): |
|
|
super().__init__() |
|
|
self.lookback = lookback |
|
|
self.lookahead = lookahead |
|
|
self.core = TimesFM2Core() |
|
|
|
|
|
if lookback > self.core.config.context_limit: |
|
|
raise ValueError( |
|
|
f"lookback ({lookback}) exceeds maximum context limit ({self.core.config.context_limit})." |
|
|
) |
|
|
|
|
|
def load_state_dict(self, state_dict, strict: bool = True): |
|
|
return self.core.load_state_dict(state_dict, strict=strict) |
|
|
|
|
|
def state_dict(self, *args, **kwargs): |
|
|
return self.core.state_dict(*args, **kwargs) |
|
|
|
|
|
def load_safetensors(self, path: str, strict: bool = True) -> None: |
|
|
self.core.load_safetensors(path, strict=strict) |
|
|
|
|
|
def _prepare_inputs(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
if x.shape[1] < self.lookback: |
|
|
raise ValueError(f"Expected at least {self.lookback} context steps, received {x.shape[1]}.") |
|
|
context = x[:, -self.lookback:] |
|
|
pad_len = (-context.shape[1]) % self.core.p |
|
|
if pad_len > 0: |
|
|
context = F.pad(context, (pad_len, 0)) |
|
|
pad_mask = torch.ones(context.shape[0], pad_len, dtype=torch.bool, device=context.device) |
|
|
mask = torch.cat( |
|
|
[pad_mask, torch.zeros(context.shape[0], self.lookback, dtype=torch.bool, device=context.device)], |
|
|
dim=1, |
|
|
) |
|
|
else: |
|
|
mask = torch.zeros_like(context, dtype=torch.bool) |
|
|
|
|
|
if context.shape[1] > self.core.config.context_limit: |
|
|
context = context[:, -self.core.config.context_limit :] |
|
|
mask = mask[:, -self.core.config.context_limit :] |
|
|
|
|
|
return context, mask |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
*, |
|
|
return_quantiles: bool = False, |
|
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: |
|
|
if x.dim() != 2: |
|
|
raise ValueError(f"Expected input tensor of shape (batch, time), received {tuple(x.shape)}.") |
|
|
|
|
|
inputs, mask = self._prepare_inputs(x.to(dtype=torch.float32)) |
|
|
renormed_outputs, _, ar_outputs = self.core.decode(self.lookahead, inputs, mask) |
|
|
batch_size = inputs.shape[0] |
|
|
|
|
|
to_cat = [renormed_outputs[:, -1, ...]] |
|
|
if ar_outputs is not None: |
|
|
to_cat.append(ar_outputs.reshape(batch_size, -1, self.core.q)) |
|
|
full_forecast = torch.cat(to_cat, dim=1)[:, : self.lookahead, :] |
|
|
|
|
|
point_forecast = full_forecast[..., self.core.aridx] |
|
|
if return_quantiles: |
|
|
return point_forecast, full_forecast |
|
|
return point_forecast |
|
|
|
|
|
|
|
|
__all__ = ["TimesFM2", "TimesFM2Core", "TimesFM2Definition"] |
|
|
|