Spaces:
Sleeping
Sleeping
update utils
Browse files
utils.py
CHANGED
|
@@ -275,17 +275,16 @@ def predict_prices(data, model, tokenizer, prediction_days=30):
|
|
| 275 |
|
| 276 |
# Generate predictions
|
| 277 |
with torch.no_grad():
|
| 278 |
-
# FIXED: Removed unnecessary generation flags (top_p, top_k, temperature)
|
| 279 |
-
# to eliminate warnings and potential conflicts with model's custom implementation.
|
| 280 |
forecast = model.generate(
|
| 281 |
prediction_input,
|
| 282 |
prediction_length=prediction_days
|
| 283 |
)
|
| 284 |
|
| 285 |
-
# Handle case where output is a tuple (common for Seq2Seq models)
|
| 286 |
if isinstance(forecast, tuple):
|
| 287 |
predictions = forecast[0].numpy()
|
| 288 |
else:
|
|
|
|
| 289 |
predictions = forecast[0].numpy()
|
| 290 |
|
| 291 |
# Calculate prediction statistics
|
|
|
|
| 275 |
|
| 276 |
# Generate predictions
|
| 277 |
with torch.no_grad():
|
|
|
|
|
|
|
| 278 |
forecast = model.generate(
|
| 279 |
prediction_input,
|
| 280 |
prediction_length=prediction_days
|
| 281 |
)
|
| 282 |
|
| 283 |
+
# FIXED: Handle the case where the output is a tuple (common for Seq2Seq models)
|
| 284 |
if isinstance(forecast, tuple):
|
| 285 |
predictions = forecast[0].numpy()
|
| 286 |
else:
|
| 287 |
+
# Ensure the output is a numpy array (even if it's already a tensor)
|
| 288 |
predictions = forecast[0].numpy()
|
| 289 |
|
| 290 |
# Calculate prediction statistics
|