Harryx2025 commited on
Commit
6c620ed
·
verified ·
1 Parent(s): bb7d52f

Rename modeling_patch_moe.py to modeling_FalconTST.py

Browse files
modeling_patch_moe.py → modeling_FalconTST.py RENAMED
@@ -1,14 +1,20 @@
1
  import torch
2
- from typing import Optional
 
3
  import torch.nn as nn
4
  import torch.nn.functional as F
 
5
  from torch import Tensor
6
  import math
 
7
  from functools import reduce
8
  from abc import ABC, abstractmethod
9
- from .configuration_patch_moe import PatchMoeConfig
10
- from .ts_generation_mixin import PatchMoEGenerationMixin
11
- from transformers import PreTrainedModel
 
 
 
12
 
13
 
14
  def _rotate_half(x: Tensor, rotary_interleaved: bool) -> Tensor:
@@ -31,12 +37,12 @@ def _rotate_half(x: Tensor, rotary_interleaved: bool) -> Tensor:
31
 
32
 
33
  def _apply_rotary_pos_emb_bshd(
34
- t: Tensor,
35
- freqs: Tensor,
36
- rotary_interleaved: bool = False,
37
- multi_latent_attention: bool = False,
38
- mscale: float = 1.0,
39
- ) -> Tensor:
40
  """Apply rotary positional embedding to input tensor T.
41
 
42
  check https://kexue.fm/archives/8265 for detailed formulas
@@ -94,39 +100,24 @@ def topk_softmax_with_capacity(
94
  """
95
  assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
96
 
97
- def compute_topk(
98
- scores,
99
- topk,
100
- ):
101
  return torch.topk(scores, k=topk, dim=1)
102
 
103
  if score_function == "softmax":
104
  if use_pre_softmax:
105
  scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
106
- probs, top_indices = compute_topk(
107
- scores,
108
- topk,
109
- )
110
  else:
111
- scores, top_indices = compute_topk(
112
- logits,
113
- topk,
114
- )
115
  probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
116
  elif score_function == "sigmoid":
117
  scores = torch.sigmoid(logits.float()).type_as(logits)
118
  if expert_bias is not None:
119
  scores_for_routing = scores + expert_bias
120
- _, top_indices = compute_topk(
121
- scores_for_routing,
122
- topk,
123
- )
124
  scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
125
  else:
126
- scores, top_indices = compute_topk(
127
- scores,
128
- topk,
129
- )
130
  probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
131
  else:
132
  raise ValueError(f"Invalid score_function: {score_function}")
@@ -165,7 +156,7 @@ class RotaryEmbedding(nn.Module):
165
 
166
  dim = kv_channels
167
  self.rotary_interleaved = rotary_interleaved
168
- device = "cpu" if use_cpu_initialization else torch.cuda.current_device()
169
  self.inv_freq = 1.0 / (
170
  rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
171
  )
@@ -180,9 +171,8 @@ class RotaryEmbedding(nn.Module):
180
  freqs = torch.outer(seq, self.inv_freq) # [seq len, dim]
181
  return freqs
182
 
183
- def forward(
184
- self, max_seq_len: int, offset: int = 0, packed_seq: bool = False, device=None
185
- ) -> Tensor:
186
  """Forward pass of RoPE embedding.
187
 
188
  Args:
@@ -195,7 +185,7 @@ class RotaryEmbedding(nn.Module):
195
  """
196
  if device is None:
197
  device = self.inv_freq.device
198
- if self.inv_freq.device.type == "cpu":
199
  # move `inv_freq` to GPU once at the first micro-batch forward pass
200
  self.inv_freq = self.inv_freq.to(device=device)
201
 
@@ -213,7 +203,7 @@ class RotaryEmbedding(nn.Module):
213
  return emb.to(device)
214
 
215
  def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
216
- state_dict.pop(f"{prefix}inv_freq", None)
217
  return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
218
 
