omniverse1 commited on
Commit
b19777d
·
verified ·
1 Parent(s): 2b1f62e

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +22 -14
utils.py CHANGED
@@ -7,7 +7,6 @@ import plotly.graph_objects as go
7
  import plotly.express as px
8
  from plotly.subplots import make_subplots
9
  import spaces
10
- from chronos import BaseChronosPipeline
11
 
12
  def get_indonesian_stocks():
13
  return {
@@ -217,30 +216,39 @@ def format_large_number(num):
217
  def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
218
  try:
219
  prices = data['Close'].values.astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
220
  pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
221
  with torch.no_grad():
222
  forecast = pipeline.predict(context=torch.tensor(prices), prediction_length=prediction_days)
223
- mean_forecast = forecast.mean(dim=1).squeeze().cpu().numpy()
 
 
 
224
  pred_len = len(mean_forecast)
225
  last_price = prices[-1]
226
- predicted_high = np.max(mean_forecast)
227
- predicted_low = np.min(mean_forecast)
228
- predicted_mean = np.mean(mean_forecast)
229
- change_pct = ((predicted_mean - last_price) / last_price) * 100
230
  return {
231
  'values': mean_forecast,
232
- 'dates': pd.date_range(start=data.index[-1] + timedelta(days=1), periods=pred_len, freq='D'),
233
  'high_30d': predicted_high,
234
  'low_30d': predicted_low,
235
  'mean_30d': predicted_mean,
236
  'change_pct': change_pct,
237
- 'summary': f"""
238
- AI Model: Amazon Chronos-Bolt (Base)
239
- Prediction Period: {pred_len} days
240
- Expected Change: {change_pct:.2f}%
241
- Confidence: Medium
242
- Note: AI predictions are for reference only and not financial advice
243
- """.strip()
244
  }
245
  except Exception as e:
246
  print(f"Error in prediction: {e}")
 
7
  import plotly.express as px
8
  from plotly.subplots import make_subplots
9
  import spaces
 
10
 
11
  def get_indonesian_stocks():
12
  return {
 
216
  def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
217
  try:
218
  prices = data['Close'].values.astype(np.float32)
219
+ try:
220
+ from chronos import BaseChronosPipeline
221
+ except Exception as ie:
222
+ return {
223
+ 'values': [],
224
+ 'dates': [],
225
+ 'high_30d': 0,
226
+ 'low_30d': 0,
227
+ 'mean_30d': 0,
228
+ 'change_pct': 0,
229
+ 'summary': 'chronos package not installed. install with: pip install chronos-forecasting'
230
+ }
231
  pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
232
  with torch.no_grad():
233
  forecast = pipeline.predict(context=torch.tensor(prices), prediction_length=prediction_days)
234
+ if hasattr(forecast, 'mean'):
235
+ mean_forecast = forecast.mean(dim=1).squeeze().cpu().numpy()
236
+ else:
237
+ mean_forecast = forecast.mean(dim=1).squeeze().cpu().numpy()
238
  pred_len = len(mean_forecast)
239
  last_price = prices[-1]
240
+ predicted_high = np.max(mean_forecast) if pred_len > 0 else 0
241
+ predicted_low = np.min(mean_forecast) if pred_len > 0 else 0
242
+ predicted_mean = np.mean(mean_forecast) if pred_len > 0 else 0
243
+ change_pct = ((predicted_mean - last_price) / last_price) * 100 if last_price != 0 and pred_len > 0 else 0
244
  return {
245
  'values': mean_forecast,
246
+ 'dates': pd.date_range(start=data.index[-1] + timedelta(days=1), periods=pred_len, freq='D') if pred_len > 0 else [],
247
  'high_30d': predicted_high,
248
  'low_30d': predicted_low,
249
  'mean_30d': predicted_mean,
250
  'change_pct': change_pct,
251
+ 'summary': f"AI Model: Amazon Chronos-Bolt (Base)\nPrediction Period: {pred_len} days\nExpected Change: {change_pct:.2f}%\nConfidence: Medium\nNote: AI predictions are for reference only and not financial advice"
 
 
 
 
 
 
252
  }
253
  except Exception as e:
254
  print(f"Error in prediction: {e}")