exx commited on
Commit
c20b869
·
1 Parent(s): 93517af

TimesFM2 grad issue debug

Browse files
Files changed (1) hide show
  1. models/TimesFM2.py +81 -63
models/TimesFM2.py CHANGED
@@ -515,81 +515,99 @@ class TimesFM2Core(nn.Module):
515
  masks: torch.Tensor,
516
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
517
  """Autoregressively decodes a batch of sequences."""
518
- with torch.no_grad():
519
- batch_size, context = inputs.shape
520
- num_decode_steps = (horizon - 1) // self.o
521
- num_input_patches = context // self.p
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
  decode_cache_size = num_input_patches + num_decode_steps * self.m
523
-
524
- patched_inputs = torch.reshape(inputs, (batch_size, -1, self.p))
525
- patched_masks = torch.reshape(masks, (batch_size, -1, self.p))
526
-
527
- n = torch.zeros(batch_size, device=inputs.device)
528
- mu = torch.zeros(batch_size, device=inputs.device)
529
- sigma = torch.zeros(batch_size, device=inputs.device)
530
- patch_mu: list[torch.Tensor] = []
531
- patch_sigma: list[torch.Tensor] = []
532
- for i in range(num_input_patches):
533
- (n, mu, sigma), _ = update_running_stats(n, mu, sigma, patched_inputs[:, i], patched_masks[:, i])
534
- patch_mu.append(mu)
535
- patch_sigma.append(sigma)
536
-
537
- last_n, last_mu, last_sigma = n, mu, sigma
538
- context_mu = torch.stack(patch_mu, dim=1)
539
- context_sigma = torch.stack(patch_sigma, dim=1)
540
-
541
  decode_caches = [
542
  DecodeCache(
543
  next_index=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
544
  num_masked=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
545
- key=torch.zeros(batch_size, decode_cache_size, self.h, self.hd, device=inputs.device),
546
- value=torch.zeros(batch_size, decode_cache_size, self.h, self.hd, device=inputs.device),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  )
548
  for _ in range(self.x)
549
  ]
 
 
550
 
551
- normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)
552
- normed_inputs = torch.where(patched_masks, 0.0, normed_inputs)
553
- (_, _, normed_outputs, normed_quantile_spread), decode_caches = self(normed_inputs, patched_masks, decode_caches)
554
 
555
- renormed_outputs = torch.reshape(
556
- revin(normed_outputs, context_mu, context_sigma, reverse=True),
557
- (batch_size, -1, self.o, self.q),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  )
559
- renormed_quantile_spread = torch.reshape(
560
- revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),
561
- (batch_size, -1, self.os, self.q),
562
- )[:, -1, ...]
563
-
564
- ar_outputs: list[torch.Tensor] = []
565
- last_renormed_output = renormed_outputs[:, -1, :, self.aridx]
566
-
567
- for _ in range(num_decode_steps):
568
- new_patched_input = torch.reshape(last_renormed_output, (batch_size, self.m, self.p))
569
- new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool)
570
-
571
- n, mu, sigma = last_n, last_mu, last_sigma
572
- new_mus: list[torch.Tensor] = []
573
- new_sigmas: list[torch.Tensor] = []
574
- for i in range(self.m):
575
- (n, mu, sigma), _ = update_running_stats(n, mu, sigma, new_patched_input[:, i], new_mask[:, i])
576
- new_mus.append(mu)
577
- new_sigmas.append(sigma)
578
- last_n, last_mu, last_sigma = n, mu, sigma
579
- new_mu = torch.stack(new_mus, dim=1)
580
- new_sigma = torch.stack(new_sigmas, dim=1)
581
-
582
- new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
583
- (_, _, new_normed_output, _), decode_caches = self(new_normed_input, new_mask, decode_caches)
584
-
585
- new_renormed_output = torch.reshape(
586
- revin(new_normed_output, new_mu, new_sigma, reverse=True),
587
- (batch_size, self.m, self.o, self.q),
588
- )
589
- ar_outputs.append(new_renormed_output[:, -1, ...])
590
- last_renormed_output = new_renormed_output[:, -1, :, self.aridx]
591
 
592
- ar_renormed_outputs = torch.stack(ar_outputs, dim=1) if num_decode_steps > 0 else None
593
 
594
  return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs
595
 
 
515
  masks: torch.Tensor,
