prudhviLatha commited on
Commit
eb9dd55
·
verified ·
1 Parent(s): 2cb6075

Update forecast_model.py

Browse files
Files changed (1) hide show
  1. forecast_model.py +89 -0
forecast_model.py CHANGED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from prophet import Prophet
3
+ import json
4
+
5
+ def train_and_forecast(data):
6
+ try:
7
+ # Convert input data to DataFrame
8
+ if not data:
9
+ return {"error": "Input data is empty"}
10
+ df = pd.DataFrame(data)
11
+
12
+ # Validate required columns
13
+ required_cols = ['Date', 'Attendance']
14
+ if not all(col in df.columns for col in required_cols):
15
+ return {"error": "Input must contain 'Date' and 'Attendance' columns"}
16
+
17
+ # Validate and process Date and Attendance
18
+ df['Date'] = pd.to_datetime(df['Date'], errors='coerce')
19
+ if df['Date'].isna().any():
20
+ return {"error": "Invalid 'Date' format. Use YYYY-MM-DD"}
21
+ df['Attendance'] = pd.to_numeric(df['Attendance'], errors='coerce')
22
+ if df['Attendance'].isna().any():
23
+ return {"error": "'Attendance' must be numeric"}
24
+
25
+ # Check for sufficient data
26
+ if len(df) < 2:
27
+ return {"error": "At least 2 data points are required for forecasting"}
28
+
29
+ # Preserve Risk and Alert for historical dates
30
+ df_risk_alert = None
31
+ if 'Risk' in df.columns and 'Alert' in df.columns:
32
+ df_risk_alert = df[['Date', 'Risk', 'Alert']].drop_duplicates(subset='Date')
33
+
34
+ # Rename columns for Prophet
35
+ df_prophet = df.rename(columns={'Date': 'ds', 'Attendance': 'y'})
36
+
37
+ # Initialize and train Prophet model
38
+ model = Prophet(
39
+ yearly_seasonality=True,
40
+ weekly_seasonality=True,
41
+ daily_seasonality=True
42
+ )
43
+ model.fit(df_prophet)
44
+
45
+ # Create future dates (30 days)
46
+ future = model.make_future_dataframe(periods=30, freq='D')
47
+ forecast = model.predict(future)
48
+
49
+ # Select relevant columns and rename back
50
+ forecast = forecast[['ds', 'yhat']].rename(columns={'ds': 'Date', 'yhat': 'Attendance'})
51
+ forecast['Date'] = forecast['Date'].astype(str)
52
+
53
+ # Estimate Risk and Alert for all dates
54
+ avg_attendance = df['Attendance'].mean()
55
+ def estimate_risk(attendance):
56
+ if attendance >= avg_attendance * 0.8:
57
+ return round(10.0 + (avg_attendance - attendance) * 2, 1)
58
+ elif attendance >= avg_attendance * 0.5:
59
+ return round(15.0 + (avg_attendance - attendance) * 3, 1)
60
+ else:
61
+ return round(20.0 + (avg_attendance - attendance) * 4, 1)
62
+
63
+ def estimate_alert(attendance):
64
+ if attendance >= avg_attendance * 0.8:
65
+ return "Low"
66
+ elif attendance >= avg_attendance * 0.5:
67
+ return "Medium"
68
+ else:
69
+ return "High"
70
+
71
+ forecast['Risk'] = forecast['Attendance'].apply(estimate_risk)
72
+ forecast['Alert'] = forecast['Attendance'].apply(estimate_alert)
73
+
74
+ # Merge historical Risk/Alert
75
+ if df_risk_alert is not None:
76
+ df_risk_alert['Date'] = df_risk_alert['Date'].astype(str)
77
+ forecast = forecast.merge(df_risk_alert, on='Date', how='left', suffixes=('', '_hist'))
78
+ forecast['Risk'] = forecast['Risk_hist'].combine_first(forecast['Risk'])
79
+ forecast['Alert'] = forecast['Alert_hist'].combine_first(forecast['Alert'])
80
+ forecast = forecast.drop(columns=['Risk_hist', 'Alert_hist'])
81
+
82
+ # Round numeric values
83
+ forecast['Attendance'] = forecast['Attendance'].round(1)
84
+
85
+ # Output relevant columns
86
+ forecast = forecast[['Date', 'Attendance', 'Risk', 'Alert']]
87
+ return forecast.to_dict('records')
88
+ except Exception as e:
89
+ return {"error": f"Forecasting failed: {str(e)}"}