omniverse1 commited on
Commit
7d19279
·
verified ·
1 Parent(s): 2500608

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +25 -12
utils.py CHANGED
@@ -271,34 +271,47 @@ def predict_prices(data, model, tokenizer, prediction_days=30):
271
  input_sequence = prices[-context_length:]
272
 
273
  # Create prediction input
274
- prediction_input = torch.tensor(input_sequence).unsqueeze(0).float()
 
275
 
276
  # Generate predictions
277
  with torch.no_grad():
 
 
278
  forecast = model.generate(
279
  prediction_input,
280
- prediction_length=prediction_days
281
  )
282
 
283
- # FIXED: Handle the case where the output is a tuple (common for Seq2Seq models)
284
- if isinstance(forecast, tuple):
285
- predictions = forecast[0].numpy()
286
- else:
287
- # Ensure the output is a numpy array (even if it's already a tensor)
288
- predictions = forecast[0].numpy()
289
 
290
- # Calculate prediction statistics
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  last_price = prices[-1]
292
  predicted_high = np.max(predictions)
293
  predicted_low = np.min(predictions)
294
  predicted_mean = np.mean(predictions)
295
  change_pct = ((predicted_mean - last_price) / last_price) * 100
296
 
 
 
 
297
  return {
298
  'values': predictions,
299
  'dates': pd.date_range(
300
  start=data.index[-1] + timedelta(days=1),
301
- periods=prediction_days,
302
  freq='D'
303
  ),
304
  'high_30d': predicted_high,
@@ -307,7 +320,7 @@ def predict_prices(data, model, tokenizer, prediction_days=30):
307
  'change_pct': change_pct,
308
  'summary': f"""
309
  AI Model: Amazon Chronos-Bolt
310
- Prediction Period: {prediction_days} days
311
  Expected Change: {change_pct:.2f}%
312
  Confidence: Medium (based on historical patterns)
313
  Note: AI predictions are for reference only and not financial advice
@@ -322,7 +335,7 @@ Note: AI predictions are for reference only and not financial advice
322
  'low_30d': 0,
323
  'mean_30d': 0,
324
  'change_pct': 0,
325
- 'summary': 'Prediction unavailable due to model error'
326
  }
327
 
328
  def create_price_chart(data, indicators):
 
271
  input_sequence = prices[-context_length:]
272
 
273
  # Create prediction input
274
+ # CRITICAL FIX: Ensure tensor is on the same device as the model (GPU/auto-mapped)
275
+ prediction_input = torch.tensor(input_sequence).unsqueeze(0).float().to(model.device)
276
 
277
  # Generate predictions
278
  with torch.no_grad():
279
+ # CRITICAL FIX: Removed prediction_length keyword which caused warnings/potential internal failure.
280
+ # Model will use its default configuration for generation length.
281
  forecast = model.generate(
282
  prediction_input,
 
283
  )
284
 
285
+ # CRITICAL FIX: Handle complex Chronos output: [batch_size, num_samples, prediction_length]
 
 
 
 
 
286
 
287
+ # 1. Get the actual tensor from the tuple/list if necessary
288
+ output_tensor = forecast[0] if isinstance(forecast, (tuple, list)) else forecast
289
+
290
+ # 2. Average across the samples (dim=1) and ensure it is a simple 1D numpy array
291
+ # output_tensor shape: [1, num_samples, prediction_length]
292
+ # .mean(dim=1) averages the probabilistic samples
293
+ # .squeeze() removes unnecessary dimensions (like batch size 1)
294
+ predictions = output_tensor.mean(dim=1).squeeze().cpu().numpy()
295
+
296
+ # Handle case where predictions is a single scalar (convert to array for safety)
297
+ if predictions.ndim == 0:
298
+ predictions = np.array([predictions.item()])
299
+
300
+ # Calculation is now safe as predictions is a 1D array
301
  last_price = prices[-1]
302
  predicted_high = np.max(predictions)
303
  predicted_low = np.min(predictions)
304
  predicted_mean = np.mean(predictions)
305
  change_pct = ((predicted_mean - last_price) / last_price) * 100
306
 
307
+ # Use actual prediction length from the output tensor
308
+ pred_len = len(predictions)
309
+
310
  return {
311
  'values': predictions,
312
  'dates': pd.date_range(
313
  start=data.index[-1] + timedelta(days=1),
314
+ periods=pred_len,
315
  freq='D'
316
  ),
317
  'high_30d': predicted_high,
 
320
  'change_pct': change_pct,
321
  'summary': f"""
322
  AI Model: Amazon Chronos-Bolt
323
+ Prediction Period: {pred_len} days
324
  Expected Change: {change_pct:.2f}%
325
  Confidence: Medium (based on historical patterns)
326
  Note: AI predictions are for reference only and not financial advice
 
335
  'low_30d': 0,
336
  'mean_30d': 0,
337
  'change_pct': 0,
338
+ 'summary': f'Prediction unavailable due to model error: {e}'
339
  }
340
 
341
  def create_price_chart(data, indicators):