exx
commited on
Commit
·
c20b869
1
Parent(s):
93517af
TimesFM2 grad issue debug
Browse files- models/TimesFM2.py +81 -63
models/TimesFM2.py
CHANGED
|
@@ -515,81 +515,99 @@ class TimesFM2Core(nn.Module):
|
|
| 515 |
masks: torch.Tensor,
|
| 516 |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
| 517 |
"""Autoregressively decodes a batch of sequences."""
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
decode_cache_size = num_input_patches + num_decode_steps * self.m
|
| 523 |
-
|
| 524 |
-
patched_inputs = torch.reshape(inputs, (batch_size, -1, self.p))
|
| 525 |
-
patched_masks = torch.reshape(masks, (batch_size, -1, self.p))
|
| 526 |
-
|
| 527 |
-
n = torch.zeros(batch_size, device=inputs.device)
|
| 528 |
-
mu = torch.zeros(batch_size, device=inputs.device)
|
| 529 |
-
sigma = torch.zeros(batch_size, device=inputs.device)
|
| 530 |
-
patch_mu: list[torch.Tensor] = []
|
| 531 |
-
patch_sigma: list[torch.Tensor] = []
|
| 532 |
-
for i in range(num_input_patches):
|
| 533 |
-
(n, mu, sigma), _ = update_running_stats(n, mu, sigma, patched_inputs[:, i], patched_masks[:, i])
|
| 534 |
-
patch_mu.append(mu)
|
| 535 |
-
patch_sigma.append(sigma)
|
| 536 |
-
|
| 537 |
-
last_n, last_mu, last_sigma = n, mu, sigma
|
| 538 |
-
context_mu = torch.stack(patch_mu, dim=1)
|
| 539 |
-
context_sigma = torch.stack(patch_sigma, dim=1)
|
| 540 |
-
|
| 541 |
decode_caches = [
|
| 542 |
DecodeCache(
|
| 543 |
next_index=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
|
| 544 |
num_masked=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
|
| 545 |
-
key=torch.zeros(
|
| 546 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
)
|
| 548 |
for _ in range(self.x)
|
| 549 |
]
|
|
|
|
|
|
|
| 550 |
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
)
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
(batch_size, -1, self.os, self.q),
|
| 562 |
-
)[:, -1, ...]
|
| 563 |
-
|
| 564 |
-
ar_outputs: list[torch.Tensor] = []
|
| 565 |
-
last_renormed_output = renormed_outputs[:, -1, :, self.aridx]
|
| 566 |
-
|
| 567 |
-
for _ in range(num_decode_steps):
|
| 568 |
-
new_patched_input = torch.reshape(last_renormed_output, (batch_size, self.m, self.p))
|
| 569 |
-
new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool)
|
| 570 |
-
|
| 571 |
-
n, mu, sigma = last_n, last_mu, last_sigma
|
| 572 |
-
new_mus: list[torch.Tensor] = []
|
| 573 |
-
new_sigmas: list[torch.Tensor] = []
|
| 574 |
-
for i in range(self.m):
|
| 575 |
-
(n, mu, sigma), _ = update_running_stats(n, mu, sigma, new_patched_input[:, i], new_mask[:, i])
|
| 576 |
-
new_mus.append(mu)
|
| 577 |
-
new_sigmas.append(sigma)
|
| 578 |
-
last_n, last_mu, last_sigma = n, mu, sigma
|
| 579 |
-
new_mu = torch.stack(new_mus, dim=1)
|
| 580 |
-
new_sigma = torch.stack(new_sigmas, dim=1)
|
| 581 |
-
|
| 582 |
-
new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
|
| 583 |
-
(_, _, new_normed_output, _), decode_caches = self(new_normed_input, new_mask, decode_caches)
|
| 584 |
-
|
| 585 |
-
new_renormed_output = torch.reshape(
|
| 586 |
-
revin(new_normed_output, new_mu, new_sigma, reverse=True),
|
| 587 |
-
(batch_size, self.m, self.o, self.q),
|
| 588 |
-
)
|
| 589 |
-
ar_outputs.append(new_renormed_output[:, -1, ...])
|
| 590 |
-
last_renormed_output = new_renormed_output[:, -1, :, self.aridx]
|
| 591 |
|
| 592 |
-
|
| 593 |
|
| 594 |
return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs
|
| 595 |
|
|
|
|
| 515 |
masks: torch.Tensor,
|
| 516 |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
| 517 |
"""Autoregressively decodes a batch of sequences."""
|
| 518 |
+
batch_size, context = inputs.shape
|
| 519 |
+
num_decode_steps = (horizon - 1) // self.o
|
| 520 |
+
num_input_patches = context // self.p
|
| 521 |
+
use_cache = not torch.is_grad_enabled()
|
| 522 |
+
|
| 523 |
+
patched_inputs = torch.reshape(inputs, (batch_size, -1, self.p))
|
| 524 |
+
patched_masks = torch.reshape(masks, (batch_size, -1, self.p))
|
| 525 |
+
|
| 526 |
+
n = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype)
|
| 527 |
+
mu = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype)
|
| 528 |
+
sigma = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype)
|
| 529 |
+
patch_mu: list[torch.Tensor] = []
|
| 530 |
+
patch_sigma: list[torch.Tensor] = []
|
| 531 |
+
for i in range(num_input_patches):
|
| 532 |
+
(n, mu, sigma), _ = update_running_stats(n, mu, sigma, patched_inputs[:, i], patched_masks[:, i])
|
| 533 |
+
patch_mu.append(mu)
|
| 534 |
+
patch_sigma.append(sigma)
|
| 535 |
+
|
| 536 |
+
last_n, last_mu, last_sigma = n, mu, sigma
|
| 537 |
+
context_mu = torch.stack(patch_mu, dim=1)
|
| 538 |
+
context_sigma = torch.stack(patch_sigma, dim=1)
|
| 539 |
+
|
| 540 |
+
decode_caches: list[DecodeCache] | None
|
| 541 |
+
if use_cache:
|
| 542 |
decode_cache_size = num_input_patches + num_decode_steps * self.m
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
decode_caches = [
|
| 544 |
DecodeCache(
|
| 545 |
next_index=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
|
| 546 |
num_masked=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
|
| 547 |
+
key=torch.zeros(
|
| 548 |
+
batch_size,
|
| 549 |
+
decode_cache_size,
|
| 550 |
+
self.h,
|
| 551 |
+
self.hd,
|
| 552 |
+
device=inputs.device,
|
| 553 |
+
dtype=inputs.dtype,
|
| 554 |
+
),
|
| 555 |
+
value=torch.zeros(
|
| 556 |
+
batch_size,
|
| 557 |
+
decode_cache_size,
|
| 558 |
+
self.h,
|
| 559 |
+
self.hd,
|
| 560 |
+
device=inputs.device,
|
| 561 |
+
dtype=inputs.dtype,
|
| 562 |
+
),
|
| 563 |
)
|
| 564 |
for _ in range(self.x)
|
| 565 |
]
|
| 566 |
+
else:
|
| 567 |
+
decode_caches = None
|
| 568 |
|
| 569 |
+
normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)
|
| 570 |
+
normed_inputs = torch.where(patched_masks, torch.zeros((), device=inputs.device, dtype=inputs.dtype), normed_inputs)
|
| 571 |
+
(_, _, normed_outputs, normed_quantile_spread), decode_caches = self(normed_inputs, patched_masks, decode_caches)
|
| 572 |
|
| 573 |
+
renormed_outputs = torch.reshape(
|
| 574 |
+
revin(normed_outputs, context_mu, context_sigma, reverse=True),
|
| 575 |
+
(batch_size, -1, self.o, self.q),
|
| 576 |
+
)
|
| 577 |
+
renormed_quantile_spread = torch.reshape(
|
| 578 |
+
revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),
|
| 579 |
+
(batch_size, -1, self.os, self.q),
|
| 580 |
+
)[:, -1, ...]
|
| 581 |
+
|
| 582 |
+
ar_outputs: list[torch.Tensor] = []
|
| 583 |
+
last_renormed_output = renormed_outputs[:, -1, :, self.aridx]
|
| 584 |
+
|
| 585 |
+
for _ in range(num_decode_steps):
|
| 586 |
+
new_patched_input = torch.reshape(last_renormed_output, (batch_size, self.m, self.p))
|
| 587 |
+
new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool)
|
| 588 |
+
|
| 589 |
+
n, mu, sigma = last_n, last_mu, last_sigma
|
| 590 |
+
new_mus: list[torch.Tensor] = []
|
| 591 |
+
new_sigmas: list[torch.Tensor] = []
|
| 592 |
+
for i in range(self.m):
|
| 593 |
+
(n, mu, sigma), _ = update_running_stats(n, mu, sigma, new_patched_input[:, i], new_mask[:, i])
|
| 594 |
+
new_mus.append(mu)
|
| 595 |
+
new_sigmas.append(sigma)
|
| 596 |
+
last_n, last_mu, last_sigma = n, mu, sigma
|
| 597 |
+
new_mu = torch.stack(new_mus, dim=1)
|
| 598 |
+
new_sigma = torch.stack(new_sigmas, dim=1)
|
| 599 |
+
|
| 600 |
+
new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
|
| 601 |
+
(_, _, new_normed_output, _), decode_caches = self(new_normed_input, new_mask, decode_caches)
|
| 602 |
+
|
| 603 |
+
new_renormed_output = torch.reshape(
|
| 604 |
+
revin(new_normed_output, new_mu, new_sigma, reverse=True),
|
| 605 |
+
(batch_size, self.m, self.o, self.q),
|
| 606 |
)
|
| 607 |
+
ar_outputs.append(new_renormed_output[:, -1, ...])
|
| 608 |
+
last_renormed_output = new_renormed_output[:, -1, :, self.aridx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
|
| 610 |
+
ar_renormed_outputs = torch.stack(ar_outputs, dim=1) if num_decode_steps > 0 else None
|
| 611 |
|
| 612 |
return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs
|
| 613 |
|