omniverse1 commited on
Commit
2500608
·
verified ·
1 Parent(s): d5c8405

update utils

Browse files
Files changed (1) hide show
  1. utils.py +2 -3
utils.py CHANGED
@@ -275,17 +275,16 @@ def predict_prices(data, model, tokenizer, prediction_days=30):
275
 
276
  # Generate predictions
277
  with torch.no_grad():
278
- # FIXED: Removed unnecessary generation flags (top_p, top_k, temperature)
279
- # to eliminate warnings and potential conflicts with model's custom implementation.
280
  forecast = model.generate(
281
  prediction_input,
282
  prediction_length=prediction_days
283
  )
284
 
285
- # Handle case where output is a tuple (common for Seq2Seq models)
286
  if isinstance(forecast, tuple):
287
  predictions = forecast[0].numpy()
288
  else:
 
289
  predictions = forecast[0].numpy()
290
 
291
  # Calculate prediction statistics
 
275
 
276
  # Generate predictions
277
  with torch.no_grad():
 
 
278
  forecast = model.generate(
279
  prediction_input,
280
  prediction_length=prediction_days
281
  )
282
 
283
+ # FIXED: Handle the case where the output is a tuple (common for Seq2Seq models)
284
  if isinstance(forecast, tuple):
285
  predictions = forecast[0].numpy()
286
  else:
287
+ # Ensure the output is a numpy array (even if it's already a tensor)
288
  predictions = forecast[0].numpy()
289
 
290
  # Calculate prediction statistics