prudhviLatha commited on
Commit
39fd53f
·
verified ·
1 Parent(s): 8702e4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -0
app.py CHANGED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import numpy as np
5
+ from datetime import datetime
6
+ from simple_salesforce import Salesforce
7
+ from dotenv import load_dotenv
8
+ import plotly.express as px
9
+
10
+ # Load environment variables from .env
11
+ load_dotenv()
12
+
13
+ # Salesforce credentials
14
+ SF_USERNAME = os.getenv('SF_USERNAME')
15
+ SF_PASSWORD = os.getenv('SF_PASSWORD')
16
+ SF_SECURITY_TOKEN = os.getenv('SF_SECURITY_TOKEN')
17
+
18
+ # Connect to Salesforce
19
+ try:
20
+ sf = Salesforce(
21
+ username=SF_USERNAME,
22
+ password=SF_PASSWORD,
23
+ security_token=SF_SECURITY_TOKEN
24
+ )
25
+ except Exception as e:
26
+ sf = None
27
+ print(f"Error connecting to Salesforce: {str(e)}")
28
+
29
+ # Function to fetch Project ID from Salesforce automatically
30
+ def get_project_id():
31
+ if not sf:
32
+ return None, "Salesforce connection failed. Check credentials."
33
+ try:
34
+ query = "SELECT Id FROM Project__c ORDER BY CreatedDate DESC LIMIT 1"
35
+ result = sf.query(query)
36
+ if result['totalSize'] > 0:
37
+ return result['records'][0]['Id'], None
38
+ return None, "No project found in Salesforce."
39
+ except Exception as e:
40
+ return None, f"Error fetching Project ID: {str(e)}"
41
+
42
+ # Simple moving average forecast
43
+ def simple_forecast(df):
44
+ df['Date'] = pd.to_datetime(df['Date'], dayfirst=True)
45
+ df['Forecast'] = df['Attendance'].rolling(window=3, min_periods=1).mean()
46
+ future_dates = pd.date_range(df['Date'].max(), periods=4, freq='D')[1:]
47
+ future_preds = np.repeat(df['Forecast'].iloc[-1], 3)
48
+ predictions = [
49
+ {"date": date.strftime('%Y-%m-%d'), "headcount": round(pred)}
50
+ for date, pred in zip(future_dates, future_preds)
51
+ ]
52
+ return predictions
53
+
54
+ # Save record to Salesforce
55
+ def save_to_salesforce(record):
56
+ if not sf:
57
+ return {"error": "Salesforce connection failed. Check credentials."}
58
+ try:
59
+ result = sf.Labour_Attendance_Forecast__c.create(record)
60
+ return {"success": f"Record created successfully for {record['Trade__c']}", "record_id": result['id']}
61
+ except Exception as e:
62
+ return {"error": f"Error uploading data to Salesforce for {record['Trade__c']}: {str(e)}"}
63
+
64
+ # Create line chart for multiple trades
65
+ def create_chart(df, predictions_dict):
66
+ combined_df = pd.DataFrame()
67
+ for trade, predictions in predictions_dict.items():
68
+ trade_df = df[df['Trade'] == trade].copy()
69
+ trade_df['Type'] = 'Historical'
70
+ trade_df['Trade'] = trade
71
+
72
+ forecast_df = pd.DataFrame(predictions)
73
+ forecast_df['Date'] = pd.to_datetime(forecast_df['date'])
74
+ forecast_df['Attendance'] = forecast_df['headcount']
75
+ forecast_df['Type'] = 'Forecast'
76
+ forecast_df['Trade'] = trade
77
+
78
+ combined_df = pd.concat([
79
+ combined_df,
80
+ trade_df[['Date', 'Attendance', 'Type', 'Trade']],
81
+ forecast_df[['Date', 'Attendance', 'Type', 'Trade']]
82
+ ])
83
+
84
+ fig = px.line(
85
+ combined_df,
86
+ x='Date',
87
+ y='Attendance',
88
+ color='Trade',
89
+ line_dash='Type',
90
+ markers=True,
91
+ title='Labour Attendance Forecast by Trade'
92
+ )
93
+ return fig
94
+
95
+ # Format output in bullet/line-by-line style for multiple trades
96
+ def format_output(trade_results):
97
+ exclude_keys = {'Project__c', 'record_id', 'success'}
98
+ output = []
99
+ for trade, data in trade_results.items():
100
+ output.append(f"Trade: {trade}")
101
+ for key, value in data.items():
102
+ if key in exclude_keys:
103
+ continue
104
+ if isinstance(value, list):
105
+ value = ', '.join(str(item) for item in value)
106
+ output.append(f" • {key}: {value}")
107
+ output.append("")
108
+ return "\n".join(output)
109
+
110
+ # Forecast function for Gradio
111
+ def forecast_labour(csv_file):
112
+ try:
113
+ encodings = ['utf-8', 'latin1', 'iso-8859-1', 'utf-16']
114
+ df = None
115
+ for encoding in encodings:
116
+ try:
117
+ df = pd.read_csv(csv_file.name, encoding=encoding)
118
+ break
119
+ except UnicodeDecodeError:
120
+ continue
121
+ if df is None:
122
+ return "Error: Could not decode CSV file with any supported encoding (utf-8, latin1, iso-8859-1, utf-16). Please ensure the file is properly encoded.", None
123
+
124
+ df.columns = df.columns.str.strip().str.capitalize()
125
+
126
+ required_columns = ['Date', 'Attendance', 'Trade', 'Weather', 'Alert_status', 'Shortage_risk', 'Suggested_actions']
127
+ missing_columns = [col for col in required_columns if col not in df.columns]
128
+ if missing_columns:
129
+ return f"Error: CSV missing required columns: {', '.join(missing_columns)}", None
130
+
131
+ df['Date'] = pd.to_datetime(df['Date'], dayfirst=True)
132
+ df['Attendance'] = df['Attendance'].astype(int)
133
+ df['Shortage_risk'] = df['Shortage_risk'].replace('%', '', regex=True).astype(float) / 100
134
+
135
+ unique_trades = df['Trade'].unique()
136
+ if len(unique_trades) < 10:
137
+ return f"Error: CSV contains only {len(unique_trades)} trades, but a minimum of 10 trades is required.", None
138
+
139
+ selected_trades = unique_trades[:10]
140
+ trade_results = {}
141
+ predictions_dict = {}
142
+
143
+ project_id, error = get_project_id()
144
+ if error:
145
+ return f"Error: {error}", None
146
+
147
+ for trade in selected_trades:
148
+ trade_df = df[df['Trade'] == trade].copy()
149
+ if trade_df.empty:
150
+ continue
151
+
152
+ predictions = simple_forecast(trade_df)
153
+ predictions_dict[trade] = predictions
154
+
155
+ latest_record = trade_df.sort_values(by='Date').iloc[-1]
156
+ weather = latest_record['Weather']
157
+ alert_status = latest_record['Alert_status']
158
+ shortage_risk = latest_record['Shortage_risk']
159
+ suggested_actions = latest_record['Suggested_actions']
160
+
161
+ result_data = {
162
+ "Title": f"Labour Attendance Data for {trade}",
163
+ "Date": trade_df['Date'].max().strftime('%B %Y'),
164
+ "Trade": trade,
165
+ "Weather": weather,
166
+ "Forecast": predictions,
167
+ "Alert Status": alert_status,
168
+ "Shortage_risk": shortage_risk,
169
+ "Suggested_actions": suggested_actions,
170
+ "Expected_headcount": predictions[0]['headcount'],
171
+ "Actual_headcount": int(trade_df['Attendance'].iloc[-1]),
172
+ "Forecast_Next_3_Days__c": predictions,
173
+ "Project__c": project_id
174
+ }
175
+
176
+ salesforce_record = {
177
+ 'Trade__c': trade,
178
+ 'Shortage_Risk__c': shortage_risk,
179
+ 'Suggested_Actions__c': suggested_actions,
180
+ 'Expected_Headcount__c': result_data['Expected_headcount'],
181
+ 'Actual_Headcount__c': result_data['Actual_headcount'],
182
+ 'Forecast_Next_3_Days__c': str(predictions),
183
+ 'Project_ID__c': project_id,
184
+ 'Alert_Status__c': alert_status,
185
+ 'Dashboard_Display__c': True,
186
+ 'Date__c': trade_df['Date'].max().date().isoformat()
187
+ }
188
+
189
+ sf_result = save_to_salesforce(salesforce_record)
190
+ result_data.update(sf_result)
191
+ trade_results[trade] = result_data
192
+
193
+ chart = create_chart(df, predictions_dict)
194
+ return format_output(trade_results), chart
195
+
196
+ except Exception as e:
197
+ return f"Error processing file: {str(e)}", None
198
+
199
+ # Gradio UI without share
200
+ def gradio_interface():
201
+ gr.Interface(
202
+ fn=forecast_labour,
203
+ inputs=[
204
+ gr.File(label="Upload CSV with required columns for at least 10 trades")
205
+ ],
206
+ outputs=[
207
+ gr.Textbox(label="Forecast Result", lines=20),
208
+ gr.Plot(label="Forecast Chart")
209
+ ],
210
+ title="Labour Attendance Forecast",
211
+ description="Upload a CSV file with columns: Date, Attendance, Trade, Weather, Alert_Status, Shortage_Risk (e.g. 22%), Suggested_Actions. The file must contain data for at least 10 trades. "
212
+ ).launch(share=False)
213
+
214
+ if __name__ == '__main__':
215
+ gradio_interface()