Spaces:
Sleeping
Sleeping
update utils
Browse files
utils.py
CHANGED
|
@@ -275,14 +275,18 @@ def predict_prices(data, model, tokenizer, prediction_days=30):
|
|
| 275 |
|
| 276 |
# Generate predictions
|
| 277 |
with torch.no_grad():
|
|
|
|
|
|
|
| 278 |
forecast = model.generate(
|
| 279 |
prediction_input,
|
| 280 |
-
prediction_length=prediction_days
|
| 281 |
-
temperature=1.0,
|
| 282 |
-
# REMOVED: 'top_k=50' and 'top_p=0.9' to eliminate warnings
|
| 283 |
)
|
| 284 |
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
# Calculate prediction statistics
|
| 288 |
last_price = prices[-1]
|
|
|
|
| 275 |
|
| 276 |
# Generate predictions
|
| 277 |
with torch.no_grad():
|
| 278 |
+
# FIXED: Removed unnecessary generation flags (top_p, top_k, temperature)
|
| 279 |
+
# to eliminate warnings and potential conflicts with model's custom implementation.
|
| 280 |
forecast = model.generate(
|
| 281 |
prediction_input,
|
| 282 |
+
prediction_length=prediction_days
|
|
|
|
|
|
|
| 283 |
)
|
| 284 |
|
| 285 |
+
# Handle case where output is a tuple (common for Seq2Seq models)
|
| 286 |
+
if isinstance(forecast, tuple):
|
| 287 |
+
predictions = forecast[0].numpy()
|
| 288 |
+
else:
|
| 289 |
+
predictions = forecast[0].numpy()
|
| 290 |
|
| 291 |
# Calculate prediction statistics
|
| 292 |
last_price = prices[-1]
|