exx
TimesFM2 grad issue debug
c20b869
"""Self-contained TimesFM 2.x wrapper compatible with the TimesFM interface."""
from __future__ import annotations
import dataclasses
import math
import torch
import torch.nn.functional as F
from torch import nn
try:
from safetensors.torch import load_file as _load_safetensors
except ImportError: # pragma: no cover - optional dependency
_load_safetensors = None
_TOLERANCE = 1e-6
@dataclasses.dataclass(frozen=True)
class ResidualBlockConfig:
input_dims: int
hidden_dims: int
output_dims: int
use_bias: bool
activation: str
@dataclasses.dataclass(frozen=True)
class TransformerConfig:
model_dims: int
hidden_dims: int
num_heads: int
attention_norm: str
feedforward_norm: str
qk_norm: str
use_bias: bool
use_rotary_position_embeddings: bool
ff_activation: str
fuse_qkv: bool
@dataclasses.dataclass(frozen=True)
class StackedTransformersConfig:
num_layers: int
transformer: TransformerConfig
@dataclasses.dataclass(frozen=True)
class TimesFM2Definition:
"""Framework-agnostic description of TimesFM 2.5 (200M parameters)."""
context_limit: int = 16384
input_patch_len: int = 32
output_patch_len: int = 128
output_quantile_len: int = 1024
quantiles: tuple[float, ...] = (
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
)
decode_index: int = 5
tokenizer: ResidualBlockConfig = dataclasses.field(
default_factory=lambda: ResidualBlockConfig(
input_dims=64,
hidden_dims=1280,
output_dims=1280,
use_bias=True,
activation="swish",
)
)
stacked_transformers: StackedTransformersConfig = dataclasses.field(
default_factory=lambda: StackedTransformersConfig(
num_layers=20,
transformer=TransformerConfig(
model_dims=1280,
hidden_dims=1280,
num_heads=16,
attention_norm="rms",
feedforward_norm="rms",
qk_norm="rms",
use_bias=False,
use_rotary_position_embeddings=True,
ff_activation="swish",
fuse_qkv=True,
),
)
)
output_projection_point: ResidualBlockConfig = dataclasses.field(
default_factory=lambda: ResidualBlockConfig(
input_dims=1280,
hidden_dims=1280,
output_dims=1280,
use_bias=False,
activation="swish",
)
)
output_projection_quantiles: ResidualBlockConfig = dataclasses.field(
default_factory=lambda: ResidualBlockConfig(
input_dims=1280,
hidden_dims=1280,
output_dims=10240,
use_bias=False,
activation="swish",
)
)
@dataclasses.dataclass(frozen=False)
class DecodeCache:
next_index: torch.Tensor
num_masked: torch.Tensor
key: torch.Tensor
value: torch.Tensor
def update_running_stats(
n: torch.Tensor,
mu: torch.Tensor,
sigma: torch.Tensor,
x: torch.Tensor,
mask: torch.Tensor,
) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Updates reversible normalization statistics for a new patch."""
is_legit = torch.logical_not(mask)
inc_n = torch.sum(is_legit.to(x.dtype), dim=-1)
inc_mu_numerator = torch.sum(x * is_legit, dim=-1)
inc_n_safe = torch.where(inc_n == 0, 1.0, inc_n)
inc_mu = inc_mu_numerator / inc_n_safe
inc_mu = torch.where(inc_n == 0, 0.0, inc_mu)
inc_var_numerator = torch.sum(((x - inc_mu.unsqueeze(-1)) ** 2) * is_legit, dim=-1)
inc_var = inc_var_numerator / inc_n_safe
inc_var = torch.where(inc_n == 0, 0.0, inc_var)
inc_sigma = torch.sqrt(inc_var)
new_n = n + inc_n
new_n_safe = torch.where(new_n == 0, 1.0, new_n)
new_mu = (n * mu + inc_mu * inc_n) / new_n_safe
new_mu = torch.where(new_n == 0, 0.0, new_mu)
term1 = n * sigma.pow(2)
term2 = inc_n * inc_sigma.pow(2)
term3 = n * (mu - new_mu).pow(2)
term4 = inc_n * (inc_mu - new_mu).pow(2)
new_var = (term1 + term2 + term3 + term4) / new_n_safe
new_var = torch.where(new_n == 0, 0.0, new_var)
new_sigma = torch.sqrt(torch.clamp(new_var, min=0.0))
return (new_n, new_mu, new_sigma), (new_n, new_mu, new_sigma)
def revin(x: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor, reverse: bool = False) -> torch.Tensor:
"""Reversible instance normalization."""
if len(mu.shape) == len(x.shape) - 1:
mu = mu[..., None]
sigma = sigma[..., None]
elif len(mu.shape) == len(x.shape) - 2:
mu = mu[..., None, None]
sigma = sigma[..., None, None]
if reverse:
return x * sigma + mu
sigma_safe = torch.where(sigma < _TOLERANCE, torch.ones_like(sigma), sigma)
return (x - mu) / sigma_safe
class ResidualBlock(nn.Module):
"""Residual block composed of a pair of linear layers."""
def __init__(self, config: ResidualBlockConfig):
super().__init__()
self.activation = self._resolve_activation(config.activation)
self.hidden_layer = nn.Linear(config.input_dims, config.hidden_dims, bias=config.use_bias)
self.output_layer = nn.Linear(config.hidden_dims, config.output_dims, bias=config.use_bias)
self.residual_layer = nn.Linear(config.input_dims, config.output_dims, bias=config.use_bias)
@staticmethod
def _resolve_activation(name: str) -> nn.Module:
if name == "relu":
return nn.ReLU()
if name == "swish":
return nn.SiLU()
if name == "none":
return nn.Identity()
raise ValueError(f"Unsupported activation: {name}")
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden = self.activation(self.hidden_layer(x))
return self.output_layer(hidden) + self.residual_layer(x)
class RMSNorm(nn.Module):
"""Root-mean-square normalization."""
def __init__(self, num_features: int, epsilon: float = 1e-6):
super().__init__()
self.scale = nn.Parameter(torch.zeros(num_features))
self.epsilon = epsilon
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
var = torch.mean(torch.square(inputs), dim=-1, keepdim=True)
normed_inputs = inputs * torch.rsqrt(var + self.epsilon)
return normed_inputs * self.scale
def make_attn_mask(
query_length: int,
num_all_masked_kv: torch.Tensor,
query_index_offset: torch.Tensor | None = None,
kv_length: int = 0,
) -> torch.Tensor:
"""Creates a causal mask consistent with cached decoding."""
if kv_length == 0:
kv_length = query_length
q_index = torch.arange(query_length, device=num_all_masked_kv.device)[None, None, :, None]
if query_index_offset is not None:
q_index = q_index + query_index_offset[:, None, None, None]
kv_index = torch.arange(kv_length, device=num_all_masked_kv.device)[None, None, None, :]
return torch.logical_and(q_index >= kv_index, kv_index >= num_all_masked_kv[:, None, None, None])
class RotaryPositionalEmbedding(nn.Module):
"""Applies rotary position embeddings to query/key projections."""
def __init__(self, embedding_dims: int, min_timescale: float = 1.0, max_timescale: float = 10000.0):
super().__init__()
self.embedding_dims = embedding_dims
self.min_timescale = min_timescale
self.max_timescale = max_timescale
def forward(self, inputs: torch.Tensor, position: torch.Tensor | None = None) -> torch.Tensor:
if self.embedding_dims != inputs.shape[-1]:
raise ValueError("Rotary embedding dimension must equal the head dimension.")
half_dim = self.embedding_dims // 2
fraction = 2 * torch.arange(half_dim, device=inputs.device) / self.embedding_dims
timescale = (self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction).to(inputs.device)
if position is None:
position = torch.arange(inputs.shape[1], dtype=torch.float32, device=inputs.device)[None, :]
if len(inputs.shape) == 4:
position = position[..., None, None]
timescale = timescale[None, None, None, :]
elif len(inputs.shape) == 3:
position = position[..., None]
timescale = timescale[None, None, :]
else:
raise ValueError("Expected rank-3 or rank-4 tensor for rotary embeddings.")
sinusoid = position / timescale
sin = torch.sin(sinusoid)
cos = torch.cos(sinusoid)
first_half, second_half = torch.chunk(inputs, 2, dim=-1)
rotated_first = first_half * cos - second_half * sin
rotated_second = second_half * cos + first_half * sin
return torch.cat([rotated_first, rotated_second], dim=-1)
class PerDimScale(nn.Module):
"""Learned per-dimension scaling used prior to attention."""
def __init__(self, num_dims: int):
super().__init__()
self.num_dims = num_dims
self.per_dim_scale = nn.Parameter(torch.zeros(num_dims))
def forward(self, x: torch.Tensor) -> torch.Tensor:
scale_factor = 1.442695041 / math.sqrt(self.num_dims) * F.softplus(self.per_dim_scale)
return x * scale_factor
class MultiHeadAttention(nn.Module):
"""Multi-head attention supporting fused QKV projections and caching."""
def __init__(
self,
num_heads: int,
in_features: int,
*,
use_per_dim_scale: bool = True,
use_rotary_position_embeddings: bool = True,
use_bias: bool = False,
attention_fn=F.scaled_dot_product_attention,
qk_norm: str = "rms",
fuse_qkv: bool = False,
):
super().__init__()
self.num_heads = num_heads
self.in_features = in_features
self.head_dim = in_features // num_heads
self.use_bias = use_bias
self.attention_fn = attention_fn
self.qk_norm = qk_norm
self.fuse_qkv = fuse_qkv
if in_features % num_heads != 0:
raise ValueError(f"Model dimension {in_features} must be divisible by {num_heads} heads.")
if fuse_qkv:
self.qkv_proj = nn.Linear(in_features, 3 * in_features, bias=use_bias)
else:
self.query = nn.Linear(in_features, in_features, bias=use_bias)
self.key = nn.Linear(in_features, in_features, bias=use_bias)
self.value = nn.Linear(in_features, in_features, bias=use_bias)
self.out = nn.Linear(in_features, in_features, bias=use_bias)
if qk_norm == "rms":
self.query_ln = RMSNorm(self.head_dim)
self.key_ln = RMSNorm(self.head_dim)
else:
self.query_ln = nn.Identity()
self.key_ln = nn.Identity()
self.use_rotary_position_embeddings = use_rotary_position_embeddings
if use_rotary_position_embeddings:
self.rotary_position_embedding = RotaryPositionalEmbedding(self.head_dim)
self.use_per_dim_scale = use_per_dim_scale
if use_per_dim_scale:
self.per_dim_scale = PerDimScale(self.head_dim)
def forward(
self,
inputs_q: torch.Tensor,
*,
decode_cache: DecodeCache | None = None,
patch_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, DecodeCache | None]:
batch, num_patches, _ = inputs_q.shape
if patch_mask is None:
patch_mask = torch.zeros(batch, num_patches, dtype=torch.bool, device=inputs_q.device)
if self.fuse_qkv:
qkv = self.qkv_proj(inputs_q)
query, key, value = torch.chunk(qkv, 3, dim=-1)
query = query.view(batch, num_patches, self.num_heads, self.head_dim)
key = key.view(batch, num_patches, self.num_heads, self.head_dim)
value = value.view(batch, num_patches, self.num_heads, self.head_dim)
else:
query = self.query(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim)
key = self.key(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim)
value = self.value(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim)
if decode_cache is None:
num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1)
next_index = torch.zeros_like(num_masked, dtype=torch.int32)
else:
num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1) + decode_cache.num_masked
next_index = decode_cache.next_index.clone()
if self.use_rotary_position_embeddings:
position = (
torch.arange(num_patches, device=inputs_q.device)[None, :]
+ next_index[:, None]
- num_masked[:, None]
)
query = self.rotary_position_embedding(query, position)
key = self.rotary_position_embedding(key, position)
query = self.query_ln(query)
key = self.key_ln(key)
if self.use_per_dim_scale:
query = self.per_dim_scale(query)
if decode_cache is not None:
_, cache_size, _, _ = decode_cache.value.shape
start = decode_cache.next_index[0]
end = start + num_patches
decode_cache.key[:, start:end] = key
decode_cache.value[:, start:end] = value
key = decode_cache.key
value = decode_cache.value
decode_cache.next_index += num_patches
decode_cache.num_masked = num_masked
attn_mask = make_attn_mask(
query_length=num_patches,
num_all_masked_kv=num_masked,
query_index_offset=next_index,
kv_length=cache_size,
)
else:
attn_mask = make_attn_mask(query_length=num_patches, num_all_masked_kv=num_masked)
attn_output = F.scaled_dot_product_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
attn_mask=attn_mask,
scale=1.0,
)
attn_output = attn_output.permute(0, 2, 1, 3)
attn_output = attn_output.reshape(batch, num_patches, self.in_features)
return self.out(attn_output), decode_cache
class Transformer(nn.Module):
"""Transformer block used by TimesFM."""
def __init__(self, config: TransformerConfig):
super().__init__()
if config.attention_norm != "rms" or config.feedforward_norm != "rms":
raise ValueError("Only RMS normalization is supported.")
self.pre_attn_ln = RMSNorm(config.model_dims)
self.post_attn_ln = RMSNorm(config.model_dims)
self.attn = MultiHeadAttention(
num_heads=config.num_heads,
in_features=config.model_dims,
use_per_dim_scale=True,
use_rotary_position_embeddings=config.use_rotary_position_embeddings,
qk_norm=config.qk_norm,
fuse_qkv=config.fuse_qkv,
)
self.pre_ff_ln = RMSNorm(config.model_dims)
self.post_ff_ln = RMSNorm(config.model_dims)
self.ff0 = nn.Linear(config.model_dims, config.hidden_dims, bias=config.use_bias)
self.ff1 = nn.Linear(config.hidden_dims, config.model_dims, bias=config.use_bias)
self.activation = ResidualBlock._resolve_activation(config.ff_activation)
def forward(
self,
input_embeddings: torch.Tensor,
patch_mask: torch.Tensor,
decode_cache: DecodeCache | None = None,
) -> tuple[torch.Tensor, DecodeCache | None]:
attn_output, decode_cache = self.attn(
inputs_q=self.pre_attn_ln(input_embeddings),
decode_cache=decode_cache,
patch_mask=patch_mask,
)
attn_output = self.post_attn_ln(attn_output) + input_embeddings
feedforward = self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output))))
output_embeddings = self.post_ff_ln(feedforward) + attn_output
return output_embeddings, decode_cache
class TimesFM2Core(nn.Module):
"""Core TimesFM 2.x backbone without external dependencies."""
def __init__(self, definition: TimesFM2Definition | None = None):
super().__init__()
self.config = definition or TimesFM2Definition()
self.p = self.config.input_patch_len
self.o = self.config.output_patch_len
self.os = self.config.output_quantile_len
self.m = self.o // self.p
self.x = self.config.stacked_transformers.num_layers
self.h = self.config.stacked_transformers.transformer.num_heads
self.md = self.config.stacked_transformers.transformer.model_dims
self.hd = self.md // self.h
self.q = len(self.config.quantiles) + 1
self.aridx = self.config.decode_index
self.tokenizer = ResidualBlock(self.config.tokenizer)
self.stacked_xf = nn.ModuleList(
[Transformer(self.config.stacked_transformers.transformer) for _ in range(self.x)]
)
self.output_projection_point = ResidualBlock(self.config.output_projection_point)
self.output_projection_quantiles = ResidualBlock(self.config.output_projection_quantiles)
def load_safetensors(self, path: str, strict: bool = True) -> None:
if _load_safetensors is None:
raise ImportError("Install safetensors to load TimesFM2 checkpoints.")
tensors = _load_safetensors(path)
self.load_state_dict(tensors, strict=strict)
self.eval()
def forward(
self,
inputs: torch.Tensor,
masks: torch.Tensor,
decode_caches: list[DecodeCache] | None = None,
) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], list[DecodeCache]]:
tokenizer_inputs = torch.cat([inputs, masks.to(inputs.dtype)], dim=-1)
input_embeddings = self.tokenizer(tokenizer_inputs)
if decode_caches is None:
decode_caches = [None] * self.x # type: ignore[list-item]
output_embeddings = input_embeddings
new_decode_caches: list[DecodeCache] = []
for layer, cache in zip(self.stacked_xf, decode_caches):
output_embeddings, new_cache = layer(output_embeddings, masks[..., -1], cache)
new_decode_caches.append(new_cache)
output_ts = self.output_projection_point(output_embeddings)
output_quantile_spread = self.output_projection_quantiles(output_embeddings)
return (input_embeddings, output_embeddings, output_ts, output_quantile_spread), new_decode_caches
def decode(
self,
horizon: int,
inputs: torch.Tensor,
masks: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""Autoregressively decodes a batch of sequences."""
batch_size, context = inputs.shape
num_decode_steps = (horizon - 1) // self.o
num_input_patches = context // self.p
use_cache = not torch.is_grad_enabled()
patched_inputs = torch.reshape(inputs, (batch_size, -1, self.p))
patched_masks = torch.reshape(masks, (batch_size, -1, self.p))
n = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype)
mu = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype)
sigma = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype)
patch_mu: list[torch.Tensor] = []
patch_sigma: list[torch.Tensor] = []
for i in range(num_input_patches):
(n, mu, sigma), _ = update_running_stats(n, mu, sigma, patched_inputs[:, i], patched_masks[:, i])
patch_mu.append(mu)
patch_sigma.append(sigma)
last_n, last_mu, last_sigma = n, mu, sigma
context_mu = torch.stack(patch_mu, dim=1)
context_sigma = torch.stack(patch_sigma, dim=1)
decode_caches: list[DecodeCache] | None
if use_cache:
decode_cache_size = num_input_patches + num_decode_steps * self.m
decode_caches = [
DecodeCache(
next_index=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
num_masked=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
key=torch.zeros(
batch_size,
decode_cache_size,
self.h,
self.hd,
device=inputs.device,
dtype=inputs.dtype,
),
value=torch.zeros(
batch_size,
decode_cache_size,
self.h,
self.hd,
device=inputs.device,
dtype=inputs.dtype,
),
)
for _ in range(self.x)
]
else:
decode_caches = None
normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)
normed_inputs = torch.where(patched_masks, torch.zeros((), device=inputs.device, dtype=inputs.dtype), normed_inputs)
(_, _, normed_outputs, normed_quantile_spread), decode_caches = self(normed_inputs, patched_masks, decode_caches)
renormed_outputs = torch.reshape(
revin(normed_outputs, context_mu, context_sigma, reverse=True),
(batch_size, -1, self.o, self.q),
)
renormed_quantile_spread = torch.reshape(
revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),
(batch_size, -1, self.os, self.q),
)[:, -1, ...]
ar_outputs: list[torch.Tensor] = []
last_renormed_output = renormed_outputs[:, -1, :, self.aridx]
for _ in range(num_decode_steps):
new_patched_input = torch.reshape(last_renormed_output, (batch_size, self.m, self.p))
new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool)
n, mu, sigma = last_n, last_mu, last_sigma
new_mus: list[torch.Tensor] = []
new_sigmas: list[torch.Tensor] = []
for i in range(self.m):
(n, mu, sigma), _ = update_running_stats(n, mu, sigma, new_patched_input[:, i], new_mask[:, i])
new_mus.append(mu)
new_sigmas.append(sigma)
last_n, last_mu, last_sigma = n, mu, sigma
new_mu = torch.stack(new_mus, dim=1)
new_sigma = torch.stack(new_sigmas, dim=1)
new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
(_, _, new_normed_output, _), decode_caches = self(new_normed_input, new_mask, decode_caches)
new_renormed_output = torch.reshape(
revin(new_normed_output, new_mu, new_sigma, reverse=True),
(batch_size, self.m, self.o, self.q),
)
ar_outputs.append(new_renormed_output[:, -1, ...])
last_renormed_output = new_renormed_output[:, -1, :, self.aridx]
ar_renormed_outputs = torch.stack(ar_outputs, dim=1) if num_decode_steps > 0 else None
return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs
class TimesFM2(nn.Module):
"""High-level TimesFM 2.x wrapper mirroring the TimesFM interface."""
def __init__(self, lookback: int = 512, lookahead: int = 96):
super().__init__()
self.lookback = lookback
self.lookahead = lookahead
self.core = TimesFM2Core()
if lookback > self.core.config.context_limit:
raise ValueError(
f"lookback ({lookback}) exceeds maximum context limit ({self.core.config.context_limit})."
)
def load_state_dict(self, state_dict, strict: bool = True):
return self.core.load_state_dict(state_dict, strict=strict)
def state_dict(self, *args, **kwargs):
return self.core.state_dict(*args, **kwargs)
def load_safetensors(self, path: str, strict: bool = True) -> None:
self.core.load_safetensors(path, strict=strict)
def _prepare_inputs(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if x.shape[1] < self.lookback:
raise ValueError(f"Expected at least {self.lookback} context steps, received {x.shape[1]}.")
context = x[:, -self.lookback:]
pad_len = (-context.shape[1]) % self.core.p
if pad_len > 0:
context = F.pad(context, (pad_len, 0))
pad_mask = torch.ones(context.shape[0], pad_len, dtype=torch.bool, device=context.device)
mask = torch.cat(
[pad_mask, torch.zeros(context.shape[0], self.lookback, dtype=torch.bool, device=context.device)],
dim=1,
)
else:
mask = torch.zeros_like(context, dtype=torch.bool)
if context.shape[1] > self.core.config.context_limit:
context = context[:, -self.core.config.context_limit :]
mask = mask[:, -self.core.config.context_limit :]
return context, mask
def forward(
self,
x: torch.Tensor,
*,
return_quantiles: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if x.dim() != 2:
raise ValueError(f"Expected input tensor of shape (batch, time), received {tuple(x.shape)}.")
inputs, mask = self._prepare_inputs(x.to(dtype=torch.float32))
renormed_outputs, _, ar_outputs = self.core.decode(self.lookahead, inputs, mask)
batch_size = inputs.shape[0]
to_cat = [renormed_outputs[:, -1, ...]]
if ar_outputs is not None:
to_cat.append(ar_outputs.reshape(batch_size, -1, self.core.q))
full_forecast = torch.cat(to_cat, dim=1)[:, : self.lookahead, :]
point_forecast = full_forecast[..., self.core.aridx]
if return_quantiles:
return point_forecast, full_forecast
return point_forecast
__all__ = ["TimesFM2", "TimesFM2Core", "TimesFM2Definition"]