Harryx2025 commited on
Commit
6119a49
·
verified ·
1 Parent(s): 153429e

Update ts_generation_mixin.py

Browse files
Files changed (1) hide show
  1. ts_generation_mixin.py +1 -1
ts_generation_mixin.py CHANGED
@@ -39,7 +39,7 @@ class FalconTSTGenerationMixin(GenerationMixin):
39
  channel = 1
40
  if len(inputs.shape) == 3:
41
  channel = inputs.shape[2]
42
- inputs = inputs.transpose(0, 2, 1).reshape(batch_size * channel, length)
43
  elif len(inputs.shape) > 3:
44
  raise ValueError("Input shape must be [batch, seq_len, channel] or [batch, seq_len]")
45
 
 
39
  channel = 1
40
  if len(inputs.shape) == 3:
41
  channel = inputs.shape[2]
42
+ inputs = inputs.permute(0, 2, 1).reshape(batch_size * channel, length)
43
  elif len(inputs.shape) > 3:
44
  raise ValueError("Input shape must be [batch, seq_len, channel] or [batch, seq_len]")
45