omniverse1 commited on
Commit
88b0bb9
·
verified ·
1 Parent(s): e41f18c

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +17 -9
utils.py CHANGED
@@ -6,7 +6,6 @@ from datetime import datetime, timedelta
6
  import plotly.graph_objects as go
7
  from plotly.subplots import make_subplots
8
  import spaces
9
- # Gunakan Chronos2Pipeline karena ini adalah kelas generik untuk model Chronos
10
  from chronos import Chronos2Pipeline
11
 
12
  def get_indonesian_stocks():
@@ -158,25 +157,25 @@ def format_large_number(num):
158
  @spaces.GPU(duration=120)
159
  def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
160
  try:
161
- # PENTING: Ganti ke Chronos-T5-Tiny untuk mencegah TimeOut/Crash GPU
162
- pipeline = Chronos2Pipeline.from_pretrained("amazon/chronos-t5-tiny", device_map="auto")
163
 
 
164
  context_df = data[['Close', 'Volume']].reset_index()
165
  context_df.columns = ['timestamp', 'target', 'volume']
166
  context_df['id'] = 'stock_price'
167
 
 
168
  context_df['timestamp'] = pd.to_datetime(context_df['timestamp'])
169
  context_df = context_df.set_index('timestamp').asfreq('D')
170
 
171
- # IMPUTATION: Harga (target) ffill, Volume (covariate) fillna(0)
172
  context_df['target'] = context_df['target'].fillna(method='ffill')
173
  context_df['volume'] = context_df['volume'].fillna(0)
174
 
175
- # FINAL GUARD: Drop baris di awal yang mungkin masih NaN (jika history start pada hari non-dagang)
176
- context_df = context_df.dropna(subset=['target', 'volume'])
177
-
178
  context_df = context_df.reset_index()
179
 
 
180
  context_df['id'] = 'stock_price'
181
  context_df = context_df[['timestamp', 'target', 'volume', 'id']]
182
 
@@ -190,11 +189,15 @@ def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
190
  quantile_levels=[0.1, 0.5, 0.9]
191
  )
192
 
 
193
  required_cols = ['target_0.1', 'target_0.5', 'target_0.9']
194
  if pred_df.empty or not all(col in pred_df.columns for col in required_cols):
 
195
  missing = [col for col in required_cols if col not in pred_df.columns]
196
  raise RuntimeError(f"Prediction failed. Result DataFrame is empty or incomplete. Missing: {missing}")
 
197
 
 
198
  q05_forecast = pred_df['target_0.5'].values.astype(np.float32)
199
  q09_forecast = pred_df['target_0.9'].values.astype(np.float32)
200
  q01_forecast = pred_df['target_0.1'].values.astype(np.float32)
@@ -216,12 +219,13 @@ def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
216
  'change_pct': change_pct,
217
  'q01': q01_forecast,
218
  'q09': q09_forecast,
219
- 'summary': f"AI Model: Amazon Chronos-T5-Tiny (Volume Covariate)\nPredicted High: {predicted_high:.2f}\nPredicted Low: {predicted_low:.2f}\nExpected Change: {change_pct:.2f}%"
220
  }
221
 
222
  except Exception as e:
223
  error_message = f'Model prediction failed: {e}'
224
  print(f"Error in prediction: {e}")
 
225
  return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': error_message, 'q01': [], 'q09': []}
226
 
227
  def create_prediction_chart(data, predictions):
@@ -230,8 +234,10 @@ def create_prediction_chart(data, predictions):
230
 
231
  fig = go.Figure()
232
 
 
233
  fig.add_trace(go.Scatter(x=data.index, y=data['Close'].values, name='Historical Price', line=dict(color='blue', width=2)))
234
 
 
235
  fig.add_trace(go.Scatter(
236
  x=predictions['dates'],
237
  y=predictions['q09'],
@@ -248,10 +254,11 @@ def create_prediction_chart(data, predictions):
248
  fillcolor='rgba(255,182,193,0.3)'
249
  ))
250
 
 
251
  fig.add_trace(go.Scatter(x=predictions['dates'], y=predictions['values'], name='Median Forecast (Q0.5)', line=dict(color='red', width=3, dash='solid')))
252
 
253
  fig.update_layout(
254
- title=f'Probabilistic Price Forecast - Next {len(predictions["dates"])} Days (Chronos-T5-Tiny)',
255
  xaxis_title='Date',
256
  yaxis_title='Price (IDR)',
257
  hovermode='x unified',
@@ -259,6 +266,7 @@ def create_prediction_chart(data, predictions):
259
  legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01)
260
  )
261
 
 
262
  last_hist_date = data.index[-1]
263
  last_hist_price = data['Close'].iloc[-1]
