omniverse1 commited on
Commit
d5c8405
·
verified ·
1 Parent(s): b19c127

update utils

Browse files
Files changed (1) hide show
  1. utils.py +8 -4
utils.py CHANGED
@@ -275,14 +275,18 @@ def predict_prices(data, model, tokenizer, prediction_days=30):
275
 
276
  # Generate predictions
277
  with torch.no_grad():
 
 
278
  forecast = model.generate(
279
  prediction_input,
280
- prediction_length=prediction_days,
281
- temperature=1.0,
282
- # REMOVED: 'top_k=50' and 'top_p=0.9' to eliminate warnings
283
  )
284
 
285
- predictions = forecast[0].numpy()
 
 
 
 
286
 
287
  # Calculate prediction statistics
288
  last_price = prices[-1]
 
275
 
276
  # Generate predictions
277
  with torch.no_grad():
278
+ # FIXED: Removed unnecessary generation flags (top_p, top_k, temperature)
279
+ # to eliminate warnings and potential conflicts with model's custom implementation.
280
  forecast = model.generate(
281
  prediction_input,
282
+ prediction_length=prediction_days
 
 
283
  )
284
 
285
+ # Handle case where output is a tuple (common for Seq2Seq models)
286
+ if isinstance(forecast, tuple):
287
+ predictions = forecast[0].numpy()
288
+ else:
289
+ predictions = forecast[0].numpy()
290
 
291
  # Calculate prediction statistics
292
  last_price = prices[-1]