Update ts_generation_mixin.py
Browse files- ts_generation_mixin.py +3 -3
ts_generation_mixin.py
CHANGED
|
@@ -226,12 +226,12 @@ class TSGenerationMixin(GenerationMixin):
|
|
| 226 |
if "decoder_attention_mask" in model_kwargs:
|
| 227 |
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
| 228 |
model_kwargs["decoder_attention_mask"] = torch.cat(
|
| 229 |
-
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0],
|
| 230 |
dim=-1,
|
| 231 |
)
|
| 232 |
|
| 233 |
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
|
| 234 |
-
|
| 235 |
-
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
|
| 236 |
|
| 237 |
return model_kwargs
|
|
|
|
| 226 |
if "decoder_attention_mask" in model_kwargs:
|
| 227 |
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
| 228 |
model_kwargs["decoder_attention_mask"] = torch.cat(
|
| 229 |
+
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], horizon_length))],
|
| 230 |
dim=-1,
|
| 231 |
)
|
| 232 |
|
| 233 |
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
|
| 234 |
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + horizon_length
|
| 235 |
+
# model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
|
| 236 |
|
| 237 |
return model_kwargs
|