264
  fig.add_trace(go.Scatter(
 
6
  import plotly.graph_objects as go
7
  from plotly.subplots import make_subplots
8
  import spaces
 
9
  from chronos import Chronos2Pipeline
10
 
11
  def get_indonesian_stocks():
 
157
  @spaces.GPU(duration=120)
158
  def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
159
  try:
160
+ # Panggil pipeline di sini untuk memastikan instance baru tiap run (mencegah error memori/state)
161
+ pipeline = Chronos2Pipeline.from_pretrained("amazon/chronos-2", device_map="auto")
162
 
163
+ # Chronos-2 with Covariate: Menggunakan Close (target) dan Volume (covariate)
164
  context_df = data[['Close', 'Volume']].reset_index()
165
  context_df.columns = ['timestamp', 'target', 'volume']
166
  context_df['id'] = 'stock_price'
167
 
168
+ # Fix Error: Could not infer frequency & FIX VOLUME COVARIATE IMPUTATION
169
  context_df['timestamp'] = pd.to_datetime(context_df['timestamp'])
170
  context_df = context_df.set_index('timestamp').asfreq('D')
171
 
172
+ # IMPUTATION FIX: Target ffill, Covariate (Volume) fillna(0)
173
  context_df['target'] = context_df['target'].fillna(method='ffill')
174
  context_df['volume'] = context_df['volume'].fillna(0)
175
 
 
 
 
176
  context_df = context_df.reset_index()
177
 
178
+ # Pastikan kolom sesuai urutan Chronos-2: timestamp, target, covariate(s), id
179
  context_df['id'] = 'stock_price'
180
  context_df = context_df[['timestamp', 'target', 'volume', 'id']]
181
 
 
189
  quantile_levels=[0.1, 0.5, 0.9]
190
  )
191
 
192
+ # --- FIX UTAMA: Pengecekan kolom hasil prediksi yang lebih ketat ---
193
  required_cols = ['target_0.1', 'target_0.5', 'target_0.9']
194
  if pred_df.empty or not all(col in pred_df.columns for col in required_cols):
195
+ # Jika gagal, pastikan kita tahu errornya dan melempar Runtime yang akan ditangkap di luar
196
  missing = [col for col in required_cols if col not in pred_df.columns]
197
  raise RuntimeError(f"Prediction failed. Result DataFrame is empty or incomplete. Missing: {missing}")
198
+ # ------------------------------------------------------------------
199
 
200
+ # Ekstraksi hasil prediksi kuantil
201
  q05_forecast = pred_df['target_0.5'].values.astype(np.float32)
202
  q09_forecast = pred_df['target_0.9'].values.astype(np.float32)
203
  q01_forecast = pred_df['target_0.1'].values.astype(np.float32)
 
219
  'change_pct': change_pct,
220
  'q01': q01_forecast,
221
  'q09': q09_forecast,
222
+ 'summary': f"AI Model: Amazon Chronos-2 (Volume Covariate)\nPredicted High: {predicted_high:.2f}\nPredicted Low: {predicted_low:.2f}\nExpected Change: {change_pct:.2f}%"
223
  }
224
 
225
  except Exception as e:
226
  error_message = f'Model prediction failed: {e}'
227
  print(f"Error in prediction: {e}")
228
+ # Mengembalikan objek error yang valid
229
  return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': error_message, 'q01': [], 'q09': []}
230
 
231
  def create_prediction_chart(data, predictions):
 
234
 
235
  fig = go.Figure()
236
 
237
+ # Historical Price: Menggunakan seluruh data historis
238
  fig.add_trace(go.Scatter(x=data.index, y=data['Close'].values, name='Historical Price', line=dict(color='blue', width=2)))
239
 
240
+ # Prediction Interval (Band): Menggunakan Q0.1 dan Q0.9
241
  fig.add_trace(go.Scatter(
242
  x=predictions['dates'],
243
  y=predictions['q09'],
 
254
  fillcolor='rgba(255,182,193,0.3)'
255
  ))
256
 
257
+ # Median Forecast (Q0.5) - Garis Utama Prediksi
258
  fig.add_trace(go.Scatter(x=predictions['dates'], y=predictions['values'], name='Median Forecast (Q0.5)', line=dict(color='red', width=3, dash='solid')))
259
 
260
  fig.update_layout(
261
+ title=f'Probabilistic Price Forecast - Next {len(predictions["dates"])} Days (Chronos-2)',
262
  xaxis_title='Date',
263
  yaxis_title='Price (IDR)',
264
  hovermode='x unified',
 
266
  legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01)
267
  )
268
 
269
+ # Menandai titik harga terakhir
270
  last_hist_date = data.index[-1]
271
  last_hist_price = data['Close'].iloc[-1]
272
  fig.add_trace(go.Scatter(