516
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
517
  """Autoregressively decodes a batch of sequences."""
518
+ batch_size, context = inputs.shape
519
+ num_decode_steps = (horizon - 1) // self.o
520
+ num_input_patches = context // self.p
521
+ use_cache = not torch.is_grad_enabled()
522
+
523
+ patched_inputs = torch.reshape(inputs, (batch_size, -1, self.p))
524
+ patched_masks = torch.reshape(masks, (batch_size, -1, self.p))
525
+
526
+ n = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype)
527
+ mu = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype)
528
+ sigma = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype)
529
+ patch_mu: list[torch.Tensor] = []
530
+ patch_sigma: list[torch.Tensor] = []
531
+ for i in range(num_input_patches):
532
+ (n, mu, sigma), _ = update_running_stats(n, mu, sigma, patched_inputs[:, i], patched_masks[:, i])
533
+ patch_mu.append(mu)
534
+ patch_sigma.append(sigma)
535
+
536
+ last_n, last_mu, last_sigma = n, mu, sigma
537
+ context_mu = torch.stack(patch_mu, dim=1)
538
+ context_sigma = torch.stack(patch_sigma, dim=1)
539
+
540
+ decode_caches: list[DecodeCache] | None
541
+ if use_cache:
542
  decode_cache_size = num_input_patches + num_decode_steps * self.m
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  decode_caches = [
544
  DecodeCache(
545
  next_index=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
546
  num_masked=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
547
+ key=torch.zeros(
548
+ batch_size,
549
+ decode_cache_size,
550
+ self.h,
551
+ self.hd,
552
+ device=inputs.device,
553
+ dtype=inputs.dtype,
554
+ ),
555
+ value=torch.zeros(
556
+ batch_size,
557
+ decode_cache_size,
558
+ self.h,
559
+ self.hd,
560
+ device=inputs.device,
561
+ dtype=inputs.dtype,
562
+ ),
563
  )
564
  for _ in range(self.x)
565
  ]
566
+ else:
567
+ decode_caches = None
568
 
569
+ normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)
570
+ normed_inputs = torch.where(patched_masks, torch.zeros((), device=inputs.device, dtype=inputs.dtype), normed_inputs)
571
+ (_, _, normed_outputs, normed_quantile_spread), decode_caches = self(normed_inputs, patched_masks, decode_caches)
572
 
573
+ renormed_outputs = torch.reshape(
574
+ revin(normed_outputs, context_mu, context_sigma, reverse=True),
575
+ (batch_size, -1, self.o, self.q),
576
+ )
577
+ renormed_quantile_spread = torch.reshape(
578
+ revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),
579
+ (batch_size, -1, self.os, self.q),
580
+ )[:, -1, ...]
581
+
582
+ ar_outputs: list[torch.Tensor] = []
583
+ last_renormed_output = renormed_outputs[:, -1, :, self.aridx]
584
+
585
+ for _ in range(num_decode_steps):
586
+ new_patched_input = torch.reshape(last_renormed_output, (batch_size, self.m, self.p))
587
+ new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool)
588
+
589
+ n, mu, sigma = last_n, last_mu, last_sigma
590
+ new_mus: list[torch.Tensor] = []
591
+ new_sigmas: list[torch.Tensor] = []
592
+ for i in range(self.m):
593
+ (n, mu, sigma), _ = update_running_stats(n, mu, sigma, new_patched_input[:, i], new_mask[:, i])
594
+ new_mus.append(mu)
595
+ new_sigmas.append(sigma)
596
+ last_n, last_mu, last_sigma = n, mu, sigma
597
+ new_mu = torch.stack(new_mus, dim=1)
598
+ new_sigma = torch.stack(new_sigmas, dim=1)
599
+
600
+ new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
601
+ (_, _, new_normed_output, _), decode_caches = self(new_normed_input, new_mask, decode_caches)
602
+
603
+ new_renormed_output = torch.reshape(
604
+ revin(new_normed_output, new_mu, new_sigma, reverse=True),
605
+ (batch_size, self.m, self.o, self.q),
606
  )
607
+ ar_outputs.append(new_renormed_output[:, -1, ...])
608
+ last_renormed_output = new_renormed_output[:, -1, :, self.aridx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
 
610
+ ar_renormed_outputs = torch.stack(ar_outputs, dim=1) if num_decode_steps > 0 else None
611
 
612
  return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs
613