219
  def get_rotary_seq_len(
@@ -247,9 +237,9 @@ class RMSNorm(nn.Module):
247
  self.variance_epsilon = eps
248
 
249
  def forward(self, hidden_states):
250
- """
251
- hidden_states [bs, patch_num, d_model]
252
- """
253
  input_dtype = hidden_states.dtype
254
  hidden_states = hidden_states.to(torch.float32)
255
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
@@ -257,7 +247,7 @@ class RMSNorm(nn.Module):
257
  return self.weight * hidden_states.to(input_dtype)
258
 
259
 
260
- class TEDotProductAttention(nn.Module):
261
  """Implement the scaled dot product attention with softmax.
262
  Arguments
263
  ---------
@@ -274,14 +264,7 @@ class TEDotProductAttention(nn.Module):
274
  self.softmax_scale = softmax_scale
275
  self.drop = nn.Dropout(attention_dropout)
276
 
277
- def forward(
278
- self,
279
- q,
280
- k,
281
- v,
282
- attention_mask,
283
- causal=None,
284
- ):
285
  """Implements the multihead softmax attention.
286
  Arguments
287
  ---------
@@ -292,45 +275,47 @@ class TEDotProductAttention(nn.Module):
292
  """
293
  causal = self.causal if causal is None else causal
294
 
295
- q = q.transpose(0, 1).contiguous()
296
- k = k.transpose(0, 1).contiguous()
297
- v = v.transpose(0, 1).contiguous()
298
 
299
  batch_size, seq_len = q.shape[0], q.shape[1]
300
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
301
- # scores
302
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
303
- scores = scores.masked_fill(attention_mask == 0, float("-1e9"))
304
  # Softmax
305
  attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
306
  # Dropout
307
  attention_drop = self.drop(attention)
308
  output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
309
- output = output.reshape(batch_size, seq_len, -1).transpose(0, 1).contiguous()
310
  return output
311
 
312
 
 
 
 
 
 
 
 
 
 
 
 
313
  class SelfAttention(nn.Module):
314
- def __init__(
315
- self,
316
- config,
317
- ):
318
  super().__init__()
319
  self.config = config
320
- q_layernorm = config.q_layernorm
321
- k_layernorm = config.k_layernorm
322
  self.hidden_size = config.hidden_size
323
- self.core_attention = TEDotProductAttention()
324
- self.linear_proj = nn.Linear(
325
- self.hidden_size,
326
- self.hidden_size,
327
- bias=config.add_bias_linear,
328
- )
329
- self.linear_qkv = nn.Linear(
330
- self.hidden_size,
331
- 3 * self.hidden_size,
332
- bias=config.add_bias_linear,
333
  )
 
 
334
  if q_layernorm:
335
  self.q_layernorm = RMSNorm(self.hidden_size)
336
  else:
@@ -340,48 +325,38 @@ class SelfAttention(nn.Module):
340
  else:
341
  self.k_layernorm = IdentityOp()
342
 
343
- def forward(self, x, attention_mask, rotary_pos_emb):
344
  qkv = self.linear_qkv(x)
345
- qkv = qkv.view(qkv.size(0), qkv.size(1), self.config.num_attention_heads, -1)
346
  q, k, v = qkv.chunk(3, dim=-1)
347
-
348
- # q/k norm
349
- q = self.q_layernorm(q)
350
- k = self.k_layernorm(k)
351
-
352
  # Apply rotary encoding to q and k
353
  rotary_pos_emb = (rotary_pos_emb,) * 2
354
  q_pos_emb, k_pos_emb = rotary_pos_emb
355
  q = _apply_rotary_pos_emb_bshd(q, q_pos_emb)
356
  k = _apply_rotary_pos_emb_bshd(k, k_pos_emb)
357
 
358
- # attention
 
 
359
  attn_output = self.core_attention(q, k, v, attention_mask)
360
  output = self.linear_proj(attn_output)
361
  return output
362
 
363
 
 
364
  class MLP(nn.Module):
365
- def __init__(self, config, in_features):
366
  super().__init__()
367
- self.config = config
368
- self.linear_fc1 = nn.Linear(
369
- in_features,
370
- self.config.moe_ffn_hidden_size * 2,
371
- bias=self.config.add_bias_linear,
372
- )
373
- self.linear_fc2 = nn.Linear(
374
- self.config.moe_ffn_hidden_size,
375
- self.config.hidden_size,
376
- bias=self.config.add_bias_linear,
377
- )
378
 
379
  def forward(self, x):
380
  x = self.swiglu(self.linear_fc1(x))
381
  x = self.linear_fc2(x)
382
  return x
383
 
384
- def swiglu(self, y):
385
  """Performs SwiGLU (Swish-Gated Linear Unit) activation function.
386
 
387
  Args:
@@ -404,9 +379,9 @@ class TransformerLayer(nn.Module):
404
  self.input_layernorm = IdentityOp()
405
  self.self_attention = SelfAttention(config)
406
  self.pre_mlp_layernorm = RMSNorm(self.config.hidden_size)
407
- self.mlp = MLP(config, self.config.hidden_size)
408
 
409
- def forward(self, x, attention_mask, rotary_pos_emb):
410
  residual = x
411
  x = self.input_layernorm(x)
412
  x = self.self_attention(x, attention_mask, rotary_pos_emb)
@@ -418,113 +393,84 @@ class TransformerLayer(nn.Module):
418
  return x
419
 
420
 
421
- class PatchMoEExpert_v2(nn.Module):
422
- def __init__(self, config, patch_input_size=32, expert_output_size=336, final_layernorm=True):
423
  super().__init__()
424
  self.config = config
425
- self.patch_size = patch_input_size
426
  self.seq_length = config.seq_length
427
- assert (
428
- self.seq_length % self.patch_size == 0
429
- ), f"invalid patch_size: {self.patch_size} when seq_length={self.seq_length}"
430
  self.patch_num = self.seq_length // self.patch_size
431
  self.flatten_size = self.patch_num * self.config.hidden_size
432
 
433
- self.layers = nn.ModuleList(
434
- [
435
- TransformerLayer(config, input_layernorm=config.transformer_input_layernorm)
436
- for _ in range(self.config.expert_num_layers)
437
- ]
438
- )
439
  if final_layernorm:
440
  self.final_layernorm = RMSNorm(self.config.hidden_size)
441
  else:
442
  self.final_layernorm = IdentityOp()
443
  self.patch_embedding = MLP(config, in_features=patch_input_size)
444
- self.output_layer = nn.Linear(
445
- in_features=self.flatten_size,
446
- out_features=expert_output_size,
447
- bias=False,
448
- )
449
 
450
  def _forward_patch_embedding(
451
  self,
452
- input: Tensor, # [batch_size, seq_len]
453
  ):
454
  """
455
  Perform patch embedding on the input time series.
456
 
457
- This method applies a linear transformation to the input tensor to
458
  convert it into patches and then embeds these patches using a linear layer.
459
  """
460
  batch_size, seq_len = input.shape
461
- assert (
462
- seq_len == self.seq_length
463
- ), f"Expected sequence length {self.seq_length}, but got {seq_len}"
464
 
465
  # Create input_mask based on pad_length
466
  # When a time point is masked, its value is mask_pad_value(default:255.)
467
- input_mask = (
468
- input != self.config.mask_pad_value
469
- ) # 0: mask, 1: unmask [batch_size, seq_len]
470
 
471
  # so whether the masked value 0 has the same effective of attention_mask
472
- input_data = input * input_mask # [batch_size, seq_len]
473
 
474
  # Patchify the input
475
- input_data = input_data.unfold(
476
- dimension=-1, size=self.patch_size, step=self.patch_size
477
- ).contiguous() # input [batch_size, patch_num, patch_size]
478
- hidden_states = self.patch_embedding(
479
- input_data
480
- ) # hidden_states [batch_size, patch_num, hidden_size]
481
- hidden_states = hidden_states.transpose(
482
- 0, 1
483
- ).contiguous() # hidden_states [patch_num, batch_size, hidden_size], To adapt to the Megatron
484
 
485
  # Patchify the mask: only the entire time points in a patch are masked then this patch is masked
486
- attention_mask = input_mask.unfold(
487
- dimension=-1, size=self.patch_size, step=self.patch_size
488
- ).contiguous() # [batch_size, patch_num, patch_size]
489
- attention_mask = (
490
- attention_mask.sum(-1) == self.patch_size
491
- ) # [batch_size, patch_num] # 0: mask, 1: unmask
492
- attention_mask[:, -1] = True # The last patch is not masked
493
  _, patch_num = attention_mask.shape
494
- attention_mask = attention_mask.unsqueeze(2).repeat(
495
- 1, 1, patch_num
496
- ) * attention_mask.unsqueeze(1).repeat(
497
- 1, patch_num, 1
498
- ) # [batch_size, patch_num, patch_num]
499
- attention_mask = attention_mask.unsqueeze(
500
- 1
501
- ).contiguous() # [batch_size, 1, patch_num, patch_num]
502
 
503
  return hidden_states, attention_mask, input_mask
504
 
505
- def _forward_output(
506
- self, hidden_states, output_scale=None, input_mask=None, inference_context=None
507
- ):
508
  """
509
- Perform a forward pass through the output layer.
510
 
511
- Args:
512
- expert_input (Tensor): Expert input of shape [batch_size, seq_len]
513
- hidden_states (Tensor): Transformed hidden states of shape [patch_num, batch_size, hidden_size]
514
- output_scale (Tensor, optional): Expert probabilities for the output layer [batch_size]
515
- input_mask (Tensor, optional): Expert input mask of shape [batch_size, seq_len], 0:mask, 1:unmask
516
 
517
- Returns:
518
- expert_output (Tensor): Expert output of shape [batch_size, expert_output_size]
519
  """
520
 
521
  # [patch_num, batch_size, hidden_size] -> [batch_size, flatten_size (patch_num * hidden_size)]
522
  patch_num, batch_size, hidden_size = hidden_states.shape
523
- assert (
524
- patch_num * hidden_size
525
- ) == self.flatten_size, f"patch_num ({patch_num}) * hidden_size ({hidden_size}) != flatten_size ({self.flatten_size})"
526
  hidden_states = hidden_states.transpose(0, 1).reshape(-1, self.flatten_size).contiguous()
527
- expert_output = self.output_layer(hidden_states) # [batch_size, expert_output_size]
528
  if output_scale is not None:
529
  original_dtype = expert_output.dtype
530
  expert_output = expert_output * output_scale.unsqueeze(-1)
@@ -532,33 +478,29 @@ class PatchMoEExpert_v2(nn.Module):
532
 
533
  return expert_output
534
 
535
- def forward(self, expert_input, rotary_pos_emb, expert_probs=None):
536
  hidden_states, attention_mask, input_mask = self._forward_patch_embedding(expert_input)
537
  for layer in self.layers:
538
- hidden_states = layer(
539
- hidden_states, attention_mask, rotary_pos_emb[: hidden_states.shape[0]]
540
- )
541
  hidden_states = self.final_layernorm(hidden_states)
542
  expert_output = self._forward_output(hidden_states, expert_probs, input_mask)
543
  return expert_output
544
 
545
 
546
- class SequentialPatchMoE(nn.Module):
547
- def __init__(self, config, expert_output_size=336):
548
  super().__init__()
549
  self.config = config
550
  self.expert_output_size = expert_output_size
551
- self.local_experts = nn.ModuleList(
552
- [
553
- PatchMoEExpert_v2(
554
- config,
555
- expert_output_size=expert_output_size,
556
- patch_input_size=config.patch_size_list[expert_id],
557
- final_layernorm=config.moe_expert_final_layernorm,
558
- )
559
- for expert_id in range(config.num_moe_experts)
560
- ]
561
- )
562
 
563
  def forward(self, input, routing_map, rotary_pos_emb, expert_probs):
564
  expert_output_list = []
@@ -566,19 +508,15 @@ class SequentialPatchMoE(nn.Module):
566
 
567
  for i, expert in enumerate(self.local_experts):
568
  token_mask = routing_map[:, i].bool() # shape (batch,)
569
- current_inputs = input[token_mask] # (num_tokens_for_expert, seq_len)
570
- current_probs = expert_probs[token_mask, i]
571
 
572
  if current_inputs.numel() == 0:
573
- expert_output = torch.zeros(
574
- 0, self.expert_output_size, device=input.device, dtype=input.dtype
575
- )
576
  else:
577
  expert_output = expert(current_inputs, rotary_pos_emb, current_probs)
578
 
579
- full_output = torch.zeros(
580
- batch_size, self.expert_output_size, device=input.device, dtype=input.dtype
581
- )
582
  full_output[token_mask] = expert_output
583
  expert_output_list.append(full_output)
584
 
@@ -601,7 +539,7 @@ class RouterGatingLinearFunction(torch.autograd.Function):
601
  ctx.weight_dtype = weight.dtype
602
  inp_shape = inp.shape
603
  inp = inp.view(-1, inp_shape[-1])
604
-
605
  output = torch.mm(inp.to(router_dtype), weight.to(router_dtype).t())
606
 
607
  output = output.view(*inp_shape[:-1], -1)
@@ -617,12 +555,11 @@ def router_gating_linear(inp: torch.Tensor, weight: torch.Tensor, router_dtype:
617
  return RouterGatingLinearFunction.apply(inp, weight, router_dtype)
618
 
619
 
620
- class Router(ABC, nn.Module):
621
  """Base Router class"""
622
 
623
  def __init__(
624
- self,
625
- config: PatchMoeConfig,
626
  ) -> None:
627
  """
628
  Initialize the Router module.
@@ -635,28 +572,24 @@ class Router(ABC, nn.Module):
635
  self.config = config
636
 
637
  # Initialize the gate weights.
638
-
639
  if self.config.patch_size_list is not None:
640
  assert self.config.moe_router_input_size is not None
641
  self.weight = torch.nn.Parameter(
642
- torch.empty(
643
- (self.config.num_moe_experts, self.config.moe_router_input_size),
644
- dtype=torch.float32,
645
- )
646
  )
647
  else:
648
  self.weight = torch.nn.Parameter(
649
- torch.empty(
650
- (self.config.num_moe_experts, self.config.hidden_size), dtype=torch.float32
651
- )
652
  )
653
  self.reset_parameters()
654
-
655
  def reset_parameters(self):
656
  """Reset the router parameters."""
657
- torch.nn.init.normal_(self.weight, mean=0, std=self.config.init_method_std)
658
  self.weight.data = self.weight.data.to(dtype=self.config.torch_dtype)
659
 
 
660
  def gating(self, input: torch.Tensor):
661
  """Forward pass of the router gate.
662
 
@@ -700,8 +633,7 @@ class TopKRouter(Router):
700
  """Route each token to the top-k experts."""
701
 
702
  def __init__(
703
- self,
704
- config: PatchMoeConfig,
705
  ) -> None:
706
  """Initialize the zero token dropping router.
707
 
@@ -716,17 +648,18 @@ class TopKRouter(Router):
716
  self.enable_expert_bias = self.config.moe_router_enable_expert_bias
717
  if self.enable_expert_bias:
718
  self.register_buffer(
719
- "local_tokens_per_expert",
720
  torch.zeros(self.config.num_moe_experts, dtype=torch.float32),
721
  persistent=False,
722
  )
723
  self.register_buffer(
724
- "expert_bias", torch.zeros(self.config.num_moe_experts, dtype=torch.float32)
725
  )
726
  else:
727
  self.local_tokens_per_expert = None
728
  self.expert_bias = None
729
 
 
730
  def routing(self, logits: torch.Tensor):
731
  """Top-k routing function
732
 
@@ -763,7 +696,7 @@ class TopKRouter(Router):
763
  return scores, routing_map
764
 
765
 
766
- class PatchMoEMoELayer(nn.Module):
767
  def __init__(self, config, layer_number):
768
  super().__init__()
769
  self.config = config
@@ -781,50 +714,46 @@ class PatchMoEMoELayer(nn.Module):
781
  self.expert_output_size = config.seq_length
782
 
783
  if self.is_last_layer and self.config.heterogeneous_moe_layer:
784
- # If heterogeneous_moe_layer is True, the backcast will be None
785
- self.backcast_layernorm = None
786
  else:
787
  self.backcast_layernorm = RMSNorm(self.seq_length)
788
 
789
- self.experts = SequentialPatchMoE(
790
- config,
791
- expert_output_size=self.expert_output_size,
792
- )
793
- self.shared_experts = PatchMoEExpert_v2(
794
- config,
795
- expert_output_size=self.expert_output_size,
796
- patch_input_size=config.shared_patch_size,
797
- final_layernorm=config.moe_expert_final_layernorm,
798
- )
799
 
800
  def time_series_preprocess(self, input: torch.Tensor):
801
  """
802
- Preprocess time series(sample) for dispatch.
803
 
804
- Applies RevIN to input time series(sample), and process the input mask (0: mask, 1: unmask)
805
 
806
- Args:
807
- input (torch.Tensor): The input time series (samples) to the MoE layer. [batch_size, seq_len]
808
 
809
- Returns:
810
- input (torch.Tensor): The (RevIN) backcast time series (samples). [batch_size, seq_len]
811
- means (torch.Tensor): The means of the non-masked backcast time series (samples). [batch_size, 1]
812
- stdev (torch.Tensor): The standard deviation of the non-masked backcast time series (samples). [batch_size, 1]
813
  """
814
 
815
  batch_size, seq_len = input.shape
816
- assert seq_len == self.seq_length, f"seq_len {seq_len} != self.seq_length {self.seq_length}"
817
 
818
  # Create input_mask based on pad_length
819
  # When a time point is masked, its value is mask_pad_value(default:255.)
820
- input_mask = (
821
- input != self.config.mask_pad_value
822
- ) # 0: mask, 1: unmask [batch_size, seq_len]
823
-
824
  self.input_mask = input_mask
825
-
826
  return input
827
-
828
  def router_and_preprocess(self, backcast: torch.Tensor):
829
  """Compute and preprocess time series(sample) routing for dispatch.
830
 
@@ -836,22 +765,20 @@ class PatchMoEMoELayer(nn.Module):
836
  # backcast [batch_size, seq_len] means/stdev [batch_size, 1]
837
  backcast = self.time_series_preprocess(backcast)
838
 
839
- residual = backcast # residual: [batch_size, seq_len], the input to the shared experts
840
 
841
  # TODO: Check the effective of the masked value to the router
842
- probs, routing_map = self.router(
843
- backcast * self.input_mask
844
- ) # probs/routing_map: [batch_size, num_experts]
845
 
846
  return backcast, probs, residual, routing_map
847
 
848
  def experts_compute(
849
  self,
850
- input: torch.Tensor, # [num_permuted_samples_after_dispatch, seq_len]
851
- probs: torch.Tensor, # [num_permuted_samples_after_dispatch]
852
- residual: torch.Tensor, # [batch_size, seq_len]
853
  rotary_pos_emb: torch.Tensor,
854
- routing_map: torch.Tensor, # [seq_len, 1, 1, kv_channels(hidden_size // num_heads)]
855
  ):
856
  """Computes the output of the experts on the dispatched time series(sample).
857
 
@@ -863,19 +790,20 @@ class PatchMoEMoELayer(nn.Module):
863
  """
864
  # shared_expert_output: [batch_size, seq_len (+ pred_len)]
865
  shared_experts_output = self.shared_experts(residual, rotary_pos_emb)
866
-
867
  # dispatched_input (global_input_tokens): [num_permuted_samples_after_dispatch_postprocess(sorted), seq_len]
868
  # tokens_per_expert (global_probs): [num_experts]
869
  # permuted_probs (global_probs): [num_permuted_samples_after_dispatch_postprocess(sorted)]
870
-
871
  experts_output = self.experts(input, routing_map, rotary_pos_emb, probs)
872
 
 
873
  return experts_output, shared_experts_output
874
-
875
  def postprocess(
876
- self,
877
- backcast: torch.Tensor, # [batch_size, seq_len]
878
- forecast: torch.Tensor, # [batch_size, pred_len]
879
  output_backcast: torch.Tensor, # [batch_size, seq_len]
880
  output_forecast: torch.Tensor, # [batch_size, pred_len]
881
  ):
@@ -889,21 +817,20 @@ class PatchMoEMoELayer(nn.Module):
889
  stdev (torch.Tensor): The standard deviation of the non-masked backcast time series (samples). [batch_size, 1]
890
  backcast_mask (torch.Tensor): The previous layer's backcast mask of time series (samples) . [batch_size, seq_len]
891
  """
892
- if output_backcast is not None:
893
- output_backcast = self.backcast_layernorm(output_backcast) # LayerNorm
 
 
894
  if self.config.residual_backcast:
895
  output_backcast = backcast - output_backcast
896
 
897
- output_backcast[~self.input_mask] = (
898
- self.config.mask_pad_value
899
- ) # Important! Recover the mask time point back to mask_pad_value(default:255.)
900
-
901
- if (
902
- self.config.do_expert_forecast and forecast is not None
903
- ): # The first layer's forecast is None
904
  output_forecast = forecast + output_forecast
905
-
906
  return output_backcast, output_forecast
 
907
 
908
  def combine(
909
  self,
@@ -916,67 +843,60 @@ class PatchMoEMoELayer(nn.Module):
916
  experts (e.g., via an All-to-All communication). It then adds the output
917
  from the shared expert if it exists.
918
  """
919
- assert (
920
- experts_output.shape == shared_experts_output.shape
921
- ), f"experts_output shape {experts_output.shape} doesn't equal to shared_experts_output shape:{shared_experts_output.shape}"
922
  output = experts_output + shared_experts_output
923
 
924
  if self.is_last_layer and self.config.heterogeneous_moe_layer:
925
  output_backcast = None
926
  output_forecast = output
927
- assert (
928
- output_forecast.shape[1] == self.pred_length
929
- ), f"heterogeneous_moe_layer=True, expected the last moe layer's output pred len: {self.pred_length}, but got {output_forecast.shape[1]}"
930
  else:
931
  # Noting: the mask time point there maybe not mask_pad_value(default:255.), it will be postprocessed
932
- output_backcast = output[:, : self.seq_length] # [batch_size, seq_len]
933
-
934
  if self.config.do_expert_forecast:
935
- output_forecast = output[:, self.seq_length :] # [batch_size, pred_len]
936
- assert (
937
- output_forecast.shape[1] == self.pred_length
938
- ), f"do_expert_forecast=True, expected the last moe layer's output pred len: {self.pred_length}, but got {output_forecast.shape[1]}"
939
  else:
940
  output_forecast = None
941
-
942
  return output_backcast, output_forecast
943
 
944
- def forward(self, backcast, forecast, rotary_pos_emb):
 
945
  inputs, probs, residual, routing_map = self.router_and_preprocess(backcast)
946
- experts_output, shared_experts_output = self.experts_compute(
947
- inputs, probs, residual, rotary_pos_emb, routing_map
948
- )
949
  output_backcast, output_forecast = self.combine(experts_output, shared_experts_output)
950
- output_backcast, output_forecast = self.postprocess(
951
- backcast, forecast, output_backcast, output_forecast
952
- )
953
  return output_backcast, output_forecast
954
 
955
 
956
- class PatchMoEBlock(nn.Module):
957
- def __init__(self, config):
 
958
  super().__init__()
959
  self.config = config
960
- self.layers = nn.ModuleList(
961
- [
962
- PatchMoEMoELayer(config, layer_num + 1)
963
  for layer_num in range(self.config.num_hidden_layers)
964
- ]
965
- )
966
-
967
- def forward(self, x, rotary_pos_emb):
968
  backcast = x
969
  forecast = None
970
  for layer in self.layers:
971
- backcast, forecast = layer(backcast, forecast, rotary_pos_emb)
972
- return backcast, forecast
973
 
974
 
975
- class PatchMoEPreTrainedModel(PreTrainedModel):
976
- config_class = PatchMoeConfig
 
977
  base_model_prefix = "model"
978
  supports_gradient_checkpointing = True
979
- _no_split_modules = ["PatchMoEMoELayer"]
980
  _skip_keys_device_placement = "past_key_values"
981
  _supports_flash_attn_2 = True
982
  _supports_sdpa = False
@@ -992,77 +912,73 @@ class PatchMoEPreTrainedModel(PreTrainedModel):
992
  if module.padding_idx is not None:
993
  module.weight.data[module.padding_idx].zero_()
994
 
995
-
996
- class PatchMoEModel(PatchMoEPreTrainedModel):
997
- def __init__(self, config: PatchMoeConfig):
998
  super().__init__(config)
999
  self.config = config
1000
  self.seq_length = config.seq_length
1001
  self.rotary_pos_emb = RotaryEmbedding(
1002
- kv_channels=self.config.kv_channels,
1003
- rotary_base=config.rotary_base,
1004
- use_cpu_initialization=self.config.use_cpu_initialization,
1005
- rotary_interleaved=self.config.rotary_interleaved,
1006
  )
1007
- self.decoder = PatchMoEBlock(config=config)
 
 
1008
  if self.config.do_expert_forecast and self.config.heterogeneous_moe_layer:
1009
  self.output_layer = IdentityOp()
1010
  else:
1011
- self.output_layer = nn.Linear(
1012
- in_features=self.seq_length,
1013
- out_features=self.config.pred_length,
1014
- bias=self.config.add_bias_linear,
1015
- )
1016
 
1017
  def revin(
1018
  self,
1019
- input: Tensor, # [batch_size, seq_len]
1020
- input_mask: Tensor, # [batch_size, seq_len] 0:mask, 1:unmask
1021
  ):
1022
- """Normalization from Non-stationary Transformer"""
1023
 
1024
  input_data = input * input_mask
1025
- sum_per_sample = torch.sum(
1026
- input_data, dim=1, keepdim=True
1027
- ).detach() # [batch_size, 1], torch.bfloat16
1028
- count_per_sample = torch.sum(
1029
- input_mask, dim=1, keepdim=True
1030
- ).detach() # [batch_size, 1], torch.int64
1031
- assert (
1032
- torch.any(count_per_sample == 0) == False
1033
- ), f"There is zero in count_per_sample, shape: {input[torch.where(count_per_sample.squeeze(1) == 0)[0]]}"
1034
- means = sum_per_sample / count_per_sample # [batch_size, 1]
1035
  input_data = input_data - means
1036
  input_data = input_data * input_mask
1037
- var_per_sample = (
1038
- torch.sum(input_data**2, dim=1, keepdim=True).detach() / count_per_sample
1039
- ) # [batch_size, 1]
1040
  stdev = torch.sqrt(var_per_sample + 1e-9)
1041
  input_data = input_data / stdev
1042
  input_data = input_data * input_mask
1043
 
1044
- # recover the mask_pad_value(default:255.)
1045
  input = input * ~(input_mask) + input_data
1046
 
1047
  return input, means, stdev
1048
 
1049
  def forward(self, input, revin):
 
 
 
 
 
1050
  batch_size, input_len = input.shape
 
 
1051
  if input_len > self.seq_length:
1052
- input = input[:, -self.seq_length :]
1053
  elif input_len < self.seq_length:
1054
  pad_len = self.seq_length - input_len
1055
- input = F.pad(
1056
- input, pad=(pad_len, 0), mode="constant", value=self.config.mask_pad_value
1057
- )
1058
  input_len = self.seq_length
1059
 
1060
- input_mask = input != self.config.mask_pad_value
1061
 
1062
  # Step1. RevIN
1063
  if revin:
1064
  input, means, stdev = self.revin(input, input_mask)
1065
-
1066
  # Step2. Get rotary_pos_emb
1067
  # rotary_pos_emb [input_len, 1, 1, kv_channels(hidden_size // num_heads)]
1068
  rotary_pos_emb = self.rotary_pos_emb(input_len, device=input.device)
@@ -1070,21 +986,23 @@ class PatchMoEModel(PatchMoEPreTrainedModel):
1070
  # Step3. Do one-step inference to get mixed forecasts from multiple forecast heads
1071
  # mixed_pred: [batch_size, sum(multi_forecast_head)]
1072
  mixed_pred = self._inference_step(
1073
- input=input, input_mask=input_mask, rotary_pos_emb=rotary_pos_emb
 
 
1074
  )
1075
 
1076
- # Step4. Based on the mixed forecasts, do auto-regressive inference according to
1077
  # the step list of each forecast head
1078
- if self.config.multi_forecast_head_type == "single":
1079
  final_output = self._auto_regressive_single_head(
1080
- input=input,
1081
- input_mask=input_mask,
1082
- patchmoe_forecast=mixed_pred,
1083
- rotary_pos_emb=rotary_pos_emb,
1084
  )
1085
  else:
1086
  raise NotImplementedError
1087
-
1088
  # Step5. RevIN
1089
  if revin:
1090
  final_output = final_output * (stdev.repeat(1, self.config.inference_length))
@@ -1093,58 +1011,57 @@ class PatchMoEModel(PatchMoEPreTrainedModel):
1093
  return final_output.detach().float()
1094
 
1095
  def _inference_step(
1096
- self,
1097
- input,
1098
- input_mask,
1099
  rotary_pos_emb,
1100
- ):
1101
  if self.config.do_base_forecast:
1102
  base_forecast, _ = self.base_output_layer(input)
1103
  else:
1104
  base_forecast = None
1105
 
1106
  decoder_backcast, decoder_forecast = self.decoder(
1107
- input, # [batch_size, seq_len]
1108
- rotary_pos_emb, # [input_len, 1, 1, kv_channels(hidden_size // num_heads)]
1109
  )
1110
 
1111
  if self.config.do_expert_forecast:
1112
- assert decoder_forecast is not None, f"decoder_forecast is None"
1113
  if self.config.heterogeneous_moe_layer:
1114
  decoder_forecast = self.output_layer(decoder_forecast) # IdentityOp
1115
  else:
1116
- final_forecast = self.output_layer(decoder_backcast * input_mask)
1117
  decoder_forecast = decoder_forecast + final_forecast
1118
  else:
1119
  # The decoder_backcast contains the mask_pad_val(default:255.)
1120
  decoder_forecast, _ = self.output_layer(decoder_backcast * input_mask)
1121
-
1122
  if self.config.do_base_forecast:
1123
- assert base_forecast is not None, f"base_forecast is None"
1124
- patchmoe_forecast = base_forecast + decoder_forecast
1125
  else:
1126
- patchmoe_forecast = decoder_forecast
1127
-
1128
- return patchmoe_forecast
1129
 
1130
  def _auto_regressive_single_head(
1131
  self,
1132
- input, # [batch_size, seq_len]
1133
- input_mask, # [batch_size, seq_len]
1134
- patchmoe_forecast, # [batch_size, max(multi_forecast_head)]
1135
- rotary_pos_emb, # [seq_len, 1, 1, kv_channels(hidden_size // num_heads)]
1136
- auto_regressive_strategy="from_long_to_short",
1137
  ):
1138
  """auto regressive prediction with [single] head"""
1139
- assert (
1140
- self.config.multi_forecast_head_type == "single"
1141
- ), f"_auto_regressive_single_head only support multi_forecast_head_type==single "
1142
 
1143
- if auto_regressive_strategy == "from_long_to_short":
1144
  # From long to short
1145
  multi_forecast_head_list = sorted(self.config.multi_forecast_head_list, reverse=True)
1146
 
1147
- final_output = patchmoe_forecast
1148
  while final_output.shape[1] < self.config.inference_length:
1149
  # adaptive choose the forecast head
1150
  remain_pred_len = self.config.inference_length - final_output.shape[1]
@@ -1154,39 +1071,28 @@ class PatchMoEModel(PatchMoEPreTrainedModel):
1154
  if idx == len(multi_forecast_head_list):
1155
  idx = len(multi_forecast_head_list) - 1
1156
  head_pred_len = multi_forecast_head_list[idx]
1157
-
1158
  # one-step model prediction
1159
- input = torch.cat([input, patchmoe_forecast], dim=1)[
1160
- :, -self.seq_length :
1161
- ].contiguous()
1162
  input_mask = torch.cat(
1163
- [
1164
- input_mask,
1165
- torch.ones(
1166
- patchmoe_forecast.shape,
1167
- dtype=input_mask.dtype,
1168
- device=input_mask.device,
1169
- ),
1170
- ],
1171
- dim=1,
1172
- )[
1173
- :, -self.seq_length :
1174
- ].contiguous() # 0:mask, 1:unmask
1175
-
1176
- patchmoe_forecast = self._inference_step(
1177
- input=input,
1178
- input_mask=input_mask,
1179
- rotary_pos_emb=rotary_pos_emb,
1180
  )
1181
 
1182
  # the core idea of multi forecast head type of [single]
1183
- patchmoe_forecast = patchmoe_forecast[:, :head_pred_len]
1184
-
1185
- final_output = torch.cat([final_output, patchmoe_forecast], dim=1)
1186
-
1187
- final_output = final_output[:, : self.config.inference_length]
1188
 
1189
- elif auto_regressive_strategy == "from_short_to_long":
1190
  # From short to long
1191
  # in validate_args, it has been sorted, and check the valid config
1192
  multi_forecast_head_list = sorted(self.config.multi_forecast_head_list)
@@ -1197,15 +1103,14 @@ class PatchMoEModel(PatchMoEPreTrainedModel):
1197
  else:
1198
  ar_step = min(
1199
  self.config.autoregressive_step_list[idx],
1200
- self.config.multi_forecast_head_list[idx + 1]
1201
- // self.config.multi_forecast_head_list[idx],
1202
  )
1203
  # ar_step = multi_forecast_head_list[idx + 1] // multi_forecast_head_list[idx]
1204
-
1205
  multi_forecast_head_dict[head_pred_len] = ar_step
1206
-
1207
  # the core idea of strategy [from_short_to_long]
1208
- mixed_pred = patchmoe_forecast
1209
  output_list = []
1210
  cur_pred = None
1211
  cur_pred_len = 0
@@ -1219,62 +1124,50 @@ class PatchMoEModel(PatchMoEPreTrainedModel):
1219
  if ar_step == 0:
1220
  # Ignore the current forecast head
1221
  continue
1222
-
1223
  # Add current head's first auto-regressive step of prediction
1224
- head_pred = mixed_pred[:, :head_pred_len] # [single]
1225
  output_list.append(head_pred[:, cur_pred_len:])
1226
  cur_pred = torch.cat(output_list, dim=1)
1227
  cur_pred_len = cur_pred.shape[1]
1228
  if cur_pred_len >= self.config.inference_length:
1229
  break
1230
-
1231
  # Do auto-regressive of the rest of the steps
1232
  for _ in range(1, ar_step + 1):
1233
  # one-step model prediction
1234
- cur_input = torch.cat([input, cur_pred], dim=1)[
1235
- :, -self.seq_length :
1236
- ].contiguous()
1237
  cur_input_mask = torch.cat(
1238
- [
1239
- input_mask,
1240
- torch.ones(
1241
- cur_pred.shape, dtype=input_mask.dtype, device=input_mask.device
1242
- ),
1243
- ],
1244
- dim=1,
1245
- )[
1246
- :, -self.seq_length :
1247
- ].contiguous() # 0:mask, 1:unmask
1248
-
1249
- patchmoe_forecast = self._inference_step(
1250
- input=cur_input,
1251
- input_mask=cur_input_mask,
1252
- rotary_pos_emb=rotary_pos_emb,
1253
  )
1254
 
1255
- head_pred = patchmoe_forecast[:, :head_pred_len]
1256
  output_list.append(head_pred)
1257
  cur_pred = torch.cat(output_list, dim=1)
1258
  cur_pred_len = cur_pred.shape[1]
1259
  if cur_pred_len >= self.config.inference_length:
1260
  break
1261
-
1262
  if cur_pred_len >= self.config.inference_length:
1263
  break
1264
-
1265
- final_output = cur_pred[
1266
- :, : self.config.inference_length
1267
- ] # [batch_size, inference_len]
1268
 
1269
  assert final_output.shape[1] == self.config.inference_length
1270
  return final_output
1271
 
1272
-
1273
- class PatchMoEForPrediction(PatchMoEPreTrainedModel, PatchMoEGenerationMixin):
1274
- def __init__(self, config: PatchMoeConfig):
1275
  super().__init__(config)
1276
  self.config = config
1277
- self.model = PatchMoEModel(self.config)
1278
  self.post_init()
1279
 
1280
  def forward(
@@ -1287,7 +1180,10 @@ class PatchMoEForPrediction(PatchMoEPreTrainedModel, PatchMoEGenerationMixin):
1287
  revin: Optional[bool] = False,
1288
  ):
1289
  self.model.config.inference_length = max_output_length
1290
- outputs = self.model(input=input_ids, revin=revin)
 
 
 
1291
 
1292
  loss = None
1293
  logits = outputs
@@ -1309,7 +1205,7 @@ class PatchMoEForPrediction(PatchMoEPreTrainedModel, PatchMoEGenerationMixin):
1309
  attention_mask=None,
1310
  inputs_embeds=None,
1311
  revin=False,
1312
- **kwargs,
1313
  ):
1314
  """
1315
  Prepare model inputs for autoregressive generation.
@@ -1317,10 +1213,8 @@ class PatchMoEForPrediction(PatchMoEPreTrainedModel, PatchMoEGenerationMixin):
1317
 
1318
  model_inputs = {"input_ids": input_ids}
1319
 
1320
- model_inputs.update(
1321
- {
1322
- "revin": revin,
1323
- }
1324
- )
1325
 
1326
- return model_inputs
 
1
  import torch
2
+ from torch._dynamo import config
3
+ from typing import List, Optional, Union
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
+ # import transformer_engine as te
7
  from torch import Tensor
8
  import math
9
+ from einops import rearrange, repeat
10
  from functools import reduce
11
  from abc import ABC, abstractmethod
12
+ from configuration_FalconTST import FalconTSTConfig
13
+ from ts_generation_mixin import FalconTSTGenerationMixin
14
+ from transformers import PreTrainedModel, Cache, DynamicCache
15
+ from transformers.activations import ACT2FN
16
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
17
+ from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast
18
 
19
 
20
  def _rotate_half(x: Tensor, rotary_interleaved: bool) -> Tensor:
 
37
 
38
 
39
  def _apply_rotary_pos_emb_bshd(
40
+ t: Tensor,
41
+ freqs: Tensor,
42
+ rotary_interleaved: bool = False,
43
+ multi_latent_attention: bool = False,
44
+ mscale: float = 1.0,
45
+ ) -> Tensor:
46
  """Apply rotary positional embedding to input tensor T.
47
 
48
  check https://kexue.fm/archives/8265 for detailed formulas
 
100
  """
101
  assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
102
 
103
+ def compute_topk(scores, topk,):
 
 
 
104
  return torch.topk(scores, k=topk, dim=1)
105
 
106
  if score_function == "softmax":
107
  if use_pre_softmax:
108
  scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
109
+ probs, top_indices = compute_topk(scores, topk, )
 
 
 
110
  else:
111
+ scores, top_indices = compute_topk(logits, topk, )
 
 
 
112
  probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
113
  elif score_function == "sigmoid":
114
  scores = torch.sigmoid(logits.float()).type_as(logits)
115
  if expert_bias is not None:
116
  scores_for_routing = scores + expert_bias
117
+ _, top_indices = compute_topk(scores_for_routing, topk, )
 
 
 
118
  scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
119
  else:
120
+ scores, top_indices = compute_topk(scores, topk,)
 
 
 
121
  probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
122
  else:
123
  raise ValueError(f"Invalid score_function: {score_function}")
 
156
 
157
  dim = kv_channels
158
  self.rotary_interleaved = rotary_interleaved
159
+ device = 'cpu' if use_cpu_initialization else torch.cuda.current_device()
160
  self.inv_freq = 1.0 / (
161
  rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
162
  )
 
171
  freqs = torch.outer(seq, self.inv_freq) # [seq len, dim]
172
  return freqs
173
 
174
+
175
+ def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False, device=None) -> Tensor:
 
176
  """Forward pass of RoPE embedding.
177
 
178
  Args:
 
185
  """
186
  if device is None:
187
  device = self.inv_freq.device
188
+ if self.inv_freq.device.type == 'cpu':
189
  # move `inv_freq` to GPU once at the first micro-batch forward pass
190
  self.inv_freq = self.inv_freq.to(device=device)
191
 
 
203
  return emb.to(device)
204
 
205
  def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
206
+ state_dict.pop(f'{prefix}inv_freq', None)
207
  return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
208
 
209
  def get_rotary_seq_len(
 
237
  self.variance_epsilon = eps
238
 
239
  def forward(self, hidden_states):
240
+ '''
241
+ hidden_states [bs, patch_num, d_model]
242
+ '''
243
  input_dtype = hidden_states.dtype
244
  hidden_states = hidden_states.to(torch.float32)
245
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
 
247
  return self.weight * hidden_states.to(input_dtype)
248
 
249
 
250
+ class FlashAttention(nn.Module):
251
  """Implement the scaled dot product attention with softmax.
252
  Arguments
253
  ---------
 
264
  self.softmax_scale = softmax_scale
265
  self.drop = nn.Dropout(attention_dropout)
266
 
267
+ def forward(self, q,k,v,attention_mask,causal=None, ):
 
 
 
 
 
 
 
268
  """Implements the multihead softmax attention.
269
  Arguments
270
  ---------
 
275
  """
276
  causal = self.causal if causal is None else causal
277
 
278
+ q = q.transpose(0,1).contiguous()
279
+ k = k.transpose(0,1).contiguous()
280
+ v = v.transpose(0,1).contiguous()
281
 
282
  batch_size, seq_len = q.shape[0], q.shape[1]
283
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
284
+ # scores
285
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
286
+ scores = scores.masked_fill(attention_mask == 0, float('-1e9'))
287
  # Softmax
288
  attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
289
  # Dropout
290
  attention_drop = self.drop(attention)
291
  output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
292
+ output = output.reshape(batch_size, seq_len, -1).transpose(0,1).contiguous()
293
  return output
294
 
295
 
296
+
297
+ class TEDotProductAttention(nn.Module):
298
+ def __init__(self, flash_attention,):
299
+ super().__init__()
300
+ self.flash_attention = flash_attention
301
+
302
+ def forward(self, q, k, v, mask=None):
303
+ # Prioritize using FlashAttention
304
+ return self.flash_attention(q, k, v, mask)
305
+
306
+
307
  class SelfAttention(nn.Module):
308
+ def __init__(self,config,):
 
 
 
309
  super().__init__()
310
  self.config = config
311
+ q_layernorm=config.q_layernorm
312
+ k_layernorm=config.k_layernorm
313
  self.hidden_size = config.hidden_size
314
+ self.core_attention = TEDotProductAttention(
315
+ flash_attention=FlashAttention(),
 
 
 
 
 
 
 
 
316
  )
317
+ self.linear_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.add_bias_linear,)
318
+ self.linear_qkv = nn.Linear(self.hidden_size, 3*self.hidden_size, bias=config.add_bias_linear,)
319
  if q_layernorm:
320
  self.q_layernorm = RMSNorm(self.hidden_size)
321
  else:
 
325
  else:
326
  self.k_layernorm = IdentityOp()
327
 
328
+ def forward(self, x, attention_mask,rotary_pos_emb):
329
  qkv = self.linear_qkv(x)
330
+ qkv = qkv.view(qkv.size(0), qkv.size(1), self.config.num_attention_heads,-1)
331
  q, k, v = qkv.chunk(3, dim=-1)
 
 
 
 
 
332
  # Apply rotary encoding to q and k
333
  rotary_pos_emb = (rotary_pos_emb,) * 2
334
  q_pos_emb, k_pos_emb = rotary_pos_emb
335
  q = _apply_rotary_pos_emb_bshd(q, q_pos_emb)
336
  k = _apply_rotary_pos_emb_bshd(k, k_pos_emb)
337
 
338
+ q = self.q_layernorm(q)
339
+ k = self.k_layernorm(k)
340
+ # attention
341
  attn_output = self.core_attention(q, k, v, attention_mask)
342
  output = self.linear_proj(attn_output)
343
  return output
344
 
345
 
346
+
347
  class MLP(nn.Module):
348
+ def __init__(self,config,in_features):
349
  super().__init__()
350
+ self.config= config
351
+ self.linear_fc1 = nn.Linear(in_features, self.config.moe_ffn_hidden_size*2, bias=self.config.add_bias_linear,)
352
+ self.linear_fc2 = nn.Linear(self.config.moe_ffn_hidden_size, self.config.hidden_size, bias=self.config.add_bias_linear,)
 
 
 
 
 
 
 
 
353
 
354
  def forward(self, x):
355
  x = self.swiglu(self.linear_fc1(x))
356
  x = self.linear_fc2(x)
357
  return x
358
 
359
+ def swiglu(self,y):
360
  """Performs SwiGLU (Swish-Gated Linear Unit) activation function.
361
 
362
  Args:
 
379
  self.input_layernorm = IdentityOp()
380
  self.self_attention = SelfAttention(config)
381
  self.pre_mlp_layernorm = RMSNorm(self.config.hidden_size)
382
+ self.mlp = MLP(config,self.config.hidden_size)
383
 
384
+ def forward(self, x, attention_mask,rotary_pos_emb):
385
  residual = x
386
  x = self.input_layernorm(x)
387
  x = self.self_attention(x, attention_mask, rotary_pos_emb)
 
393
  return x
394
 
395
 
396
+ class FalconTSTExpert(nn.Module):
397
+ def __init__(self, config, patch_input_size=32,expert_output_size=336,final_layernorm=True):
398
  super().__init__()
399
  self.config = config
400
+ self.patch_size= patch_input_size
401
  self.seq_length = config.seq_length
402
+ assert self.seq_length % self.patch_size == 0, f'invalid patch_size: {self.patch_size} when seq_length={self.seq_length}'
 
 
403
  self.patch_num = self.seq_length // self.patch_size
404
  self.flatten_size = self.patch_num * self.config.hidden_size
405
 
406
+ self.layers = nn.ModuleList([
407
+ TransformerLayer(config,input_layernorm=config.transformer_input_layernorm)
408
+ for _ in range(self.config.expert_num_layers)
409
+ ])
 
 
410
  if final_layernorm:
411
  self.final_layernorm = RMSNorm(self.config.hidden_size)
412
  else:
413
  self.final_layernorm = IdentityOp()
414
  self.patch_embedding = MLP(config, in_features=patch_input_size)
415
+ self.output_layer = nn.Linear(in_features=self.flatten_size, out_features=expert_output_size, bias=False,)
416
+
 
 
 
417
 
418
  def _forward_patch_embedding(
419
  self,
420
+ input: Tensor, # [batch_size, seq_len]
421
  ):
422
  """
423
  Perform patch embedding on the input time series.
424
 
425
+ This method applies a linear transformation to the input tensor to
426
  convert it into patches and then embeds these patches using a linear layer.
427
  """
428
  batch_size, seq_len = input.shape
429
+ assert seq_len == self.seq_length, f'Expected sequence length {self.seq_length}, but got {seq_len}'
 
 
430
 
431
  # Create input_mask based on pad_length
432
  # When a time point is masked, its value is mask_pad_value(default:255.)
433
+ input_mask = (input != self.config.mask_pad_value) # 0: mask, 1: unmask [batch_size, seq_len]
 
 
434
 
435
  # so whether the masked value 0 has the same effective of attention_mask
436
+ input_data = input * input_mask # [batch_size, seq_len]
437
 
438
  # Patchify the input
439
+ input_data = input_data.unfold(dimension=-1, size=self.patch_size, step=self.patch_size).contiguous() # input [batch_size, patch_num, patch_size]
440
+ hidden_states= self.patch_embedding(input_data) # hidden_states [batch_size, patch_num, hidden_size]
441
+ hidden_states = hidden_states.transpose(0, 1).contiguous() # hidden_states [patch_num, batch_size, hidden_size], To adapt to the Megatron
 
 
 
 
 
 
442
 
443
  # Patchify the mask: only the entire time points in a patch are masked then this patch is masked
444
+ attention_mask = input_mask.unfold(dimension=-1, size=self.patch_size, step=self.patch_size).contiguous() # [batch_size, patch_num, patch_size]
445
+ attention_mask = (attention_mask.sum(-1) == self.patch_size) # [batch_size, patch_num] # 0: mask, 1: unmask
446
+ attention_mask[:, -1] = True # The last patch is not masked
 
 
 
 
447
  _, patch_num = attention_mask.shape
448
+ attention_mask = attention_mask.unsqueeze(2).repeat(1,1,patch_num) * attention_mask.unsqueeze(1).repeat(1,patch_num,1) # [batch_size, patch_num, patch_num]
449
+ attention_mask = attention_mask.unsqueeze(1).contiguous() # [batch_size, 1, patch_num, patch_num]
450
+
 
 
 
 
 
451
 
452
  return hidden_states, attention_mask, input_mask
453
 
454
+
455
+ def _forward_output(self, hidden_states, output_scale=None, input_mask=None, inference_context=None):
 
456
  """
457
+ Perform a forward pass through the output layer.
458
 
459
+ Args:
460
+ expert_input (Tensor): Expert input of shape [batch_size, seq_len]
461
+ hidden_states (Tensor): Transformed hidden states of shape [patch_num, batch_size, hidden_size]
462
+ output_scale (Tensor, optional): Expert probabilities for the output layer [batch_size]
463
+ input_mask (Tensor, optional): Expert input mask of shape [batch_size, seq_len], 0:mask, 1:unmask
464
 
465
+ Returns:
466
+ expert_output (Tensor): Expert output of shape [batch_size, expert_output_size]
467
  """
468
 
469
  # [patch_num, batch_size, hidden_size] -> [batch_size, flatten_size (patch_num * hidden_size)]
470
  patch_num, batch_size, hidden_size = hidden_states.shape
471
+ assert (patch_num * hidden_size) == self.flatten_size, f'patch_num ({patch_num}) * hidden_size ({hidden_size}) != flatten_size ({self.flatten_size})'
 
 
472
  hidden_states = hidden_states.transpose(0, 1).reshape(-1, self.flatten_size).contiguous()
473
+ expert_output = self.output_layer(hidden_states) # [batch_size, expert_output_size]
474
  if output_scale is not None:
475
  original_dtype = expert_output.dtype
476
  expert_output = expert_output * output_scale.unsqueeze(-1)
 
478
 
479
  return expert_output
480
 
481
+ def forward(self, expert_input, rotary_pos_emb,expert_probs=None):
482
  hidden_states, attention_mask, input_mask = self._forward_patch_embedding(expert_input)
483
  for layer in self.layers:
484
+ hidden_states = layer(hidden_states,attention_mask,rotary_pos_emb[:hidden_states.shape[0]])
 
 
485
  hidden_states = self.final_layernorm(hidden_states)
486
  expert_output = self._forward_output(hidden_states, expert_probs, input_mask)
487
  return expert_output
488
 
489
 
490
+ class SequentialFalconTST(nn.Module):
491
+ def __init__(self, config,expert_output_size=336):
492
  super().__init__()
493
  self.config = config
494
  self.expert_output_size = expert_output_size
495
+ self.local_experts = nn.ModuleList([
496
+ FalconTSTExpert(
497
+ config,
498
+ expert_output_size=expert_output_size,
499
+ patch_input_size=config.patch_size_list[expert_id],
500
+ final_layernorm=config.moe_expert_final_layernorm
501
+ )
502
+ for expert_id in range(config.num_moe_experts)
503
+ ])
 
 
504
 
505
  def forward(self, input, routing_map, rotary_pos_emb, expert_probs):
506
  expert_output_list = []
 
508
 
509
  for i, expert in enumerate(self.local_experts):
510
  token_mask = routing_map[:, i].bool() # shape (batch,)
511
+ current_inputs = input[token_mask] # (num_tokens_for_expert, seq_len)
512
+ current_probs = expert_probs[token_mask, i]
513
 
514
  if current_inputs.numel() == 0:
515
+ expert_output = torch.zeros(0, self.expert_output_size, device=input.device, dtype=input.dtype)
 
 
516
  else:
517
  expert_output = expert(current_inputs, rotary_pos_emb, current_probs)
518
 
519
+ full_output = torch.zeros(batch_size, self.expert_output_size, device=input.device, dtype=input.dtype)
 
 
520
  full_output[token_mask] = expert_output
521
  expert_output_list.append(full_output)
522
 
 
539
  ctx.weight_dtype = weight.dtype
540
  inp_shape = inp.shape
541
  inp = inp.view(-1, inp_shape[-1])
542
+
543
  output = torch.mm(inp.to(router_dtype), weight.to(router_dtype).t())
544
 
545
  output = output.view(*inp_shape[:-1], -1)
 
555
  return RouterGatingLinearFunction.apply(inp, weight, router_dtype)
556
 
557
 
558
+ class Router(ABC,nn.Module):
559
  """Base Router class"""
560
 
561
  def __init__(
562
+ self, config: FalconTSTConfig,
 
563
  ) -> None:
564
  """
565
  Initialize the Router module.
 
572
  self.config = config
573
 
574
  # Initialize the gate weights.
575
+
576
  if self.config.patch_size_list is not None:
577
  assert self.config.moe_router_input_size is not None
578
  self.weight = torch.nn.Parameter(
579
+ torch.empty((self.config.num_moe_experts, self.config.moe_router_input_size), dtype=torch.float32)
 
 
 
580
  )
581
  else:
582
  self.weight = torch.nn.Parameter(
583
+ torch.empty((self.config.num_moe_experts, self.config.hidden_size), dtype=torch.float32)
 
 
584
  )
585
  self.reset_parameters()
586
+
587
  def reset_parameters(self):
588
  """Reset the router parameters."""
589
+ torch.nn.init.normal_(self.weight,mean=0,std=self.config.init_method_std)
590
  self.weight.data = self.weight.data.to(dtype=self.config.torch_dtype)
591
 
592
+
593
  def gating(self, input: torch.Tensor):
594
  """Forward pass of the router gate.
595
 
 
633
  """Route each token to the top-k experts."""
634
 
635
  def __init__(
636
+ self, config: FalconTSTConfig,
 
637
  ) -> None:
638
  """Initialize the zero token dropping router.
639
 
 
648
  self.enable_expert_bias = self.config.moe_router_enable_expert_bias
649
  if self.enable_expert_bias:
650
  self.register_buffer(
651
+ 'local_tokens_per_expert',
652
  torch.zeros(self.config.num_moe_experts, dtype=torch.float32),
653
  persistent=False,
654
  )
655
  self.register_buffer(
656
+ 'expert_bias', torch.zeros(self.config.num_moe_experts, dtype=torch.float32)
657
  )
658
  else:
659
  self.local_tokens_per_expert = None
660
  self.expert_bias = None
661
 
662
+
663
  def routing(self, logits: torch.Tensor):
664
  """Top-k routing function
665
 
 
696
  return scores, routing_map
697
 
698
 
699
+ class FalconTSTMoELayer(nn.Module):
700
  def __init__(self, config, layer_number):
701
  super().__init__()
702
  self.config = config
 
714
  self.expert_output_size = config.seq_length
715
 
716
  if self.is_last_layer and self.config.heterogeneous_moe_layer:
717
+ # If heterogeneous_moe_layer is True, the backcast will be None
718
+ self.backcast_layernorm = None
719
  else:
720
  self.backcast_layernorm = RMSNorm(self.seq_length)
721
 
722
+ self.experts = SequentialFalconTST(
723
+ config,
724
+ expert_output_size=self.expert_output_size,
725
+ )
726
+ self.shared_experts = FalconTSTExpert(config,
727
+ expert_output_size=self.expert_output_size,
728
+ patch_input_size=config.shared_patch_size,
729
+ final_layernorm=config.moe_expert_final_layernorm)
 
 
730
 
731
  def time_series_preprocess(self, input: torch.Tensor):
732
  """
733
+ Preprocess time series(sample) for dispatch.
734
 
735
+ Applies RevIN to input time series(sample), and process the input mask (0: mask, 1: unmask)
736
 
737
+ Args:
738
+ input (torch.Tensor): The input time series (samples) to the MoE layer. [batch_size, seq_len]
739
 
740
+ Returns:
741
+ input (torch.Tensor): The (RevIN) backcast time series (samples). [batch_size, seq_len]
742
+ means (torch.Tensor): The means of the non-masked backcast time series (samples). [batch_size, 1]
743
+ stdev (torch.Tensor): The standard deviation of the non-masked backcast time series (samples). [batch_size, 1]
744
  """
745
 
746
  batch_size, seq_len = input.shape
747
+ assert seq_len == self.seq_length, f'seq_len {seq_len} != self.seq_length {self.seq_length}'
748
 
749
  # Create input_mask based on pad_length
750
  # When a time point is masked, its value is mask_pad_value(default:255.)
751
+ input_mask = (input != self.config.mask_pad_value) # 0: mask, 1: unmask [batch_size, seq_len]
752
+
 
 
753
  self.input_mask = input_mask
754
+
755
  return input
756
+
757
  def router_and_preprocess(self, backcast: torch.Tensor):
758
  """Compute and preprocess time series(sample) routing for dispatch.
759
 
 
765
  # backcast [batch_size, seq_len] means/stdev [batch_size, 1]
766
  backcast = self.time_series_preprocess(backcast)
767
 
768
+ residual = backcast # residual: [batch_size, seq_len], the input to the shared experts
769
 
770
  # TODO: Check the effective of the masked value to the router
771
+ probs, routing_map = self.router(backcast * self.input_mask) # probs/routing_map: [batch_size, num_experts]
 
 
772
 
773
  return backcast, probs, residual, routing_map
774
 
775
  def experts_compute(
776
  self,
777
+ input: torch.Tensor, # [num_permuted_samples_after_dispatch, seq_len]
778
+ probs: torch.Tensor, # [num_permuted_samples_after_dispatch]
779
+ residual: torch.Tensor, # [batch_size, seq_len]
780
  rotary_pos_emb: torch.Tensor,
781
+ routing_map:torch.Tensor, # [seq_len, 1, 1, kv_channels(hidden_size // num_heads)]
782
  ):
783
  """Computes the output of the experts on the dispatched time series(sample).
784
 
 
790
  """
791
  # shared_expert_output: [batch_size, seq_len (+ pred_len)]
792
  shared_experts_output = self.shared_experts(residual, rotary_pos_emb)
793
+
794
  # dispatched_input (global_input_tokens): [num_permuted_samples_after_dispatch_postprocess(sorted), seq_len]
795
  # tokens_per_expert (global_probs): [num_experts]
796
  # permuted_probs (global_probs): [num_permuted_samples_after_dispatch_postprocess(sorted)]
797
+
798
  experts_output = self.experts(input, routing_map, rotary_pos_emb, probs)
799
 
800
+
801
  return experts_output, shared_experts_output
802
+
803
  def postprocess(
804
+ self,
805
+ backcast: torch.Tensor, # [batch_size, seq_len]
806
+ forecast: torch.Tensor, # [batch_size, pred_len]
807
  output_backcast: torch.Tensor, # [batch_size, seq_len]
808
  output_forecast: torch.Tensor, # [batch_size, pred_len]
809
  ):
 
817
  stdev (torch.Tensor): The standard deviation of the non-masked backcast time series (samples). [batch_size, 1]
818
  backcast_mask (torch.Tensor): The previous layer's backcast mask of time series (samples) . [batch_size, seq_len]
819
  """
820
+ if output_backcast is not None:
821
+ # 25/8/14 @modified by xiaming replace the revin with layernorm after the moe layer
822
+ # And if we multiply the output_backcast with the input mask, the performance will be hurted
823
+ output_backcast = self.backcast_layernorm(output_backcast) # LayerNorm
824
  if self.config.residual_backcast:
825
  output_backcast = backcast - output_backcast
826
 
827
+ output_backcast[~self.input_mask] = self.config.mask_pad_value # Important! Recover the mask time point back to mask_pad_value(default:255.)
828
+
829
+ if self.config.do_expert_forecast and forecast is not None: # The first layer's forecast is None
 
 
 
 
830
  output_forecast = forecast + output_forecast
831
+
832
  return output_backcast, output_forecast
833
+
834
 
835
  def combine(
836
  self,
 
843
  experts (e.g., via an All-to-All communication). It then adds the output
844
  from the shared expert if it exists.
845
  """
846
+ assert experts_output.shape == shared_experts_output.shape,\
847
+ f'experts_output shape {experts_output.shape} doesn\'t equal to shared_experts_output shape:{shared_experts_output.shape}'
 
848
  output = experts_output + shared_experts_output
849
 
850
  if self.is_last_layer and self.config.heterogeneous_moe_layer:
851
  output_backcast = None
852
  output_forecast = output
853
+ assert output_forecast.shape[1] == self.pred_length, \
854
+ f'heterogeneous_moe_layer=True, expected the last moe layer\'s output pred len: {self.pred_length}, but got {output_forecast.shape[1]}'
 
855
  else:
856
  # Noting: the mask time point there maybe not mask_pad_value(default:255.), it will be postprocessed
857
+ output_backcast = output[:, :self.seq_length] # [batch_size, seq_len]
858
+
859
  if self.config.do_expert_forecast:
860
+ output_forecast = output[:, self.seq_length:] # [batch_size, pred_len]
861
+ assert output_forecast.shape[1] == self.pred_length, \
862
+ f'do_expert_forecast=True, expected the last moe layer\'s output pred len: {self.pred_length}, but got {output_forecast.shape[1]}'
 
863
  else:
864
  output_forecast = None
865
+
866
  return output_backcast, output_forecast
867
 
868
+
869
+ def forward(self, backcast,forecast,rotary_pos_emb):
870
  inputs, probs, residual, routing_map = self.router_and_preprocess(backcast)
871
+ experts_output, shared_experts_output = self.experts_compute(inputs, probs, residual, rotary_pos_emb, routing_map)
 
 
872
  output_backcast, output_forecast = self.combine(experts_output, shared_experts_output)
873
+ output_backcast, output_forecast = self.postprocess(backcast, forecast, output_backcast, output_forecast)
 
 
874
  return output_backcast, output_forecast
875
 
876
 
877
+
878
+ class FalconTSTBlock(nn.Module):
879
+ def __init__(self,config):
880
  super().__init__()
881
  self.config = config
882
+ self.layers = nn.ModuleList([
883
+ FalconTSTMoELayer(config,layer_num +1)
 
884
  for layer_num in range(self.config.num_hidden_layers)
885
+ ])
886
+ def forward(self, x,rotary_pos_emb):
 
 
887
  backcast = x
888
  forecast = None
889
  for layer in self.layers:
890
+ backcast, forecast = layer(backcast,forecast,rotary_pos_emb)
891
+ return backcast,forecast
892
 
893
 
894
+
895
+ class FalconTSTPreTrainedModel(PreTrainedModel):
896
+ config_class = FalconTSTConfig
897
  base_model_prefix = "model"
898
  supports_gradient_checkpointing = True
899
+ _no_split_modules = ["FalconTSTMoELayer"]
900
  _skip_keys_device_placement = "past_key_values"
901
  _supports_flash_attn_2 = True
902
  _supports_sdpa = False
 
912
  if module.padding_idx is not None:
913
  module.weight.data[module.padding_idx].zero_()
914
 
915
+ class FalconTSTModel(FalconTSTPreTrainedModel):
916
+ def __init__(self, config: FalconTSTConfig):
 
917
  super().__init__(config)
918
  self.config = config
919
  self.seq_length = config.seq_length
920
  self.rotary_pos_emb = RotaryEmbedding(
921
+ kv_channels=self.config.kv_channels,
922
+ rotary_base=config.rotary_base,
923
+ use_cpu_initialization=self.config.use_cpu_initialization,
924
+ rotary_interleaved=self.config.rotary_interleaved
925
  )
926
+ self.decoder = FalconTSTBlock(
927
+ config=config
928
+ )
929
  if self.config.do_expert_forecast and self.config.heterogeneous_moe_layer:
930
  self.output_layer = IdentityOp()
931
  else:
932
+ self.output_layer = nn.Linear(in_features=self.seq_length, out_features=self.config.pred_length, bias=self.config.add_bias_linear,)
933
+
 
 
 
934
 
935
  def revin(
936
  self,
937
+ input: Tensor, # [batch_size, seq_len]
938
+ input_mask: Tensor, # [batch_size, seq_len] 0:mask, 1:unmask
939
  ):
940
+ """ Normalization from Non-stationary Transformer"""
941
 
942
  input_data = input * input_mask
943
+ sum_per_sample = torch.sum(input_data, dim=1, keepdim=True).detach() # [batch_size, 1], torch.bfloat16
944
+ count_per_sample = torch.sum(input_mask, dim=1, keepdim=True).detach() # [batch_size, 1], torch.int64
945
+ assert torch.any(count_per_sample == 0) == False, \
946
+ f'There is zero in count_per_sample, shape: {input[torch.where(count_per_sample.squeeze(1) == 0)[0]]}'
947
+ means = sum_per_sample / count_per_sample # [batch_size, 1]
 
 
 
 
 
948
  input_data = input_data - means
949
  input_data = input_data * input_mask
950
+ var_per_sample = torch.sum(input_data ** 2, dim=1, keepdim=True).detach() / count_per_sample # [batch_size, 1]
 
 
951
  stdev = torch.sqrt(var_per_sample + 1e-9)
952
  input_data = input_data / stdev
953
  input_data = input_data * input_mask
954
 
955
+ #recover the mask_pad_value(default:255.)
956
  input = input * ~(input_mask) + input_data
957
 
958
  return input, means, stdev
959
 
960
  def forward(self, input, revin):
961
+ # Apply rotary position embeddings
962
+ # seq_len = patches.size(1)
963
+ # pos_emb = self.rotary_pos_emb(seq_len, patches.device)
964
+ # patches = patches + pos_emb
965
+
966
  batch_size, input_len = input.shape
967
+ # @created by xiaming @modified by baichun
968
+ # realize varied input length
969
  if input_len > self.seq_length:
970
+ input = input[:, -self.seq_length:]
971
  elif input_len < self.seq_length:
972
  pad_len = self.seq_length - input_len
973
+ input = F.pad(input, pad=(pad_len, 0), mode='constant', value=self.config.mask_pad_value)
 
 
974
  input_len = self.seq_length
975
 
976
+ input_mask = (input != self.config.mask_pad_value)
977
 
978
  # Step1. RevIN
979
  if revin:
980
  input, means, stdev = self.revin(input, input_mask)
981
+
982
  # Step2. Get rotary_pos_emb
983
  # rotary_pos_emb [input_len, 1, 1, kv_channels(hidden_size // num_heads)]
984
  rotary_pos_emb = self.rotary_pos_emb(input_len, device=input.device)
 
986
  # Step3. Do one-step inference to get mixed forecasts from multiple forecast heads
987
  # mixed_pred: [batch_size, sum(multi_forecast_head)]
988
  mixed_pred = self._inference_step(
989
+ input=input,
990
+ input_mask=input_mask,
991
+ rotary_pos_emb=rotary_pos_emb
992
  )
993
 
994
+ # Step4. Based on the mixed forecasts, do auto-regressive inference according to
995
  # the step list of each forecast head
996
+ if self.config.multi_forecast_head_type == 'single':
997
  final_output = self._auto_regressive_single_head(
998
+ input=input,
999
+ input_mask=input_mask,
1000
+ FalconTST_forecast=mixed_pred,
1001
+ rotary_pos_emb=rotary_pos_emb
1002
  )
1003
  else:
1004
  raise NotImplementedError
1005
+
1006
  # Step5. RevIN
1007
  if revin:
1008
  final_output = final_output * (stdev.repeat(1, self.config.inference_length))
 
1011
  return final_output.detach().float()
1012
 
1013
  def _inference_step(
1014
+ self,
1015
+ input,
1016
+ input_mask,
1017
  rotary_pos_emb,
1018
+ ):
1019
  if self.config.do_base_forecast:
1020
  base_forecast, _ = self.base_output_layer(input)
1021
  else:
1022
  base_forecast = None
1023
 
1024
  decoder_backcast, decoder_forecast = self.decoder(
1025
+ input, # [batch_size, seq_len]
1026
+ rotary_pos_emb, # [input_len, 1, 1, kv_channels(hidden_size // num_heads)]
1027
  )
1028
 
1029
  if self.config.do_expert_forecast:
1030
+ assert decoder_forecast is not None, f'decoder_forecast is None'
1031
  if self.config.heterogeneous_moe_layer:
1032
  decoder_forecast = self.output_layer(decoder_forecast) # IdentityOp
1033
  else:
1034
+ final_forecast= self.output_layer(decoder_backcast * input_mask)
1035
  decoder_forecast = decoder_forecast + final_forecast
1036
  else:
1037
  # The decoder_backcast contains the mask_pad_val(default:255.)
1038
  decoder_forecast, _ = self.output_layer(decoder_backcast * input_mask)
1039
+
1040
  if self.config.do_base_forecast:
1041
+ assert base_forecast is not None, f'base_forecast is None'
1042
+ FalconTST_forecast = base_forecast + decoder_forecast
1043
  else:
1044
+ FalconTST_forecast = decoder_forecast
1045
+
1046
+ return FalconTST_forecast
1047
 
1048
  def _auto_regressive_single_head(
1049
  self,
1050
+ input, # [batch_size, seq_len]
1051
+ input_mask, # [batch_size, seq_len]
1052
+ FalconTST_forecast, # [batch_size, max(multi_forecast_head)]
1053
+ rotary_pos_emb, # [seq_len, 1, 1, kv_channels(hidden_size // num_heads)]
1054
+ auto_regressive_strategy='from_long_to_short'
1055
  ):
1056
  """auto regressive prediction with [single] head"""
1057
+ assert self.config.multi_forecast_head_type == 'single', \
1058
+ f'_auto_regressive_single_head only support multi_forecast_head_type==single '
 
1059
 
1060
+ if auto_regressive_strategy == 'from_long_to_short':
1061
  # From long to short
1062
  multi_forecast_head_list = sorted(self.config.multi_forecast_head_list, reverse=True)
1063
 
1064
+ final_output = FalconTST_forecast
1065
  while final_output.shape[1] < self.config.inference_length:
1066
  # adaptive choose the forecast head
1067
  remain_pred_len = self.config.inference_length - final_output.shape[1]
 
1071
  if idx == len(multi_forecast_head_list):
1072
  idx = len(multi_forecast_head_list) - 1
1073
  head_pred_len = multi_forecast_head_list[idx]
1074
+
1075
  # one-step model prediction
1076
+ input = torch.cat([input, FalconTST_forecast], dim=1)[:, -self.seq_length:].contiguous()
 
 
1077
  input_mask = torch.cat(
1078
+ [input_mask,
1079
+ torch.ones(FalconTST_forecast.shape, dtype=input_mask.dtype, device=input_mask.device)],
1080
+ dim=1)[:, -self.seq_length:].contiguous() # 0:mask, 1:unmask
1081
+
1082
+ FalconTST_forecast = self._inference_step(
1083
+ input=input,
1084
+ input_mask=input_mask,
1085
+ rotary_pos_emb=rotary_pos_emb,
 
 
 
 
 
 
 
 
 
1086
  )
1087
 
1088
  # the core idea of multi forecast head type of [single]
1089
+ FalconTST_forecast = FalconTST_forecast[:, :head_pred_len]
1090
+
1091
+ final_output = torch.cat([final_output, FalconTST_forecast], dim=1)
1092
+
1093
+ final_output = final_output[:, :self.config.inference_length]
1094
 
1095
+ elif auto_regressive_strategy == 'from_short_to_long':
1096
  # From short to long
1097
  # in validate_args, it has been sorted, and check the valid config
1098
  multi_forecast_head_list = sorted(self.config.multi_forecast_head_list)
 
1103
  else:
1104
  ar_step = min(
1105
  self.config.autoregressive_step_list[idx],
1106
+ self.config.multi_forecast_head_list[idx + 1] // self.config.multi_forecast_head_list[idx]
 
1107
  )
1108
  # ar_step = multi_forecast_head_list[idx + 1] // multi_forecast_head_list[idx]
1109
+
1110
  multi_forecast_head_dict[head_pred_len] = ar_step
1111
+
1112
  # the core idea of strategy [from_short_to_long]
1113
+ mixed_pred = FalconTST_forecast
1114
  output_list = []
1115
  cur_pred = None
1116
  cur_pred_len = 0
 
1124
  if ar_step == 0:
1125
  # Ignore the current forecast head
1126
  continue
1127
+
1128
  # Add current head's first auto-regressive step of prediction
1129
+ head_pred = mixed_pred[:, :head_pred_len] # [single]
1130
  output_list.append(head_pred[:, cur_pred_len:])
1131
  cur_pred = torch.cat(output_list, dim=1)
1132
  cur_pred_len = cur_pred.shape[1]
1133
  if cur_pred_len >= self.config.inference_length:
1134
  break
1135
+
1136
  # Do auto-regressive of the rest of the steps
1137
  for _ in range(1, ar_step + 1):
1138
  # one-step model prediction
1139
+ cur_input = torch.cat([input, cur_pred], dim=1)[:, -self.seq_length:].contiguous()
 
 
1140
  cur_input_mask = torch.cat(
1141
+ [input_mask,
1142
+ torch.ones(cur_pred.shape, dtype=input_mask.dtype, device=input_mask.device)],
1143
+ dim=1)[:, -self.seq_length:].contiguous() # 0:mask, 1:unmask
1144
+
1145
+ FalconTST_forecast = self._inference_step(
1146
+ input=cur_input,
1147
+ input_mask=cur_input_mask,
1148
+ rotary_pos_emb=rotary_pos_emb,
 
 
 
 
 
 
 
1149
  )
1150
 
1151
+ head_pred = FalconTST_forecast[:, :head_pred_len]
1152
  output_list.append(head_pred)
1153
  cur_pred = torch.cat(output_list, dim=1)
1154
  cur_pred_len = cur_pred.shape[1]
1155
  if cur_pred_len >= self.config.inference_length:
1156
  break
1157
+
1158
  if cur_pred_len >= self.config.inference_length:
1159
  break
1160
+
1161
+ final_output = cur_pred[:, :self.config.inference_length] # [batch_size, inference_len]
 
 
1162
 
1163
  assert final_output.shape[1] == self.config.inference_length
1164
  return final_output
1165
 
1166
+ class FalconTSTForPrediction(FalconTSTPreTrainedModel, FalconTSTGenerationMixin):
1167
+ def __init__(self, config: FalconTSTConfig):
 
1168
  super().__init__(config)
1169
  self.config = config
1170
+ self.model = FalconTSTModel(self.config)
1171
  self.post_init()
1172
 
1173
  def forward(
 
1180
  revin: Optional[bool] = False,
1181
  ):
1182
  self.model.config.inference_length = max_output_length
1183
+ outputs = self.model(
1184
+ input=input_ids,
1185
+ revin=revin
1186
+ )
1187
 
1188
  loss = None
1189
  logits = outputs
 
1205
  attention_mask=None,
1206
  inputs_embeds=None,
1207
  revin=False,
1208
+ **kwargs
1209
  ):
1210
  """
1211
  Prepare model inputs for autoregressive generation.
 
1213
 
1214
  model_inputs = {"input_ids": input_ids}
1215
 
1216
+ model_inputs.update({
1217
+ "revin": revin,
1218
+ })
 
 
1219
 
1220
+ return model_inputs