File size: 23,833 Bytes
93517af c20b869 93517af c20b869 93517af c20b869 93517af c20b869 93517af c20b869 93517af c20b869 93517af c20b869 93517af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 |
"""Self-contained TimesFM 2.x wrapper compatible with the TimesFM interface."""
from __future__ import annotations
import dataclasses
import math
import torch
import torch.nn.functional as F
from torch import nn
try:
from safetensors.torch import load_file as _load_safetensors
except ImportError: # pragma: no cover - optional dependency
_load_safetensors = None
_TOLERANCE = 1e-6
@dataclasses.dataclass(frozen=True)
class ResidualBlockConfig:
input_dims: int
hidden_dims: int
output_dims: int
use_bias: bool
activation: str
@dataclasses.dataclass(frozen=True)
class TransformerConfig:
model_dims: int
hidden_dims: int
num_heads: int
attention_norm: str
feedforward_norm: str
qk_norm: str
use_bias: bool
use_rotary_position_embeddings: bool
ff_activation: str
fuse_qkv: bool
@dataclasses.dataclass(frozen=True)
class StackedTransformersConfig:
num_layers: int
transformer: TransformerConfig
@dataclasses.dataclass(frozen=True)
class TimesFM2Definition:
"""Framework-agnostic description of TimesFM 2.5 (200M parameters)."""
context_limit: int = 16384
input_patch_len: int = 32
output_patch_len: int = 128
output_quantile_len: int = 1024
quantiles: tuple[float, ...] = (
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
)
decode_index: int = 5
tokenizer: ResidualBlockConfig = dataclasses.field(
default_factory=lambda: ResidualBlockConfig(
input_dims=64,
hidden_dims=1280,
output_dims=1280,
use_bias=True,
activation="swish",
)
)
stacked_transformers: StackedTransformersConfig = dataclasses.field(
default_factory=lambda: StackedTransformersConfig(
num_layers=20,
transformer=TransformerConfig(
model_dims=1280,
hidden_dims=1280,
num_heads=16,
attention_norm="rms",
feedforward_norm="rms",
qk_norm="rms",
use_bias=False,
use_rotary_position_embeddings=True,
ff_activation="swish",
fuse_qkv=True,
),
)
)
output_projection_point: ResidualBlockConfig = dataclasses.field(
default_factory=lambda: ResidualBlockConfig(
input_dims=1280,
hidden_dims=1280,
output_dims=1280,
use_bias=False,
activation="swish",
)
)
output_projection_quantiles: ResidualBlockConfig = dataclasses.field(
default_factory=lambda: ResidualBlockConfig(
input_dims=1280,
hidden_dims=1280,
output_dims=10240,
use_bias=False,
activation="swish",
)
)
@dataclasses.dataclass(frozen=False)
class DecodeCache:
next_index: torch.Tensor
num_masked: torch.Tensor
key: torch.Tensor
value: torch.Tensor
def update_running_stats(
n: torch.Tensor,
mu: torch.Tensor,
sigma: torch.Tensor,
x: torch.Tensor,
mask: torch.Tensor,
) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Updates reversible normalization statistics for a new patch."""
is_legit = torch.logical_not(mask)
inc_n = torch.sum(is_legit.to(x.dtype), dim=-1)
inc_mu_numerator = torch.sum(x * is_legit, dim=-1)
inc_n_safe = torch.where(inc_n == 0, 1.0, inc_n)
inc_mu = inc_mu_numerator / inc_n_safe
inc_mu = torch.where(inc_n == 0, 0.0, inc_mu)
inc_var_numerator = torch.sum(((x - inc_mu.unsqueeze(-1)) ** 2) * is_legit, dim=-1)
inc_var = inc_var_numerator / inc_n_safe
inc_var = torch.where(inc_n == 0, 0.0, inc_var)
inc_sigma = torch.sqrt(inc_var)
new_n = n + inc_n
new_n_safe = torch.where(new_n == 0, 1.0, new_n)
new_mu = (n * mu + inc_mu * inc_n) / new_n_safe
new_mu = torch.where(new_n == 0, 0.0, new_mu)
term1 = n * sigma.pow(2)
term2 = inc_n * inc_sigma.pow(2)
term3 = n * (mu - new_mu).pow(2)
term4 = inc_n * (inc_mu - new_mu).pow(2)
new_var = (term1 + term2 + term3 + term4) / new_n_safe
new_var = torch.where(new_n == 0, 0.0, new_var)
new_sigma = torch.sqrt(torch.clamp(new_var, min=0.0))
return (new_n, new_mu, new_sigma), (new_n, new_mu, new_sigma)
def revin(x: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor, reverse: bool = False) -> torch.Tensor:
"""Reversible instance normalization."""
if len(mu.shape) == len(x.shape) - 1:
mu = mu[..., None]
sigma = sigma[..., None]
elif len(mu.shape) == len(x.shape) - 2:
mu = mu[..., None, None]
sigma = sigma[..., None, None]
if reverse:
return x * sigma + mu
sigma_safe = torch.where(sigma < _TOLERANCE, torch.ones_like(sigma), sigma)
return (x - mu) / sigma_safe
class ResidualBlock(nn.Module):
"""Residual block composed of a pair of linear layers."""
def __init__(self, config: ResidualBlockConfig):
super().__init__()
self.activation = self._resolve_activation(config.activation)
self.hidden_layer = nn.Linear(config.input_dims, config.hidden_dims, bias=config.use_bias)
self.output_layer = nn.Linear(config.hidden_dims, config.output_dims, bias=config.use_bias)
self.residual_layer = nn.Linear(config.input_dims, config.output_dims, bias=config.use_bias)
@staticmethod
def _resolve_activation(name: str) -> nn.Module:
if name == "relu":
return nn.ReLU()
if name == "swish":
return nn.SiLU()
if name == "none":
return nn.Identity()
raise ValueError(f"Unsupported activation: {name}")
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden = self.activation(self.hidden_layer(x))
return self.output_layer(hidden) + self.residual_layer(x)
class RMSNorm(nn.Module):
"""Root-mean-square normalization."""
def __init__(self, num_features: int, epsilon: float = 1e-6):
super().__init__()
self.scale = nn.Parameter(torch.zeros(num_features))
self.epsilon = epsilon
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
var = torch.mean(torch.square(inputs), dim=-1, keepdim=True)
normed_inputs = inputs * torch.rsqrt(var + self.epsilon)
return normed_inputs * self.scale
def make_attn_mask(
query_length: int,
num_all_masked_kv: torch.Tensor,
query_index_offset: torch.Tensor | None = None,
kv_length: int = 0,
) -> torch.Tensor:
"""Creates a causal mask consistent with cached decoding."""
if kv_length == 0:
kv_length = query_length
q_index = torch.arange(query_length, device=num_all_masked_kv.device)[None, None, :, None]
if query_index_offset is not None:
q_index = q_index + query_index_offset[:, None, None, None]
kv_index = torch.arange(kv_length, device=num_all_masked_kv.device)[None, None, None, :]
return torch.logical_and(q_index >= kv_index, kv_index >= num_all_masked_kv[:, None, None, None])
class RotaryPositionalEmbedding(nn.Module):
"""Applies rotary position embeddings to query/key projections."""
def __init__(self, embedding_dims: int, min_timescale: float = 1.0, max_timescale: float = 10000.0):
super().__init__()
self.embedding_dims = embedding_dims
self.min_timescale = min_timescale
self.max_timescale = max_timescale
def forward(self, inputs: torch.Tensor, position: torch.Tensor | None = None) -> torch.Tensor:
if self.embedding_dims != inputs.shape[-1]:
raise ValueError("Rotary embedding dimension must equal the head dimension.")
half_dim = self.embedding_dims // 2
fraction = 2 * torch.arange(half_dim, device=inputs.device) / self.embedding_dims
timescale = (self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction).to(inputs.device)
if position is None:
position = torch.arange(inputs.shape[1], dtype=torch.float32, device=inputs.device)[None, :]
if len(inputs.shape) == 4:
position = position[..., None, None]
timescale = timescale[None, None, None, :]
elif len(inputs.shape) == 3:
position = position[..., None]
timescale = timescale[None, None, :]
else:
raise ValueError("Expected rank-3 or rank-4 tensor for rotary embeddings.")
sinusoid = position / timescale
sin = torch.sin(sinusoid)
cos = torch.cos(sinusoid)
first_half, second_half = torch.chunk(inputs, 2, dim=-1)
rotated_first = first_half * cos - second_half * sin
rotated_second = second_half * cos + first_half * sin
return torch.cat([rotated_first, rotated_second], dim=-1)
class PerDimScale(nn.Module):
"""Learned per-dimension scaling used prior to attention."""
def __init__(self, num_dims: int):
super().__init__()
self.num_dims = num_dims
self.per_dim_scale = nn.Parameter(torch.zeros(num_dims))
def forward(self, x: torch.Tensor) -> torch.Tensor:
scale_factor = 1.442695041 / math.sqrt(self.num_dims) * F.softplus(self.per_dim_scale)
return x * scale_factor
class MultiHeadAttention(nn.Module):
"""Multi-head attention supporting fused QKV projections and caching."""
def __init__(
self,
num_heads: int,
in_features: int,
*,
use_per_dim_scale: bool = True,
use_rotary_position_embeddings: bool = True,
use_bias: bool = False,
attention_fn=F.scaled_dot_product_attention,
qk_norm: str = "rms",
fuse_qkv: bool = False,
):
super().__init__()
self.num_heads = num_heads
self.in_features = in_features
self.head_dim = in_features // num_heads
self.use_bias = use_bias
self.attention_fn = attention_fn
self.qk_norm = qk_norm
self.fuse_qkv = fuse_qkv
if in_features % num_heads != 0:
raise ValueError(f"Model dimension {in_features} must be divisible by {num_heads} heads.")
if fuse_qkv:
self.qkv_proj = nn.Linear(in_features, 3 * in_features, bias=use_bias)
else:
self.query = nn.Linear(in_features, in_features, bias=use_bias)
self.key = nn.Linear(in_features, in_features, bias=use_bias)
self.value = nn.Linear(in_features, in_features, bias=use_bias)
self.out = nn.Linear(in_features, in_features, bias=use_bias)
if qk_norm == "rms":
self.query_ln = RMSNorm(self.head_dim)
self.key_ln = RMSNorm(self.head_dim)
else:
self.query_ln = nn.Identity()
self.key_ln = nn.Identity()
self.use_rotary_position_embeddings = use_rotary_position_embeddings
if use_rotary_position_embeddings:
self.rotary_position_embedding = RotaryPositionalEmbedding(self.head_dim)
self.use_per_dim_scale = use_per_dim_scale
if use_per_dim_scale:
self.per_dim_scale = PerDimScale(self.head_dim)
def forward(
self,
inputs_q: torch.Tensor,
*,
decode_cache: DecodeCache | None = None,
patch_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, DecodeCache | None]:
batch, num_patches, _ = inputs_q.shape
if patch_mask is None:
patch_mask = torch.zeros(batch, num_patches, dtype=torch.bool, device=inputs_q.device)
if self.fuse_qkv:
qkv = self.qkv_proj(inputs_q)
query, key, value = torch.chunk(qkv, 3, dim=-1)
query = query.view(batch, num_patches, self.num_heads, self.head_dim)
key = key.view(batch, num_patches, self.num_heads, self.head_dim)
value = value.view(batch, num_patches, self.num_heads, self.head_dim)
else:
query = self.query(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim)
key = self.key(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim)
value = self.value(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim)
if decode_cache is None:
num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1)
next_index = torch.zeros_like(num_masked, dtype=torch.int32)
else:
num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1) + decode_cache.num_masked
next_index = decode_cache.next_index.clone()
if self.use_rotary_position_embeddings:
position = (
torch.arange(num_patches, device=inputs_q.device)[None, :]
+ next_index[:, None]
- num_masked[:, None]
)
query = self.rotary_position_embedding(query, position)
key = self.rotary_position_embedding(key, position)
query = self.query_ln(query)
key = self.key_ln(key)
if self.use_per_dim_scale:
query = self.per_dim_scale(query)
if decode_cache is not None:
_, cache_size, _, _ = decode_cache.value.shape
start = decode_cache.next_index[0]
end = start + num_patches
decode_cache.key[:, start:end] = key
decode_cache.value[:, start:end] = value
key = decode_cache.key
value = decode_cache.value
decode_cache.next_index += num_patches
decode_cache.num_masked = num_masked
attn_mask = make_attn_mask(
query_length=num_patches,
num_all_masked_kv=num_masked,
query_index_offset=next_index,
kv_length=cache_size,
)
else:
attn_mask = make_attn_mask(query_length=num_patches, num_all_masked_kv=num_masked)
attn_output = F.scaled_dot_product_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
attn_mask=attn_mask,
scale=1.0,
)
attn_output = attn_output.permute(0, 2, 1, 3)
attn_output = attn_output.reshape(batch, num_patches, self.in_features)
return self.out(attn_output), decode_cache
class Transformer(nn.Module):
"""Transformer block used by TimesFM."""
def __init__(self, config: TransformerConfig):
super().__init__()
if config.attention_norm != "rms" or config.feedforward_norm != "rms":
raise ValueError("Only RMS normalization is supported.")
self.pre_attn_ln = RMSNorm(config.model_dims)
self.post_attn_ln = RMSNorm(config.model_dims)
self.attn = MultiHeadAttention(
num_heads=config.num_heads,
in_features=config.model_dims,
use_per_dim_scale=True,
use_rotary_position_embeddings=config.use_rotary_position_embeddings,
qk_norm=config.qk_norm,
fuse_qkv=config.fuse_qkv,
)
self.pre_ff_ln = RMSNorm(config.model_dims)
self.post_ff_ln = RMSNorm(config.model_dims)
self.ff0 = nn.Linear(config.model_dims, config.hidden_dims, bias=config.use_bias)
self.ff1 = nn.Linear(config.hidden_dims, config.model_dims, bias=config.use_bias)
self.activation = ResidualBlock._resolve_activation(config.ff_activation)
def forward(
self,
input_embeddings: torch.Tensor,
patch_mask: torch.Tensor,
decode_cache: DecodeCache | None = None,
) -> tuple[torch.Tensor, DecodeCache | None]:
attn_output, decode_cache = self.attn(
inputs_q=self.pre_attn_ln(input_embeddings),
decode_cache=decode_cache,
patch_mask=patch_mask,
)
attn_output = self.post_attn_ln(attn_output) + input_embeddings
feedforward = self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output))))
output_embeddings = self.post_ff_ln(feedforward) + attn_output
return output_embeddings, decode_cache
class TimesFM2Core(nn.Module):
"""Core TimesFM 2.x backbone without external dependencies."""
def __init__(self, definition: TimesFM2Definition | None = None):
super().__init__()
self.config = definition or TimesFM2Definition()
self.p = self.config.input_patch_len
self.o = self.config.output_patch_len
self.os = self.config.output_quantile_len
self.m = self.o // self.p
self.x = self.config.stacked_transformers.num_layers
self.h = self.config.stacked_transformers.transformer.num_heads
self.md = self.config.stacked_transformers.transformer.model_dims
self.hd = self.md // self.h
self.q = len(self.config.quantiles) + 1
self.aridx = self.config.decode_index
self.tokenizer = ResidualBlock(self.config.tokenizer)
self.stacked_xf = nn.ModuleList(
[Transformer(self.config.stacked_transformers.transformer) for _ in range(self.x)]
)
self.output_projection_point = ResidualBlock(self.config.output_projection_point)
self.output_projection_quantiles = ResidualBlock(self.config.output_projection_quantiles)
def load_safetensors(self, path: str, strict: bool = True) -> None:
if _load_safetensors is None:
raise ImportError("Install safetensors to load TimesFM2 checkpoints.")
tensors = _load_safetensors(path)
self.load_state_dict(tensors, strict=strict)
self.eval()
def forward(
self,
inputs: torch.Tensor,
masks: torch.Tensor,
decode_caches: list[DecodeCache] | None = None,
) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], list[DecodeCache]]:
tokenizer_inputs = torch.cat([inputs, masks.to(inputs.dtype)], dim=-1)
input_embeddings = self.tokenizer(tokenizer_inputs)
if decode_caches is None:
decode_caches = [None] * self.x # type: ignore[list-item]
output_embeddings = input_embeddings
new_decode_caches: list[DecodeCache] = []
for layer, cache in zip(self.stacked_xf, decode_caches):
output_embeddings, new_cache = layer(output_embeddings, masks[..., -1], cache)
new_decode_caches.append(new_cache)
output_ts = self.output_projection_point(output_embeddings)
output_quantile_spread = self.output_projection_quantiles(output_embeddings)
return (input_embeddings, output_embeddings, output_ts, output_quantile_spread), new_decode_caches
def decode(
self,
horizon: int,
inputs: torch.Tensor,
masks: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""Autoregressively decodes a batch of sequences."""
batch_size, context = inputs.shape
num_decode_steps = (horizon - 1) // self.o
num_input_patches = context // self.p
use_cache = not torch.is_grad_enabled()
patched_inputs = torch.reshape(inputs, (batch_size, -1, self.p))
patched_masks = torch.reshape(masks, (batch_size, -1, self.p))
n = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype)
mu = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype)
sigma = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype)
patch_mu: list[torch.Tensor] = []
patch_sigma: list[torch.Tensor] = []
for i in range(num_input_patches):
(n, mu, sigma), _ = update_running_stats(n, mu, sigma, patched_inputs[:, i], patched_masks[:, i])
patch_mu.append(mu)
patch_sigma.append(sigma)
last_n, last_mu, last_sigma = n, mu, sigma
context_mu = torch.stack(patch_mu, dim=1)
context_sigma = torch.stack(patch_sigma, dim=1)
decode_caches: list[DecodeCache] | None
if use_cache:
decode_cache_size = num_input_patches + num_decode_steps * self.m
decode_caches = [
DecodeCache(
next_index=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
num_masked=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
key=torch.zeros(
batch_size,
decode_cache_size,
self.h,
self.hd,
device=inputs.device,
dtype=inputs.dtype,
),
value=torch.zeros(
batch_size,
decode_cache_size,
self.h,
self.hd,
device=inputs.device,
dtype=inputs.dtype,
),
)
for _ in range(self.x)
]
else:
decode_caches = None
normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)
normed_inputs = torch.where(patched_masks, torch.zeros((), device=inputs.device, dtype=inputs.dtype), normed_inputs)
(_, _, normed_outputs, normed_quantile_spread), decode_caches = self(normed_inputs, patched_masks, decode_caches)
renormed_outputs = torch.reshape(
revin(normed_outputs, context_mu, context_sigma, reverse=True),
(batch_size, -1, self.o, self.q),
)
renormed_quantile_spread = torch.reshape(
revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),
(batch_size, -1, self.os, self.q),
)[:, -1, ...]
ar_outputs: list[torch.Tensor] = []
last_renormed_output = renormed_outputs[:, -1, :, self.aridx]
for _ in range(num_decode_steps):
new_patched_input = torch.reshape(last_renormed_output, (batch_size, self.m, self.p))
new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool)
n, mu, sigma = last_n, last_mu, last_sigma
new_mus: list[torch.Tensor] = []
new_sigmas: list[torch.Tensor] = []
for i in range(self.m):
(n, mu, sigma), _ = update_running_stats(n, mu, sigma, new_patched_input[:, i], new_mask[:, i])
new_mus.append(mu)
new_sigmas.append(sigma)
last_n, last_mu, last_sigma = n, mu, sigma
new_mu = torch.stack(new_mus, dim=1)
new_sigma = torch.stack(new_sigmas, dim=1)
new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
(_, _, new_normed_output, _), decode_caches = self(new_normed_input, new_mask, decode_caches)
new_renormed_output = torch.reshape(
revin(new_normed_output, new_mu, new_sigma, reverse=True),
(batch_size, self.m, self.o, self.q),
)
ar_outputs.append(new_renormed_output[:, -1, ...])
last_renormed_output = new_renormed_output[:, -1, :, self.aridx]
ar_renormed_outputs = torch.stack(ar_outputs, dim=1) if num_decode_steps > 0 else None
return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs
class TimesFM2(nn.Module):
"""High-level TimesFM 2.x wrapper mirroring the TimesFM interface."""
def __init__(self, lookback: int = 512, lookahead: int = 96):
super().__init__()
self.lookback = lookback
self.lookahead = lookahead
self.core = TimesFM2Core()
if lookback > self.core.config.context_limit:
raise ValueError(
f"lookback ({lookback}) exceeds maximum context limit ({self.core.config.context_limit})."
)
def load_state_dict(self, state_dict, strict: bool = True):
return self.core.load_state_dict(state_dict, strict=strict)
def state_dict(self, *args, **kwargs):
return self.core.state_dict(*args, **kwargs)
def load_safetensors(self, path: str, strict: bool = True) -> None:
self.core.load_safetensors(path, strict=strict)
def _prepare_inputs(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if x.shape[1] < self.lookback:
raise ValueError(f"Expected at least {self.lookback} context steps, received {x.shape[1]}.")
context = x[:, -self.lookback:]
pad_len = (-context.shape[1]) % self.core.p
if pad_len > 0:
context = F.pad(context, (pad_len, 0))
pad_mask = torch.ones(context.shape[0], pad_len, dtype=torch.bool, device=context.device)
mask = torch.cat(
[pad_mask, torch.zeros(context.shape[0], self.lookback, dtype=torch.bool, device=context.device)],
dim=1,
)
else:
mask = torch.zeros_like(context, dtype=torch.bool)
if context.shape[1] > self.core.config.context_limit:
context = context[:, -self.core.config.context_limit :]
mask = mask[:, -self.core.config.context_limit :]
return context, mask
def forward(
self,
x: torch.Tensor,
*,
return_quantiles: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if x.dim() != 2:
raise ValueError(f"Expected input tensor of shape (batch, time), received {tuple(x.shape)}.")
inputs, mask = self._prepare_inputs(x.to(dtype=torch.float32))
renormed_outputs, _, ar_outputs = self.core.decode(self.lookahead, inputs, mask)
batch_size = inputs.shape[0]
to_cat = [renormed_outputs[:, -1, ...]]
if ar_outputs is not None:
to_cat.append(ar_outputs.reshape(batch_size, -1, self.core.q))
full_forecast = torch.cat(to_cat, dim=1)[:, : self.lookahead, :]
point_forecast = full_forecast[..., self.core.aridx]
if return_quantiles:
return point_forecast, full_forecast
return point_forecast
__all__ = ["TimesFM2", "TimesFM2Core", "TimesFM2Definition"]
|