prudhviLatha commited on
Commit
5a56751
·
verified ·
1 Parent(s): 9415caf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -10
app.py CHANGED
@@ -6,6 +6,7 @@ from datetime import datetime
6
  from simple_salesforce import Salesforce
7
  from dotenv import load_dotenv
8
  import plotly.express as px
 
9
 
10
  # Load environment variables from .env
11
  load_dotenv()
@@ -26,7 +27,7 @@ except Exception as e:
26
  sf = None
27
  print(f"Error connecting to Salesforce: {str(e)}")
28
 
29
- # Function to fetch Project ID from Salesforce automatically
30
  def get_project_id():
31
  if not sf:
32
  return None, "Salesforce connection failed. Check credentials."
@@ -39,9 +40,10 @@ def get_project_id():
39
  except Exception as e:
40
  return None, f"Error fetching Project ID: {str(e)}"
41
 
42
- # Simple moving average forecast
43
  def simple_forecast(df):
44
  df['Date'] = pd.to_datetime(df['Date'], dayfirst=True)
 
45
  df['Forecast'] = df['Attendance'].rolling(window=3, min_periods=1).mean()
46
  future_dates = pd.date_range(df['Date'].max(), periods=4, freq='D')[1:]
47
  future_preds = np.repeat(df['Forecast'].iloc[-1], 3)
@@ -66,10 +68,14 @@ def create_chart(df, predictions_dict):
66
  combined_df = pd.DataFrame()
67
  for trade, predictions in predictions_dict.items():
68
  trade_df = df[df['Trade'] == trade].copy()
 
 
69
  trade_df['Type'] = 'Historical'
70
  trade_df['Trade'] = trade
71
 
72
  forecast_df = pd.DataFrame(predictions)
 
 
73
  forecast_df['Date'] = pd.to_datetime(forecast_df['date'])
74
  forecast_df['Attendance'] = forecast_df['headcount']
75
  forecast_df['Type'] = 'Forecast'
@@ -81,6 +87,9 @@ def create_chart(df, predictions_dict):
81
  forecast_df[['Date', 'Attendance', 'Type', 'Trade']]
82
  ])
83
 
 
 
 
84
  fig = px.line(
85
  combined_df,
86
  x='Date',
@@ -105,7 +114,7 @@ def format_output(trade_results):
105
  value = ', '.join(str(item) for item in value)
106
  output.append(f" • {key}: {value}")
107
  output.append("")
108
- return "\n".join(output)
109
 
110
  # Forecast function for Gradio
111
  def forecast_labour(csv_file):
@@ -119,37 +128,54 @@ def forecast_labour(csv_file):
119
  except UnicodeDecodeError:
120
  continue
121
  if df is None:
122
- return "Error: Could not decode CSV file with any supported encoding (utf-8, latin1, iso-8859-1, utf-16). Please ensure the file is properly encoded.", None
123
 
124
  df.columns = df.columns.str.strip().str.capitalize()
125
-
126
  required_columns = ['Date', 'Attendance', 'Trade', 'Weather', 'Alert_status', 'Shortage_risk', 'Suggested_actions']
127
  missing_columns = [col for col in required_columns if col not in df.columns]
128
  if missing_columns:
129
  return f"Error: CSV missing required columns: {', '.join(missing_columns)}", None
130
 
131
- df['Date'] = pd.to_datetime(df['Date'], dayfirst=True)
132
- df['Attendance'] = df['Attendance'].astype(int)
 
 
133
  df['Shortage_risk'] = df['Shortage_risk'].replace('%', '', regex=True).astype(float) / 100
 
134
 
135
  unique_trades = df['Trade'].unique()
136
  if len(unique_trades) < 10:
137
- return f"Error: CSV contains only {len(unique_trades)} trades, but a minimum of 10 trades is required.", None
 
 
 
 
 
 
138
 
139
  selected_trades = unique_trades[:10]
140
  trade_results = {}
141
  predictions_dict = {}
 
 
142
 
143
  project_id, error = get_project_id()
144
  if error:
145
  return f"Error: {error}", None
146
 
147
  for trade in selected_trades:
 
 
 
148
  trade_df = df[df['Trade'] == trade].copy()
149
  if trade_df.empty:
 
150
  continue
151
 
152
  predictions = simple_forecast(trade_df)
 
 
 
153
  predictions_dict[trade] = predictions
154
 
155
  latest_record = trade_df.sort_values(by='Date').iloc[-1]
@@ -190,6 +216,12 @@ def forecast_labour(csv_file):
190
  result_data.update(sf_result)
191
  trade_results[trade] = result_data
192
 
 
 
 
 
 
 
193
  chart = create_chart(df, predictions_dict)
194
  return format_output(trade_results), chart
195
 
@@ -208,8 +240,8 @@ def gradio_interface():
208
  gr.Plot(label="Forecast Chart")
209
  ],
210
  title="Labour Attendance Forecast",
211
- description="Upload a CSV file with columns: Date, Attendance, Trade, Weather, Alert_Status, Shortage_Risk (e.g. 22%), Suggested_Actions. The file must contain data for at least 10 trades. "
212
  ).launch(share=False)
213
 
214
  if __name__ == '__main__':
215
- gradio_interface()
 
6
  from simple_salesforce import Salesforce
7
  from dotenv import load_dotenv
8
  import plotly.express as px
9
+ import plotly.graph_objects as go
10
 
11
  # Load environment variables from .env
12
  load_dotenv()
 
27
  sf = None
28
  print(f"Error connecting to Salesforce: {str(e)}")
29
 
