Spaces:
Sleeping
Sleeping
Update stock_analysis.py
Browse files- stock_analysis.py +20 -17
stock_analysis.py
CHANGED
|
@@ -10,32 +10,35 @@ def is_business_day(a_date):
|
|
| 10 |
return a_date.weekday() < 5
|
| 11 |
|
| 12 |
def forecast_series(series, model="ARIMA", forecast_period=FORECAST_PERIOD):
|
| 13 |
-
predictions = []
|
| 14 |
-
confidence_intervals = []
|
| 15 |
-
|
| 16 |
if series.shape[1] > 1:
|
| 17 |
-
series = series['Close'].values
|
| 18 |
|
| 19 |
if model == "ARIMA":
|
| 20 |
model = ARIMA(series, order=(5, 1, 0))
|
| 21 |
model_fit = model.fit()
|
| 22 |
-
forecast = model_fit.forecast(steps=forecast_period
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
# Check if forecast is a numpy array (newer statsmodels) or a ForecastResults object (older statsmodels)
|
| 25 |
-
if isinstance(forecast, np.ndarray):
|
| 26 |
-
predictions = forecast
|
| 27 |
-
confidence_intervals = model_fit.get_forecast(steps=forecast_period).conf_int()
|
| 28 |
-
else:
|
| 29 |
-
predictions = forecast.predicted_mean
|
| 30 |
-
confidence_intervals = forecast.conf_int()
|
| 31 |
elif model == "Prophet":
|
| 32 |
# Implement Prophet forecasting method
|
| 33 |
pass
|
| 34 |
elif model == "LSTM":
|
| 35 |
# Implement LSTM forecasting method
|
| 36 |
pass
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
return predictions,
|
| 39 |
|
| 40 |
def get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method, start_date, end_date):
|
| 41 |
stock_name, ticker_name = stock.split(":")
|
|
@@ -49,14 +52,14 @@ def get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method,
|
|
| 49 |
predictions, confidence_intervals = forecast_series(series, model=forecast_method)
|
| 50 |
|
| 51 |
last_date = pd.to_datetime(series['Date'].values[-1])
|
| 52 |
-
forecast_dates = pd.date_range(start=last_date + timedelta(days=1), periods=
|
| 53 |
-
forecast_dates = [date for date in forecast_dates if is_business_day(date)]
|
| 54 |
|
| 55 |
forecast = pd.DataFrame({
|
| 56 |
"Date": forecast_dates,
|
| 57 |
"Forecast": predictions,
|
| 58 |
-
"Lower_CI": confidence_intervals[
|
| 59 |
-
"Upper_CI": confidence_intervals[
|
| 60 |
})
|
| 61 |
|
| 62 |
if graph_type == 'Line Graph':
|
|
|
|
| 10 |
return a_date.weekday() < 5
|
| 11 |
|
| 12 |
def forecast_series(series, model="ARIMA", forecast_period=FORECAST_PERIOD):
|
|
|
|
|
|
|
|
|
|
| 13 |
if series.shape[1] > 1:
|
| 14 |
+
series = series['Close'].values
|
| 15 |
|
| 16 |
if model == "ARIMA":
|
| 17 |
model = ARIMA(series, order=(5, 1, 0))
|
| 18 |
model_fit = model.fit()
|
| 19 |
+
forecast = model_fit.forecast(steps=forecast_period)
|
| 20 |
+
|
| 21 |
+
# Get confidence intervals
|
| 22 |
+
conf_int = model_fit.get_forecast(steps=forecast_period).conf_int()
|
| 23 |
+
lower_ci = conf_int.iloc[:, 0] if isinstance(conf_int, pd.DataFrame) else conf_int[:, 0]
|
| 24 |
+
upper_ci = conf_int.iloc[:, 1] if isinstance(conf_int, pd.DataFrame) else conf_int[:, 1]
|
| 25 |
+
|
| 26 |
+
# Ensure all arrays have the same length
|
| 27 |
+
min_length = min(len(forecast), len(lower_ci), len(upper_ci))
|
| 28 |
+
predictions = forecast[:min_length]
|
| 29 |
+
lower_ci = lower_ci[:min_length]
|
| 30 |
+
upper_ci = upper_ci[:min_length]
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
elif model == "Prophet":
|
| 33 |
# Implement Prophet forecasting method
|
| 34 |
pass
|
| 35 |
elif model == "LSTM":
|
| 36 |
# Implement LSTM forecasting method
|
| 37 |
pass
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError(f"Unsupported model: {model}")
|
| 40 |
|
| 41 |
+
return predictions, pd.DataFrame({'Lower_CI': lower_ci, 'Upper_CI': upper_ci})
|
| 42 |
|
| 43 |
def get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method, start_date, end_date):
|
| 44 |
stock_name, ticker_name = stock.split(":")
|
|
|
|
| 52 |
predictions, confidence_intervals = forecast_series(series, model=forecast_method)
|
| 53 |
|
| 54 |
last_date = pd.to_datetime(series['Date'].values[-1])
|
| 55 |
+
forecast_dates = pd.date_range(start=last_date + timedelta(days=1), periods=len(predictions))
|
| 56 |
+
forecast_dates = [date for date in forecast_dates if is_business_day(date)][:len(predictions)]
|
| 57 |
|
| 58 |
forecast = pd.DataFrame({
|
| 59 |
"Date": forecast_dates,
|
| 60 |
"Forecast": predictions,
|
| 61 |
+
"Lower_CI": confidence_intervals['Lower_CI'],
|
| 62 |
+
"Upper_CI": confidence_intervals['Upper_CI']
|
| 63 |
})
|
| 64 |
|
| 65 |
if graph_type == 'Line Graph':
|