Update ts_generation_mixin.py
Browse files- ts_generation_mixin.py +1 -1
ts_generation_mixin.py
CHANGED
|
@@ -39,7 +39,7 @@ class FalconTSTGenerationMixin(GenerationMixin):
|
|
| 39 |
channel = 1
|
| 40 |
if len(inputs.shape) == 3:
|
| 41 |
channel = inputs.shape[2]
|
| 42 |
-
inputs = inputs.
|
| 43 |
elif len(inputs.shape) > 3:
|
| 44 |
raise ValueError("Input shape must be [batch, seq_len, channel] or [batch, seq_len]")
|
| 45 |
|
|
|
|
| 39 |
channel = 1
|
| 40 |
if len(inputs.shape) == 3:
|
| 41 |
channel = inputs.shape[2]
|
| 42 |
+
inputs = inputs.permute(0, 2, 1).reshape(batch_size * channel, length)
|
| 43 |
elif len(inputs.shape) > 3:
|
| 44 |
raise ValueError("Input shape must be [batch, seq_len, channel] or [batch, seq_len]")
|
| 45 |
|