30
+ # Function to fetch Project ID from Salesforce
31
  def get_project_id():
32
  if not sf:
33
  return None, "Salesforce connection failed. Check credentials."
 
40
  except Exception as e:
41
  return None, f"Error fetching Project ID: {str(e)}"
42
 
43
+ # Simple moving average forecast (works with 1+ days of data)
44
  def simple_forecast(df):
45
  df['Date'] = pd.to_datetime(df['Date'], dayfirst=True)
46
+ # Use rolling mean with min_periods=1 to allow single-day data
47
  df['Forecast'] = df['Attendance'].rolling(window=3, min_periods=1).mean()
48
  future_dates = pd.date_range(df['Date'].max(), periods=4, freq='D')[1:]
49
  future_preds = np.repeat(df['Forecast'].iloc[-1], 3)
 
68
  combined_df = pd.DataFrame()
69
  for trade, predictions in predictions_dict.items():
70
  trade_df = df[df['Trade'] == trade].copy()
71
+ if trade_df.empty:
72
+ continue
73
  trade_df['Type'] = 'Historical'
74
  trade_df['Trade'] = trade
75
 
76
  forecast_df = pd.DataFrame(predictions)
77
+ if forecast_df.empty:
78
+ continue
79
  forecast_df['Date'] = pd.to_datetime(forecast_df['date'])
80
  forecast_df['Attendance'] = forecast_df['headcount']
81
  forecast_df['Type'] = 'Forecast'
 
87
  forecast_df[['Date', 'Attendance', 'Type', 'Trade']]
88
  ])
89
 
90
+ if combined_df.empty:
91
+ return go.Figure().update_layout(title="Labour Attendance Forecast (No Data)")
92
+
93
  fig = px.line(
94
  combined_df,
95
  x='Date',
 
114
  value = ', '.join(str(item) for item in value)
115
  output.append(f" • {key}: {value}")
116
  output.append("")
117
+ return "\n".join(output) if trade_results else "No valid trade data available."
118
 
119
  # Forecast function for Gradio
120
  def forecast_labour(csv_file):
 
128
  except UnicodeDecodeError:
129
  continue
130
  if df is None:
131
+ return "Error: Could not decode CSV file with any supported encoding (utf-8, latin1, iso-8859-1, utf-16).", None
132
 
133
  df.columns = df.columns.str.strip().str.capitalize()
 
134
  required_columns = ['Date', 'Attendance', 'Trade', 'Weather', 'Alert_status', 'Shortage_risk', 'Suggested_actions']
135
  missing_columns = [col for col in required_columns if col not in df.columns]
136
  if missing_columns:
137
  return f"Error: CSV missing required columns: {', '.join(missing_columns)}", None
138
 
139
+ df['Date'] = pd.to_datetime(df['Date'], dayfirst=True, errors='coerce')
140
+ if df['Date'].isna().all():
141
+ return "Error: All dates in CSV are invalid.", None
142
+ df['Attendance'] = pd.to_numeric(df['Attendance'], errors='coerce').fillna(0).astype(int)
143
  df['Shortage_risk'] = df['Shortage_risk'].replace('%', '', regex=True).astype(float) / 100
144
+ df['Shortage_risk'] = df['Shortage_risk'].fillna(0.5)
145
 
146
  unique_trades = df['Trade'].unique()
147
  if len(unique_trades) < 10:
148
+ return f"Error: CSV contains only {len(unique_trades)} trades, minimum 10 required.", None
149
+
150
+ # Check data sufficiency per trade
151
+ trade_data_counts = {trade: len(df[df['Trade'] == trade]) for trade in unique_trades}
152
+ insufficient_trades = [trade for trade, count in trade_data_counts.items() if count < 1]
153
+ if insufficient_trades:
154
+ return (f"Error: The following trades have no data: {', '.join(insufficient_trades)}"), None
155
 
156
  selected_trades = unique_trades[:10]
157
  trade_results = {}
158
  predictions_dict = {}
159
+ processed_trades = set() # Track processed trades to avoid duplicates
160
+ errors = []
161
 
162
  project_id, error = get_project_id()
163
  if error:
164
  return f"Error: {error}", None
165
 
166
  for trade in selected_trades:
167
+ if trade in processed_trades:
168
+ continue # Skip duplicates
169
+ processed_trades.add(trade)
170
  trade_df = df[df['Trade'] == trade].copy()
171
  if trade_df.empty:
172
+ errors.append(f"No data for trade: {trade}")
173
  continue
174
 
175
  predictions = simple_forecast(trade_df)
176
+ if not predictions:
177
+ errors.append(f"No forecast generated for trade: {trade}")
178
+ continue
179
  predictions_dict[trade] = predictions
180
 
181
  latest_record = trade_df.sort_values(by='Date').iloc[-1]
 
216
  result_data.update(sf_result)
217
  trade_results[trade] = result_data
218
 
219
+ if not trade_results:
220
+ error_msg = "No valid trade data processed."
221
+ if errors:
222
+ error_msg += " Errors: " + "; ".join(errors)
223
+ return error_msg, None
224
+
225
  chart = create_chart(df, predictions_dict)
226
  return format_output(trade_results), chart
227
 
 
240
  gr.Plot(label="Forecast Chart")
241
  ],
242
  title="Labour Attendance Forecast",
243
+ description="Upload a CSV file with columns: Date, Attendance, Trade, Weather, Alert_Status, Shortage_Risk (e.g. 22%), Suggested_Actions. The file must contain data for at least 10 trades."
244
  ).launch(share=False)
245
 
246
  if __name__ == '__main__':
247
+ gradio_interface()