Spaces:
Sleeping
Sleeping
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -7,7 +7,6 @@ import plotly.graph_objects as go
|
|
| 7 |
import plotly.express as px
|
| 8 |
from plotly.subplots import make_subplots
|
| 9 |
import spaces
|
| 10 |
-
from chronos import BaseChronosPipeline
|
| 11 |
|
| 12 |
def get_indonesian_stocks():
|
| 13 |
return {
|
|
@@ -217,30 +216,39 @@ def format_large_number(num):
|
|
| 217 |
def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
|
| 218 |
try:
|
| 219 |
prices = data['Close'].values.astype(np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
|
| 221 |
with torch.no_grad():
|
| 222 |
forecast = pipeline.predict(context=torch.tensor(prices), prediction_length=prediction_days)
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
| 224 |
pred_len = len(mean_forecast)
|
| 225 |
last_price = prices[-1]
|
| 226 |
-
predicted_high = np.max(mean_forecast)
|
| 227 |
-
predicted_low = np.min(mean_forecast)
|
| 228 |
-
predicted_mean = np.mean(mean_forecast)
|
| 229 |
-
change_pct = ((predicted_mean - last_price) / last_price) * 100
|
| 230 |
return {
|
| 231 |
'values': mean_forecast,
|
| 232 |
-
'dates': pd.date_range(start=data.index[-1] + timedelta(days=1), periods=pred_len, freq='D'),
|
| 233 |
'high_30d': predicted_high,
|
| 234 |
'low_30d': predicted_low,
|
| 235 |
'mean_30d': predicted_mean,
|
| 236 |
'change_pct': change_pct,
|
| 237 |
-
'summary': f""
|
| 238 |
-
AI Model: Amazon Chronos-Bolt (Base)
|
| 239 |
-
Prediction Period: {pred_len} days
|
| 240 |
-
Expected Change: {change_pct:.2f}%
|
| 241 |
-
Confidence: Medium
|
| 242 |
-
Note: AI predictions are for reference only and not financial advice
|
| 243 |
-
""".strip()
|
| 244 |
}
|
| 245 |
except Exception as e:
|
| 246 |
print(f"Error in prediction: {e}")
|
|
|
|
| 7 |
import plotly.express as px
|
| 8 |
from plotly.subplots import make_subplots
|
| 9 |
import spaces
|
|
|
|
| 10 |
|
| 11 |
def get_indonesian_stocks():
|
| 12 |
return {
|
|
|
|
| 216 |
def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
|
| 217 |
try:
|
| 218 |
prices = data['Close'].values.astype(np.float32)
|
| 219 |
+
try:
|
| 220 |
+
from chronos import BaseChronosPipeline
|
| 221 |
+
except Exception as ie:
|
| 222 |
+
return {
|
| 223 |
+
'values': [],
|
| 224 |
+
'dates': [],
|
| 225 |
+
'high_30d': 0,
|
| 226 |
+
'low_30d': 0,
|
| 227 |
+
'mean_30d': 0,
|
| 228 |
+
'change_pct': 0,
|
| 229 |
+
'summary': 'chronos package not installed. install with: pip install chronos-forecasting'
|
| 230 |
+
}
|
| 231 |
pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
|
| 232 |
with torch.no_grad():
|
| 233 |
forecast = pipeline.predict(context=torch.tensor(prices), prediction_length=prediction_days)
|
| 234 |
+
if hasattr(forecast, 'mean'):
|
| 235 |
+
mean_forecast = forecast.mean(dim=1).squeeze().cpu().numpy()
|
| 236 |
+
else:
|
| 237 |
+
mean_forecast = forecast.mean(dim=1).squeeze().cpu().numpy()
|
| 238 |
pred_len = len(mean_forecast)
|
| 239 |
last_price = prices[-1]
|
| 240 |
+
predicted_high = np.max(mean_forecast) if pred_len > 0 else 0
|
| 241 |
+
predicted_low = np.min(mean_forecast) if pred_len > 0 else 0
|
| 242 |
+
predicted_mean = np.mean(mean_forecast) if pred_len > 0 else 0
|
| 243 |
+
change_pct = ((predicted_mean - last_price) / last_price) * 100 if last_price != 0 and pred_len > 0 else 0
|
| 244 |
return {
|
| 245 |
'values': mean_forecast,
|
| 246 |
+
'dates': pd.date_range(start=data.index[-1] + timedelta(days=1), periods=pred_len, freq='D') if pred_len > 0 else [],
|
| 247 |
'high_30d': predicted_high,
|
| 248 |
'low_30d': predicted_low,
|
| 249 |
'mean_30d': predicted_mean,
|
| 250 |
'change_pct': change_pct,
|
| 251 |
+
'summary': f"AI Model: Amazon Chronos-Bolt (Base)\nPrediction Period: {pred_len} days\nExpected Change: {change_pct:.2f}%\nConfidence: Medium\nNote: AI predictions are for reference only and not financial advice"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
}
|
| 253 |
except Exception as e:
|
| 254 |
print(f"Error in prediction: {e}")
|