Spaces:
Sleeping
Sleeping
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -6,6 +6,9 @@ from datetime import datetime, timedelta
|
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
from plotly.subplots import make_subplots
|
| 8 |
import spaces
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
def get_indonesian_stocks():
|
| 11 |
return {
|
|
@@ -156,26 +159,46 @@ def format_large_number(num):
|
|
| 156 |
@spaces.GPU(duration=120)
|
| 157 |
def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
|
| 158 |
try:
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
with torch.no_grad():
|
| 163 |
-
#
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
predicted_high = float(np.max(mean_forecast))
|
| 172 |
predicted_low = float(np.min(mean_forecast))
|
| 173 |
predicted_mean = float(np.mean(mean_forecast))
|
| 174 |
change_pct = ((predicted_mean - last_price) / last_price) * 100 if last_price != 0 else 0
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
| 176 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
print(f"Error in prediction: {e}")
|
| 178 |
-
return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary':
|
| 179 |
|
| 180 |
def create_prediction_chart(data, predictions):
|
| 181 |
if not len(predictions['values']):
|
|
@@ -183,12 +206,14 @@ def create_prediction_chart(data, predictions):
|
|
| 183 |
fig = go.Figure()
|
| 184 |
fig.add_trace(go.Scatter(x=data.index[-60:], y=data['Close'].values[-60:], name='Historical Price', line=dict(color='blue', width=2)))
|
| 185 |
fig.add_trace(go.Scatter(x=predictions['dates'], y=predictions['values'], name='AI Prediction', line=dict(color='red', width=2, dash='dash')))
|
| 186 |
-
|
|
|
|
|
|
|
| 187 |
upper_band = predictions['values'] + (pred_std * 1.96)
|
| 188 |
lower_band = predictions['values'] - (pred_std * 1.96)
|
| 189 |
fig.add_trace(go.Scatter(x=predictions['dates'], y=upper_band, name='Upper Band', line=dict(color='lightcoral', width=1)))
|
| 190 |
fig.add_trace(go.Scatter(x=predictions['dates'], y=lower_band, name='Lower Band', line=dict(color='lightcoral', width=1), fill='tonexty', fillcolor='rgba(255,182,193,0.2)'))
|
| 191 |
-
fig.update_layout(title=f'Price Prediction - Next {len(predictions["dates"])} Days', xaxis_title='Date', yaxis_title='Price (IDR)', hovermode='x unified', height=500)
|
| 192 |
return fig
|
| 193 |
|
| 194 |
def create_price_chart(data, indicators):
|
|
|
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
from plotly.subplots import make_subplots
|
| 8 |
import spaces
|
| 9 |
+
# Impor Chronos2Pipeline untuk model amazon/chronos-2
|
| 10 |
+
from chronos import Chronos2Pipeline
|
| 11 |
+
|
| 12 |
|
| 13 |
def get_indonesian_stocks():
|
| 14 |
return {
|
|
|
|
| 159 |
@spaces.GPU(duration=120)
|
| 160 |
def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
|
| 161 |
try:
|
| 162 |
+
# PENGGANTIAN: Gunakan Chronos2Pipeline dan model amazon/chronos-2
|
| 163 |
+
pipeline = Chronos2Pipeline.from_pretrained("amazon/chronos-2", device_map="auto")
|
| 164 |
+
|
| 165 |
+
# Chronos-2 membutuhkan input dalam format DataFrame (predict_df)
|
| 166 |
+
context_df = data[['Close']].reset_index()
|
| 167 |
+
context_df.columns = ['timestamp', 'target']
|
| 168 |
+
context_df['id'] = 'stock_price'
|
| 169 |
+
|
| 170 |
with torch.no_grad():
|
| 171 |
+
# Menggunakan predict_df()
|
| 172 |
+
pred_df = pipeline.predict_df(
|
| 173 |
+
context_df,
|
| 174 |
+
prediction_length=prediction_days,
|
| 175 |
+
id_column="id",
|
| 176 |
+
timestamp_column="timestamp",
|
| 177 |
+
target="target",
|
| 178 |
+
# Kita ambil kuantil 0.5 (median) sebagai prediksi mean
|
| 179 |
+
quantile_levels=[0.5]
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Ekstraksi hasil prediksi dari DataFrame
|
| 183 |
+
mean_forecast = pred_df['target_0.5'].values.astype(np.float32)
|
| 184 |
+
predicted_dates = pred_df['timestamp']
|
| 185 |
+
|
| 186 |
+
last_price = data['Close'].iloc[-1]
|
| 187 |
predicted_high = float(np.max(mean_forecast))
|
| 188 |
predicted_low = float(np.min(mean_forecast))
|
| 189 |
predicted_mean = float(np.mean(mean_forecast))
|
| 190 |
change_pct = ((predicted_mean - last_price) / last_price) * 100 if last_price != 0 else 0
|
| 191 |
+
|
| 192 |
+
# Perbarui summary model
|
| 193 |
+
return {'values': mean_forecast, 'dates': predicted_dates, 'high_30d': predicted_high, 'low_30d': predicted_low, 'mean_30d': predicted_mean, 'change_pct': change_pct, 'summary': f"AI Model: Amazon Chronos-2\nPredicted High: {predicted_high:.2f}\nPredicted Low: {predicted_low:.2f}\nExpected Change: {change_pct:.2f}%"}
|
| 194 |
+
|
| 195 |
except Exception as e:
|
| 196 |
+
# Error handling yang lebih baik
|
| 197 |
+
error_message = f'Model error: {e}'
|
| 198 |
+
if "context_tensor" in str(e) or "context" in str(e) or "pipeline" in str(e):
|
| 199 |
+
error_message = f"Prediction API Error (Chronos-2): Cek instalasi 'chronos-forecasting' dan argumen di predict_prices(). Detail: {e}"
|
| 200 |
print(f"Error in prediction: {e}")
|
| 201 |
+
return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': error_message}
|
| 202 |
|
| 203 |
def create_prediction_chart(data, predictions):
|
| 204 |
if not len(predictions['values']):
|
|
|
|
| 206 |
fig = go.Figure()
|
| 207 |
fig.add_trace(go.Scatter(x=data.index[-60:], y=data['Close'].values[-60:], name='Historical Price', line=dict(color='blue', width=2)))
|
| 208 |
fig.add_trace(go.Scatter(x=predictions['dates'], y=predictions['values'], name='AI Prediction', line=dict(color='red', width=2, dash='dash')))
|
| 209 |
+
# Perhitungan band harus menggunakan prediksi yang sudah dikuantil (jika ada) atau distandardisasi
|
| 210 |
+
# Karena kita hanya menggunakan satu kuantil (0.5), kita asumsikan pred_std kecil
|
| 211 |
+
pred_std = np.std(predictions['values']) if len(predictions['values']) > 1 else 0.05 * predictions['values'][0]
|
| 212 |
upper_band = predictions['values'] + (pred_std * 1.96)
|
| 213 |
lower_band = predictions['values'] - (pred_std * 1.96)
|
| 214 |
fig.add_trace(go.Scatter(x=predictions['dates'], y=upper_band, name='Upper Band', line=dict(color='lightcoral', width=1)))
|
| 215 |
fig.add_trace(go.Scatter(x=predictions['dates'], y=lower_band, name='Lower Band', line=dict(color='lightcoral', width=1), fill='tonexty', fillcolor='rgba(255,182,193,0.2)'))
|
| 216 |
+
fig.update_layout(title=f'Price Prediction - Next {len(predictions["dates"])} Days (Chronos-2)', xaxis_title='Date', yaxis_title='Price (IDR)', hovermode='x unified', height=500)
|
| 217 |
return fig
|
| 218 |
|
| 219 |
def create_price_chart(data, indicators):
|