omniverse1 commited on
Commit
9284def
·
verified ·
1 Parent(s): 6001213

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +40 -15
utils.py CHANGED
@@ -6,6 +6,9 @@ from datetime import datetime, timedelta
6
  import plotly.graph_objects as go
7
  from plotly.subplots import make_subplots
8
  import spaces
 
 
 
9
 
10
  def get_indonesian_stocks():
11
  return {
@@ -156,26 +159,46 @@ def format_large_number(num):
156
  @spaces.GPU(duration=120)
157
  def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
158
  try:
159
- prices = data['Close'].values.astype(np.float32)
160
- from chronos import BaseChronosPipeline
161
- pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
 
 
 
 
 
162
  with torch.no_grad():
163
- # Fix: Use context_tensor instead of context
164
- forecast = pipeline.predict(context_tensor=torch.tensor(prices), prediction_length=prediction_days)
165
- forecast_np = forecast.squeeze().cpu().numpy() if isinstance(forecast, torch.Tensor) else np.array(forecast)
166
- if forecast_np.ndim > 1:
167
- mean_forecast = forecast_np.mean(axis=tuple(range(forecast_np.ndim - 1)))
168
- else:
169
- mean_forecast = forecast_np
170
- last_price = prices[-1]
 
 
 
 
 
 
 
 
171
  predicted_high = float(np.max(mean_forecast))
172
  predicted_low = float(np.min(mean_forecast))
173
  predicted_mean = float(np.mean(mean_forecast))
174
  change_pct = ((predicted_mean - last_price) / last_price) * 100 if last_price != 0 else 0
175
- return {'values': mean_forecast, 'dates': pd.date_range(start=data.index[-1] + timedelta(days=1), periods=len(mean_forecast), freq='D'), 'high_30d': predicted_high, 'low_30d': predicted_low, 'mean_30d': predicted_mean, 'change_pct': change_pct, 'summary': f"AI Model: Amazon Chronos-Bolt (Base)\nPredicted High: {predicted_high:.2f}\nPredicted Low: {predicted_low:.2f}\nExpected Change: {change_pct:.2f}%"}
 
 
 
176
  except Exception as e:
 
 
 
 
177
  print(f"Error in prediction: {e}")
178
- return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': f'Model error: {e}'}
179
 
180
  def create_prediction_chart(data, predictions):
181
  if not len(predictions['values']):
@@ -183,12 +206,14 @@ def create_prediction_chart(data, predictions):
183
  fig = go.Figure()
184
  fig.add_trace(go.Scatter(x=data.index[-60:], y=data['Close'].values[-60:], name='Historical Price', line=dict(color='blue', width=2)))
185
  fig.add_trace(go.Scatter(x=predictions['dates'], y=predictions['values'], name='AI Prediction', line=dict(color='red', width=2, dash='dash')))
186
- pred_std = np.std(predictions['values'])
 
 
187
  upper_band = predictions['values'] + (pred_std * 1.96)
188
  lower_band = predictions['values'] - (pred_std * 1.96)
189
  fig.add_trace(go.Scatter(x=predictions['dates'], y=upper_band, name='Upper Band', line=dict(color='lightcoral', width=1)))
190
  fig.add_trace(go.Scatter(x=predictions['dates'], y=lower_band, name='Lower Band', line=dict(color='lightcoral', width=1), fill='tonexty', fillcolor='rgba(255,182,193,0.2)'))
191
- fig.update_layout(title=f'Price Prediction - Next {len(predictions["dates"])} Days', xaxis_title='Date', yaxis_title='Price (IDR)', hovermode='x unified', height=500)
192
  return fig
193
 
194
  def create_price_chart(data, indicators):
 
6
  import plotly.graph_objects as go
7
  from plotly.subplots import make_subplots
8
  import spaces
9
+ # Impor Chronos2Pipeline untuk model amazon/chronos-2
10
+ from chronos import Chronos2Pipeline
11
+
12
 
13
  def get_indonesian_stocks():
14
  return {
 
159
  @spaces.GPU(duration=120)
160
  def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
161
  try:
162
+ # PENGGANTIAN: Gunakan Chronos2Pipeline dan model amazon/chronos-2
163
+ pipeline = Chronos2Pipeline.from_pretrained("amazon/chronos-2", device_map="auto")
164
+
165
+ # Chronos-2 membutuhkan input dalam format DataFrame (predict_df)
166
+ context_df = data[['Close']].reset_index()
167
+ context_df.columns = ['timestamp', 'target']
168
+ context_df['id'] = 'stock_price'
169
+
170
  with torch.no_grad():
171
+ # Menggunakan predict_df()
172
+ pred_df = pipeline.predict_df(
173
+ context_df,
174
+ prediction_length=prediction_days,
175
+ id_column="id",
176
+ timestamp_column="timestamp",
177
+ target="target",
178
+ # Kita ambil kuantil 0.5 (median) sebagai prediksi mean
179
+ quantile_levels=[0.5]
180
+ )
181
+
182
+ # Ekstraksi hasil prediksi dari DataFrame
183
+ mean_forecast = pred_df['target_0.5'].values.astype(np.float32)
184
+ predicted_dates = pred_df['timestamp']
185
+
186
+ last_price = data['Close'].iloc[-1]
187
  predicted_high = float(np.max(mean_forecast))
188
  predicted_low = float(np.min(mean_forecast))
189
  predicted_mean = float(np.mean(mean_forecast))
190
  change_pct = ((predicted_mean - last_price) / last_price) * 100 if last_price != 0 else 0
191
+
192
+ # Perbarui summary model
193
+ return {'values': mean_forecast, 'dates': predicted_dates, 'high_30d': predicted_high, 'low_30d': predicted_low, 'mean_30d': predicted_mean, 'change_pct': change_pct, 'summary': f"AI Model: Amazon Chronos-2\nPredicted High: {predicted_high:.2f}\nPredicted Low: {predicted_low:.2f}\nExpected Change: {change_pct:.2f}%"}
194
+
195
  except Exception as e:
196
+ # Error handling yang lebih baik
197
+ error_message = f'Model error: {e}'
198
+ if "context_tensor" in str(e) or "context" in str(e) or "pipeline" in str(e):
199
+ error_message = f"Prediction API Error (Chronos-2): Cek instalasi 'chronos-forecasting' dan argumen di predict_prices(). Detail: {e}"
200
  print(f"Error in prediction: {e}")
201
+ return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': error_message}
202
 
203
  def create_prediction_chart(data, predictions):
204
  if not len(predictions['values']):
 
206
  fig = go.Figure()
207
  fig.add_trace(go.Scatter(x=data.index[-60:], y=data['Close'].values[-60:], name='Historical Price', line=dict(color='blue', width=2)))
208
  fig.add_trace(go.Scatter(x=predictions['dates'], y=predictions['values'], name='AI Prediction', line=dict(color='red', width=2, dash='dash')))
209
+ # Perhitungan band harus menggunakan prediksi yang sudah dikuantil (jika ada) atau distandardisasi
210
+ # Karena kita hanya menggunakan satu kuantil (0.5), kita asumsikan pred_std kecil
211
+ pred_std = np.std(predictions['values']) if len(predictions['values']) > 1 else 0.05 * predictions['values'][0]
212
  upper_band = predictions['values'] + (pred_std * 1.96)
213
  lower_band = predictions['values'] - (pred_std * 1.96)
214
  fig.add_trace(go.Scatter(x=predictions['dates'], y=upper_band, name='Upper Band', line=dict(color='lightcoral', width=1)))
215
  fig.add_trace(go.Scatter(x=predictions['dates'], y=lower_band, name='Lower Band', line=dict(color='lightcoral', width=1), fill='tonexty', fillcolor='rgba(255,182,193,0.2)'))
216
+ fig.update_layout(title=f'Price Prediction - Next {len(predictions["dates"])} Days (Chronos-2)', xaxis_title='Date', yaxis_title='Price (IDR)', hovermode='x unified', height=500)
217
  return fig
218
 
219
  def create_price_chart(data, indicators):