"""Self-contained TimesFM 2.x wrapper compatible with the TimesFM interface.""" from __future__ import annotations import dataclasses import math import torch import torch.nn.functional as F from torch import nn try: from safetensors.torch import load_file as _load_safetensors except ImportError: # pragma: no cover - optional dependency _load_safetensors = None _TOLERANCE = 1e-6 @dataclasses.dataclass(frozen=True) class ResidualBlockConfig: input_dims: int hidden_dims: int output_dims: int use_bias: bool activation: str @dataclasses.dataclass(frozen=True) class TransformerConfig: model_dims: int hidden_dims: int num_heads: int attention_norm: str feedforward_norm: str qk_norm: str use_bias: bool use_rotary_position_embeddings: bool ff_activation: str fuse_qkv: bool @dataclasses.dataclass(frozen=True) class StackedTransformersConfig: num_layers: int transformer: TransformerConfig @dataclasses.dataclass(frozen=True) class TimesFM2Definition: """Framework-agnostic description of TimesFM 2.5 (200M parameters).""" context_limit: int = 16384 input_patch_len: int = 32 output_patch_len: int = 128 output_quantile_len: int = 1024 quantiles: tuple[float, ...] = ( 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, ) decode_index: int = 5 tokenizer: ResidualBlockConfig = dataclasses.field( default_factory=lambda: ResidualBlockConfig( input_dims=64, hidden_dims=1280, output_dims=1280, use_bias=True, activation="swish", ) ) stacked_transformers: StackedTransformersConfig = dataclasses.field( default_factory=lambda: StackedTransformersConfig( num_layers=20, transformer=TransformerConfig( model_dims=1280, hidden_dims=1280, num_heads=16, attention_norm="rms", feedforward_norm="rms", qk_norm="rms", use_bias=False, use_rotary_position_embeddings=True, ff_activation="swish", fuse_qkv=True, ), ) ) output_projection_point: ResidualBlockConfig = dataclasses.field( default_factory=lambda: ResidualBlockConfig( input_dims=1280, hidden_dims=1280, output_dims=1280, use_bias=False, activation="swish", ) ) output_projection_quantiles: ResidualBlockConfig = dataclasses.field( default_factory=lambda: ResidualBlockConfig( input_dims=1280, hidden_dims=1280, output_dims=10240, use_bias=False, activation="swish", ) ) @dataclasses.dataclass(frozen=False) class DecodeCache: next_index: torch.Tensor num_masked: torch.Tensor key: torch.Tensor value: torch.Tensor def update_running_stats( n: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor, x: torch.Tensor, mask: torch.Tensor, ) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Updates reversible normalization statistics for a new patch.""" is_legit = torch.logical_not(mask) inc_n = torch.sum(is_legit.to(x.dtype), dim=-1) inc_mu_numerator = torch.sum(x * is_legit, dim=-1) inc_n_safe = torch.where(inc_n == 0, 1.0, inc_n) inc_mu = inc_mu_numerator / inc_n_safe inc_mu = torch.where(inc_n == 0, 0.0, inc_mu) inc_var_numerator = torch.sum(((x - inc_mu.unsqueeze(-1)) ** 2) * is_legit, dim=-1) inc_var = inc_var_numerator / inc_n_safe inc_var = torch.where(inc_n == 0, 0.0, inc_var) inc_sigma = torch.sqrt(inc_var) new_n = n + inc_n new_n_safe = torch.where(new_n == 0, 1.0, new_n) new_mu = (n * mu + inc_mu * inc_n) / new_n_safe new_mu = torch.where(new_n == 0, 0.0, new_mu) term1 = n * sigma.pow(2) term2 = inc_n * inc_sigma.pow(2) term3 = n * (mu - new_mu).pow(2) term4 = inc_n * (inc_mu - new_mu).pow(2) new_var = (term1 + term2 + term3 + term4) / new_n_safe new_var = torch.where(new_n == 0, 0.0, new_var) new_sigma = torch.sqrt(torch.clamp(new_var, min=0.0)) return (new_n, new_mu, new_sigma), (new_n, new_mu, new_sigma) def revin(x: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor, reverse: bool = False) -> torch.Tensor: """Reversible instance normalization.""" if len(mu.shape) == len(x.shape) - 1: mu = mu[..., None] sigma = sigma[..., None] elif len(mu.shape) == len(x.shape) - 2: mu = mu[..., None, None] sigma = sigma[..., None, None] if reverse: return x * sigma + mu sigma_safe = torch.where(sigma < _TOLERANCE, torch.ones_like(sigma), sigma) return (x - mu) / sigma_safe class ResidualBlock(nn.Module): """Residual block composed of a pair of linear layers.""" def __init__(self, config: ResidualBlockConfig): super().__init__() self.activation = self._resolve_activation(config.activation) self.hidden_layer = nn.Linear(config.input_dims, config.hidden_dims, bias=config.use_bias) self.output_layer = nn.Linear(config.hidden_dims, config.output_dims, bias=config.use_bias) self.residual_layer = nn.Linear(config.input_dims, config.output_dims, bias=config.use_bias) @staticmethod def _resolve_activation(name: str) -> nn.Module: if name == "relu": return nn.ReLU() if name == "swish": return nn.SiLU() if name == "none": return nn.Identity() raise ValueError(f"Unsupported activation: {name}") def forward(self, x: torch.Tensor) -> torch.Tensor: hidden = self.activation(self.hidden_layer(x)) return self.output_layer(hidden) + self.residual_layer(x) class RMSNorm(nn.Module): """Root-mean-square normalization.""" def __init__(self, num_features: int, epsilon: float = 1e-6): super().__init__() self.scale = nn.Parameter(torch.zeros(num_features)) self.epsilon = epsilon def forward(self, inputs: torch.Tensor) -> torch.Tensor: var = torch.mean(torch.square(inputs), dim=-1, keepdim=True) normed_inputs = inputs * torch.rsqrt(var + self.epsilon) return normed_inputs * self.scale def make_attn_mask( query_length: int, num_all_masked_kv: torch.Tensor, query_index_offset: torch.Tensor | None = None, kv_length: int = 0, ) -> torch.Tensor: """Creates a causal mask consistent with cached decoding.""" if kv_length == 0: kv_length = query_length q_index = torch.arange(query_length, device=num_all_masked_kv.device)[None, None, :, None] if query_index_offset is not None: q_index = q_index + query_index_offset[:, None, None, None] kv_index = torch.arange(kv_length, device=num_all_masked_kv.device)[None, None, None, :] return torch.logical_and(q_index >= kv_index, kv_index >= num_all_masked_kv[:, None, None, None]) class RotaryPositionalEmbedding(nn.Module): """Applies rotary position embeddings to query/key projections.""" def __init__(self, embedding_dims: int, min_timescale: float = 1.0, max_timescale: float = 10000.0): super().__init__() self.embedding_dims = embedding_dims self.min_timescale = min_timescale self.max_timescale = max_timescale def forward(self, inputs: torch.Tensor, position: torch.Tensor | None = None) -> torch.Tensor: if self.embedding_dims != inputs.shape[-1]: raise ValueError("Rotary embedding dimension must equal the head dimension.") half_dim = self.embedding_dims // 2 fraction = 2 * torch.arange(half_dim, device=inputs.device) / self.embedding_dims timescale = (self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction).to(inputs.device) if position is None: position = torch.arange(inputs.shape[1], dtype=torch.float32, device=inputs.device)[None, :] if len(inputs.shape) == 4: position = position[..., None, None] timescale = timescale[None, None, None, :] elif len(inputs.shape) == 3: position = position[..., None] timescale = timescale[None, None, :] else: raise ValueError("Expected rank-3 or rank-4 tensor for rotary embeddings.") sinusoid = position / timescale sin = torch.sin(sinusoid) cos = torch.cos(sinusoid) first_half, second_half = torch.chunk(inputs, 2, dim=-1) rotated_first = first_half * cos - second_half * sin rotated_second = second_half * cos + first_half * sin return torch.cat([rotated_first, rotated_second], dim=-1) class PerDimScale(nn.Module): """Learned per-dimension scaling used prior to attention.""" def __init__(self, num_dims: int): super().__init__() self.num_dims = num_dims self.per_dim_scale = nn.Parameter(torch.zeros(num_dims)) def forward(self, x: torch.Tensor) -> torch.Tensor: scale_factor = 1.442695041 / math.sqrt(self.num_dims) * F.softplus(self.per_dim_scale) return x * scale_factor class MultiHeadAttention(nn.Module): """Multi-head attention supporting fused QKV projections and caching.""" def __init__( self, num_heads: int, in_features: int, *, use_per_dim_scale: bool = True, use_rotary_position_embeddings: bool = True, use_bias: bool = False, attention_fn=F.scaled_dot_product_attention, qk_norm: str = "rms", fuse_qkv: bool = False, ): super().__init__() self.num_heads = num_heads self.in_features = in_features self.head_dim = in_features // num_heads self.use_bias = use_bias self.attention_fn = attention_fn self.qk_norm = qk_norm self.fuse_qkv = fuse_qkv if in_features % num_heads != 0: raise ValueError(f"Model dimension {in_features} must be divisible by {num_heads} heads.") if fuse_qkv: self.qkv_proj = nn.Linear(in_features, 3 * in_features, bias=use_bias) else: self.query = nn.Linear(in_features, in_features, bias=use_bias) self.key = nn.Linear(in_features, in_features, bias=use_bias) self.value = nn.Linear(in_features, in_features, bias=use_bias) self.out = nn.Linear(in_features, in_features, bias=use_bias) if qk_norm == "rms": self.query_ln = RMSNorm(self.head_dim) self.key_ln = RMSNorm(self.head_dim) else: self.query_ln = nn.Identity() self.key_ln = nn.Identity() self.use_rotary_position_embeddings = use_rotary_position_embeddings if use_rotary_position_embeddings: self.rotary_position_embedding = RotaryPositionalEmbedding(self.head_dim) self.use_per_dim_scale = use_per_dim_scale if use_per_dim_scale: self.per_dim_scale = PerDimScale(self.head_dim) def forward( self, inputs_q: torch.Tensor, *, decode_cache: DecodeCache | None = None, patch_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, DecodeCache | None]: batch, num_patches, _ = inputs_q.shape if patch_mask is None: patch_mask = torch.zeros(batch, num_patches, dtype=torch.bool, device=inputs_q.device) if self.fuse_qkv: qkv = self.qkv_proj(inputs_q) query, key, value = torch.chunk(qkv, 3, dim=-1) query = query.view(batch, num_patches, self.num_heads, self.head_dim) key = key.view(batch, num_patches, self.num_heads, self.head_dim) value = value.view(batch, num_patches, self.num_heads, self.head_dim) else: query = self.query(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim) key = self.key(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim) value = self.value(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim) if decode_cache is None: num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1) next_index = torch.zeros_like(num_masked, dtype=torch.int32) else: num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1) + decode_cache.num_masked next_index = decode_cache.next_index.clone() if self.use_rotary_position_embeddings: position = ( torch.arange(num_patches, device=inputs_q.device)[None, :] + next_index[:, None] - num_masked[:, None] ) query = self.rotary_position_embedding(query, position) key = self.rotary_position_embedding(key, position) query = self.query_ln(query) key = self.key_ln(key) if self.use_per_dim_scale: query = self.per_dim_scale(query) if decode_cache is not None: _, cache_size, _, _ = decode_cache.value.shape start = decode_cache.next_index[0] end = start + num_patches decode_cache.key[:, start:end] = key decode_cache.value[:, start:end] = value key = decode_cache.key value = decode_cache.value decode_cache.next_index += num_patches decode_cache.num_masked = num_masked attn_mask = make_attn_mask( query_length=num_patches, num_all_masked_kv=num_masked, query_index_offset=next_index, kv_length=cache_size, ) else: attn_mask = make_attn_mask(query_length=num_patches, num_all_masked_kv=num_masked) attn_output = F.scaled_dot_product_attention( query.permute(0, 2, 1, 3), key.permute(0, 2, 1, 3), value.permute(0, 2, 1, 3), attn_mask=attn_mask, scale=1.0, ) attn_output = attn_output.permute(0, 2, 1, 3) attn_output = attn_output.reshape(batch, num_patches, self.in_features) return self.out(attn_output), decode_cache class Transformer(nn.Module): """Transformer block used by TimesFM.""" def __init__(self, config: TransformerConfig): super().__init__() if config.attention_norm != "rms" or config.feedforward_norm != "rms": raise ValueError("Only RMS normalization is supported.") self.pre_attn_ln = RMSNorm(config.model_dims) self.post_attn_ln = RMSNorm(config.model_dims) self.attn = MultiHeadAttention( num_heads=config.num_heads, in_features=config.model_dims, use_per_dim_scale=True, use_rotary_position_embeddings=config.use_rotary_position_embeddings, qk_norm=config.qk_norm, fuse_qkv=config.fuse_qkv, ) self.pre_ff_ln = RMSNorm(config.model_dims) self.post_ff_ln = RMSNorm(config.model_dims) self.ff0 = nn.Linear(config.model_dims, config.hidden_dims, bias=config.use_bias) self.ff1 = nn.Linear(config.hidden_dims, config.model_dims, bias=config.use_bias) self.activation = ResidualBlock._resolve_activation(config.ff_activation) def forward( self, input_embeddings: torch.Tensor, patch_mask: torch.Tensor, decode_cache: DecodeCache | None = None, ) -> tuple[torch.Tensor, DecodeCache | None]: attn_output, decode_cache = self.attn( inputs_q=self.pre_attn_ln(input_embeddings), decode_cache=decode_cache, patch_mask=patch_mask, ) attn_output = self.post_attn_ln(attn_output) + input_embeddings feedforward = self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output)))) output_embeddings = self.post_ff_ln(feedforward) + attn_output return output_embeddings, decode_cache class TimesFM2Core(nn.Module): """Core TimesFM 2.x backbone without external dependencies.""" def __init__(self, definition: TimesFM2Definition | None = None): super().__init__() self.config = definition or TimesFM2Definition() self.p = self.config.input_patch_len self.o = self.config.output_patch_len self.os = self.config.output_quantile_len self.m = self.o // self.p self.x = self.config.stacked_transformers.num_layers self.h = self.config.stacked_transformers.transformer.num_heads self.md = self.config.stacked_transformers.transformer.model_dims self.hd = self.md // self.h self.q = len(self.config.quantiles) + 1 self.aridx = self.config.decode_index self.tokenizer = ResidualBlock(self.config.tokenizer) self.stacked_xf = nn.ModuleList( [Transformer(self.config.stacked_transformers.transformer) for _ in range(self.x)] ) self.output_projection_point = ResidualBlock(self.config.output_projection_point) self.output_projection_quantiles = ResidualBlock(self.config.output_projection_quantiles) def load_safetensors(self, path: str, strict: bool = True) -> None: if _load_safetensors is None: raise ImportError("Install safetensors to load TimesFM2 checkpoints.") tensors = _load_safetensors(path) self.load_state_dict(tensors, strict=strict) self.eval() def forward( self, inputs: torch.Tensor, masks: torch.Tensor, decode_caches: list[DecodeCache] | None = None, ) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], list[DecodeCache]]: tokenizer_inputs = torch.cat([inputs, masks.to(inputs.dtype)], dim=-1) input_embeddings = self.tokenizer(tokenizer_inputs) if decode_caches is None: decode_caches = [None] * self.x # type: ignore[list-item] output_embeddings = input_embeddings new_decode_caches: list[DecodeCache] = [] for layer, cache in zip(self.stacked_xf, decode_caches): output_embeddings, new_cache = layer(output_embeddings, masks[..., -1], cache) new_decode_caches.append(new_cache) output_ts = self.output_projection_point(output_embeddings) output_quantile_spread = self.output_projection_quantiles(output_embeddings) return (input_embeddings, output_embeddings, output_ts, output_quantile_spread), new_decode_caches def decode( self, horizon: int, inputs: torch.Tensor, masks: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: """Autoregressively decodes a batch of sequences.""" batch_size, context = inputs.shape num_decode_steps = (horizon - 1) // self.o num_input_patches = context // self.p use_cache = not torch.is_grad_enabled() patched_inputs = torch.reshape(inputs, (batch_size, -1, self.p)) patched_masks = torch.reshape(masks, (batch_size, -1, self.p)) n = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype) mu = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype) sigma = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype) patch_mu: list[torch.Tensor] = [] patch_sigma: list[torch.Tensor] = [] for i in range(num_input_patches): (n, mu, sigma), _ = update_running_stats(n, mu, sigma, patched_inputs[:, i], patched_masks[:, i]) patch_mu.append(mu) patch_sigma.append(sigma) last_n, last_mu, last_sigma = n, mu, sigma context_mu = torch.stack(patch_mu, dim=1) context_sigma = torch.stack(patch_sigma, dim=1) decode_caches: list[DecodeCache] | None if use_cache: decode_cache_size = num_input_patches + num_decode_steps * self.m decode_caches = [ DecodeCache( next_index=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device), num_masked=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device), key=torch.zeros( batch_size, decode_cache_size, self.h, self.hd, device=inputs.device, dtype=inputs.dtype, ), value=torch.zeros( batch_size, decode_cache_size, self.h, self.hd, device=inputs.device, dtype=inputs.dtype, ), ) for _ in range(self.x) ] else: decode_caches = None normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False) normed_inputs = torch.where(patched_masks, torch.zeros((), device=inputs.device, dtype=inputs.dtype), normed_inputs) (_, _, normed_outputs, normed_quantile_spread), decode_caches = self(normed_inputs, patched_masks, decode_caches) renormed_outputs = torch.reshape( revin(normed_outputs, context_mu, context_sigma, reverse=True), (batch_size, -1, self.o, self.q), ) renormed_quantile_spread = torch.reshape( revin(normed_quantile_spread, context_mu, context_sigma, reverse=True), (batch_size, -1, self.os, self.q), )[:, -1, ...] ar_outputs: list[torch.Tensor] = [] last_renormed_output = renormed_outputs[:, -1, :, self.aridx] for _ in range(num_decode_steps): new_patched_input = torch.reshape(last_renormed_output, (batch_size, self.m, self.p)) new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool) n, mu, sigma = last_n, last_mu, last_sigma new_mus: list[torch.Tensor] = [] new_sigmas: list[torch.Tensor] = [] for i in range(self.m): (n, mu, sigma), _ = update_running_stats(n, mu, sigma, new_patched_input[:, i], new_mask[:, i]) new_mus.append(mu) new_sigmas.append(sigma) last_n, last_mu, last_sigma = n, mu, sigma new_mu = torch.stack(new_mus, dim=1) new_sigma = torch.stack(new_sigmas, dim=1) new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False) (_, _, new_normed_output, _), decode_caches = self(new_normed_input, new_mask, decode_caches) new_renormed_output = torch.reshape( revin(new_normed_output, new_mu, new_sigma, reverse=True), (batch_size, self.m, self.o, self.q), ) ar_outputs.append(new_renormed_output[:, -1, ...]) last_renormed_output = new_renormed_output[:, -1, :, self.aridx] ar_renormed_outputs = torch.stack(ar_outputs, dim=1) if num_decode_steps > 0 else None return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs class TimesFM2(nn.Module): """High-level TimesFM 2.x wrapper mirroring the TimesFM interface.""" def __init__(self, lookback: int = 512, lookahead: int = 96): super().__init__() self.lookback = lookback self.lookahead = lookahead self.core = TimesFM2Core() if lookback > self.core.config.context_limit: raise ValueError( f"lookback ({lookback}) exceeds maximum context limit ({self.core.config.context_limit})." ) def load_state_dict(self, state_dict, strict: bool = True): return self.core.load_state_dict(state_dict, strict=strict) def state_dict(self, *args, **kwargs): return self.core.state_dict(*args, **kwargs) def load_safetensors(self, path: str, strict: bool = True) -> None: self.core.load_safetensors(path, strict=strict) def _prepare_inputs(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if x.shape[1] < self.lookback: raise ValueError(f"Expected at least {self.lookback} context steps, received {x.shape[1]}.") context = x[:, -self.lookback:] pad_len = (-context.shape[1]) % self.core.p if pad_len > 0: context = F.pad(context, (pad_len, 0)) pad_mask = torch.ones(context.shape[0], pad_len, dtype=torch.bool, device=context.device) mask = torch.cat( [pad_mask, torch.zeros(context.shape[0], self.lookback, dtype=torch.bool, device=context.device)], dim=1, ) else: mask = torch.zeros_like(context, dtype=torch.bool) if context.shape[1] > self.core.config.context_limit: context = context[:, -self.core.config.context_limit :] mask = mask[:, -self.core.config.context_limit :] return context, mask def forward( self, x: torch.Tensor, *, return_quantiles: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if x.dim() != 2: raise ValueError(f"Expected input tensor of shape (batch, time), received {tuple(x.shape)}.") inputs, mask = self._prepare_inputs(x.to(dtype=torch.float32)) renormed_outputs, _, ar_outputs = self.core.decode(self.lookahead, inputs, mask) batch_size = inputs.shape[0] to_cat = [renormed_outputs[:, -1, ...]] if ar_outputs is not None: to_cat.append(ar_outputs.reshape(batch_size, -1, self.core.q)) full_forecast = torch.cat(to_cat, dim=1)[:, : self.lookahead, :] point_forecast = full_forecast[..., self.core.aridx] if return_quantiles: return point_forecast, full_forecast return point_forecast __all__ = ["TimesFM2", "TimesFM2Core", "TimesFM2Definition"]