tch / forecast_model.py
prudhviLatha's picture
Create forecast_model.py
eddc53a verified
import pandas as pd
from prophet import Prophet
import json
def train_and_forecast(data):
try:
# Convert input data to DataFrame
if not data:
return {"error": "Input data is empty"}
df = pd.DataFrame(data)
# Validate required columns
required_cols = ['Date', 'Attendance']
if not all(col in df.columns for col in required_cols):
return {"error": "Input must contain 'Date' and 'Attendance' columns"}
# Validate and process Date and Attendance
df['Date'] = pd.to_datetime(df['Date'], errors='coerce')
if df['Date'].isna().any():
return {"error": "Invalid 'Date' format. Use YYYY-MM-DD"}
df['Attendance'] = pd.to_numeric(df['Attendance'], errors='coerce')
if df['Attendance'].isna().any():
return {"error": "'Attendance' must be numeric"}
# Check for sufficient data
if len(df) < 2:
return {"error": "At least 2 data points are required for forecasting"}
# Preserve Risk and Alert for historical dates
df_risk_alert = None
if 'Risk' in df.columns and 'Alert' in df.columns:
df_risk_alert = df[['Date', 'Risk', 'Alert']].drop_duplicates(subset='Date')
# Rename columns for Prophet
df_prophet = df.rename(columns={'Date': 'ds', 'Attendance': 'y'})
# Initialize and train Prophet model
model = Prophet(
yearly_seasonality=True,
weekly_seasonality=True,
daily_seasonality=True
)
model.fit(df_prophet)
# Create future dates (30 days)
future = model.make_future_dataframe(periods=30, freq='D')
forecast = model.predict(future)
# Select relevant columns and rename back
forecast = forecast[['ds', 'yhat']].rename(columns={'ds': 'Date', 'yhat': 'Attendance'})
forecast['Date'] = forecast['Date'].astype(str)
# Estimate Risk and Alert for all dates
avg_attendance = df['Attendance'].mean()
def estimate_risk(attendance):
if attendance >= avg_attendance * 0.8:
return round(10.0 + (avg_attendance - attendance) * 2, 1)
elif attendance >= avg_attendance * 0.5:
return round(15.0 + (avg_attendance - attendance) * 3, 1)
else:
return round(20.0 + (avg_attendance - attendance) * 4, 1)
def estimate_alert(attendance):
if attendance >= avg_attendance * 0.8:
return "Low"
elif attendance >= avg_attendance * 0.5:
return "Medium"
else:
return "High"
forecast['Risk'] = forecast['Attendance'].apply(estimate_risk)
forecast['Alert'] = forecast['Attendance'].apply(estimate_alert)
# Merge historical Risk/Alert
if df_risk_alert is not None:
df_risk_alert['Date'] = df_risk_alert['Date'].astype(str)
forecast = forecast.merge(df_risk_alert, on='Date', how='left', suffixes=('', '_hist'))
forecast['Risk'] = forecast['Risk_hist'].combine_first(forecast['Risk'])
forecast['Alert'] = forecast['Alert_hist'].combine_first(forecast['Alert'])
forecast = forecast.drop(columns=['Risk_hist', 'Alert_hist'])
# Round numeric values
forecast['Attendance'] = forecast['Attendance'].round(1)
# Output relevant columns
forecast = forecast[['Date', 'Attendance', 'Risk', 'Alert']]
return forecast.to_dict('records')
except Exception as e:
return {"error": f"Forecasting failed: {str(e)}"}