Spaces:
Sleeping
Sleeping
Update forecast_model.py
Browse files- forecast_model.py +89 -0
forecast_model.py
CHANGED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from prophet import Prophet
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
def train_and_forecast(data):
|
| 6 |
+
try:
|
| 7 |
+
# Convert input data to DataFrame
|
| 8 |
+
if not data:
|
| 9 |
+
return {"error": "Input data is empty"}
|
| 10 |
+
df = pd.DataFrame(data)
|
| 11 |
+
|
| 12 |
+
# Validate required columns
|
| 13 |
+
required_cols = ['Date', 'Attendance']
|
| 14 |
+
if not all(col in df.columns for col in required_cols):
|
| 15 |
+
return {"error": "Input must contain 'Date' and 'Attendance' columns"}
|
| 16 |
+
|
| 17 |
+
# Validate and process Date and Attendance
|
| 18 |
+
df['Date'] = pd.to_datetime(df['Date'], errors='coerce')
|
| 19 |
+
if df['Date'].isna().any():
|
| 20 |
+
return {"error": "Invalid 'Date' format. Use YYYY-MM-DD"}
|
| 21 |
+
df['Attendance'] = pd.to_numeric(df['Attendance'], errors='coerce')
|
| 22 |
+
if df['Attendance'].isna().any():
|
| 23 |
+
return {"error": "'Attendance' must be numeric"}
|
| 24 |
+
|
| 25 |
+
# Check for sufficient data
|
| 26 |
+
if len(df) < 2:
|
| 27 |
+
return {"error": "At least 2 data points are required for forecasting"}
|
| 28 |
+
|
| 29 |
+
# Preserve Risk and Alert for historical dates
|
| 30 |
+
df_risk_alert = None
|
| 31 |
+
if 'Risk' in df.columns and 'Alert' in df.columns:
|
| 32 |
+
df_risk_alert = df[['Date', 'Risk', 'Alert']].drop_duplicates(subset='Date')
|
| 33 |
+
|
| 34 |
+
# Rename columns for Prophet
|
| 35 |
+
df_prophet = df.rename(columns={'Date': 'ds', 'Attendance': 'y'})
|
| 36 |
+
|
| 37 |
+
# Initialize and train Prophet model
|
| 38 |
+
model = Prophet(
|
| 39 |
+
yearly_seasonality=True,
|
| 40 |
+
weekly_seasonality=True,
|
| 41 |
+
daily_seasonality=True
|
| 42 |
+
)
|
| 43 |
+
model.fit(df_prophet)
|
| 44 |
+
|
| 45 |
+
# Create future dates (30 days)
|
| 46 |
+
future = model.make_future_dataframe(periods=30, freq='D')
|
| 47 |
+
forecast = model.predict(future)
|
| 48 |
+
|
| 49 |
+
# Select relevant columns and rename back
|
| 50 |
+
forecast = forecast[['ds', 'yhat']].rename(columns={'ds': 'Date', 'yhat': 'Attendance'})
|
| 51 |
+
forecast['Date'] = forecast['Date'].astype(str)
|
| 52 |
+
|
| 53 |
+
# Estimate Risk and Alert for all dates
|
| 54 |
+
avg_attendance = df['Attendance'].mean()
|
| 55 |
+
def estimate_risk(attendance):
|
| 56 |
+
if attendance >= avg_attendance * 0.8:
|
| 57 |
+
return round(10.0 + (avg_attendance - attendance) * 2, 1)
|
| 58 |
+
elif attendance >= avg_attendance * 0.5:
|
| 59 |
+
return round(15.0 + (avg_attendance - attendance) * 3, 1)
|
| 60 |
+
else:
|
| 61 |
+
return round(20.0 + (avg_attendance - attendance) * 4, 1)
|
| 62 |
+
|
| 63 |
+
def estimate_alert(attendance):
|
| 64 |
+
if attendance >= avg_attendance * 0.8:
|
| 65 |
+
return "Low"
|
| 66 |
+
elif attendance >= avg_attendance * 0.5:
|
| 67 |
+
return "Medium"
|
| 68 |
+
else:
|
| 69 |
+
return "High"
|
| 70 |
+
|
| 71 |
+
forecast['Risk'] = forecast['Attendance'].apply(estimate_risk)
|
| 72 |
+
forecast['Alert'] = forecast['Attendance'].apply(estimate_alert)
|
| 73 |
+
|
| 74 |
+
# Merge historical Risk/Alert
|
| 75 |
+
if df_risk_alert is not None:
|
| 76 |
+
df_risk_alert['Date'] = df_risk_alert['Date'].astype(str)
|
| 77 |
+
forecast = forecast.merge(df_risk_alert, on='Date', how='left', suffixes=('', '_hist'))
|
| 78 |
+
forecast['Risk'] = forecast['Risk_hist'].combine_first(forecast['Risk'])
|
| 79 |
+
forecast['Alert'] = forecast['Alert_hist'].combine_first(forecast['Alert'])
|
| 80 |
+
forecast = forecast.drop(columns=['Risk_hist', 'Alert_hist'])
|
| 81 |
+
|
| 82 |
+
# Round numeric values
|
| 83 |
+
forecast['Attendance'] = forecast['Attendance'].round(1)
|
| 84 |
+
|
| 85 |
+
# Output relevant columns
|
| 86 |
+
forecast = forecast[['Date', 'Attendance', 'Risk', 'Alert']]
|
| 87 |
+
return forecast.to_dict('records')
|
| 88 |
+
except Exception as e:
|
| 89 |
+
return {"error": f"Forecasting failed: {str(e)}"}
|