prudhviLatha commited on
Commit
f8aed7a
·
verified ·
1 Parent(s): 34a9e7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +354 -6
app.py CHANGED
@@ -173,14 +173,362 @@ def weighted_moving_average_forecast(df, trade, site_calendar_date):
173
  except Exception as e:
174
  logger.error(f"Forecast error for trade {trade}: {str(e)}")
175
  return [], [], None, 'N/A', 'Normal', f"Forecast error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
-
178
-
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
-
181
-
182
-
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
 
 
 
 
 
 
 
173
  except Exception as e:
174
  logger.error(f"Forecast error for trade {trade}: {str(e)}")
175
  return [], [], None, 'N/A', 'Normal', f"Forecast error: {str(e)}"
176
+
177
+ # Real-time shortage risk heatmap for the selected day
178
+ def create_heatmap(df, predictions_dict, shortage_probs_dict, site_calendar_date):
179
+ try:
180
+ site_calendar_date = pd.to_datetime(site_calendar_date)
181
+ heatmap_data = []
182
+ # Extend to 6 days to match the screenshot (2025-04-24 to 2025-04-29)
183
+ future_dates = pd.date_range(site_calendar_date, periods=6, freq='D')
184
+
185
+ for trade in predictions_dict.keys():
186
+ probs = shortage_probs_dict.get(trade, [0.5] * len(future_dates))
187
+ for i, date in enumerate(future_dates):
188
+ # Use the shortage probability for the current day (index 0) and future days
189
+ prob = probs[i] if i < len(probs) else probs[-1] # Fallback to last prob if not enough data
190
+ heatmap_data.append({
191
+ 'Date': date.strftime('%Y-%m-%d'),
192
+ 'Trade': trade,
193
+ 'Shortage_Probability': prob
194
+ })
195
+
196
+ heatmap_df = pd.DataFrame(heatmap_data)
197
+ if heatmap_df.empty:
198
+ return go.Figure().update_layout(title="Shortage Risk Heatmap (No Data)")
199
+
200
+ display_probs = heatmap_df['Shortage_Probability'] * 100
201
+
202
+ # Custom colorscale to match screenshot: red at 0, transitioning to blues
203
+ custom_colorscale = [
204
+ [0, 'red'], # 0 maps to red
205
+ [0.001, '#1f77b4'], # Slightly above 0 starts with a blue shade
206
+ [0.5, '#aec7e8'], # Mid-range blue
207
+ [1, '#08306b'] # Dark blue at 1
208
+ ]
209
+
210
+ fig = go.Figure(data=go.Heatmap(
211
+ x=heatmap_df['Date'],
212
+ y=heatmap_df['Trade'],
213
+ z=heatmap_df['Shortage_Probability'],
214
+ colorscale=custom_colorscale,
215
+ zmin=0, zmax=1,
216
+ text=display_probs.round(0).astype(int).astype(str) + '%',
217
+ texttemplate="%{text}",
218
+ textfont={"size": 14, "color": "black"},
219
+ hovertemplate="Trade: %{y}<br>Date: %{x}<br>Shortage Risk: %{text}<extra></extra>",
220
+ colorbar=dict(
221
+ title="Shortage Risk",
222
+ tickvals=[0, 0.5, 1],
223
+ ticktext=["0%", "50%", "100%"]
224
+ )
225
+ ))
226
+
227
+ fig.update_layout(
228
+ title="Shortage Risk Heatmap",
229
+ xaxis_title="Date",
230
+ yaxis_title="Trade",
231
+ xaxis=dict(
232
+ tickangle=45,
233
+ tickformat="%Y-%m-%d",
234
+ showgrid=False
235
+ ),
236
+ yaxis=dict(
237
+ autorange="reversed",
238
+ showgrid=False
239
+ ),
240
+ font=dict(size=14),
241
+ margin=dict(l=100, r=50, t=100, b=100),
242
+ plot_bgcolor="white",
243
+ paper_bgcolor="white",
244
+ showlegend=False
245
+ )
246
+ return fig
247
+ except Exception as e:
248
+ logger.error(f"Error creating heatmap: {str(e)}")
249
+ return go.Figure().update_layout(title=f"Error in Heatmap: {str(e)}")
250
+
251
+ def create_chart(df, predictions_dict):
252
+ try:
253
+ combined_df = pd.DataFrame()
254
+ for trade, predictions in predictions_dict.items():
255
+ trade_df = df[df['Trade'].str.lower() == trade.lower()][['Date', 'Attendance']].copy()
256
+ trade_df['Type'] = 'Historical'
257
+ trade_df['Trade'] = trade
258
+
259
+ forecast_df = pd.DataFrame(predictions)
260
+ if not forecast_df.empty:
261
+ forecast_df['Date'] = pd.to_datetime(forecast_df['date'])
262
+ forecast_df['Attendance'] = forecast_df['headcount']
263
+ forecast_df['Type'] = 'Forecast'
264
+ forecast_df['Trade'] = trade
265
+ combined_df = pd.concat([combined_df, trade_df, forecast_df[['Date', 'Attendance', 'Type', 'Trade']]])
266
+
267
+ if combined_df.empty:
268
+ return go.Figure().update_layout(title="Labour Attendance Forecast (No Data)")
269
+
270
+ fig = px.line(
271
+ combined_df,
272
+ x='Date',
273
+ y='Attendance',
274
+ color='Trade',
275
+ line_dash='Type',
276
+ markers=True,
277
+ title='Labour Attendance Forecast by Trade'
278
+ )
279
+ return fig
280
+ except Exception as e:
281
+ logger.error(f"Error creating chart: {str(e)}")
282
+ return go.Figure().update_layout(title=f"Error in Chart: {str(e)}")
283
+
284
+ def generate_pdf_summary(trade_results):
285
+ try:
286
+ buffer = io.BytesIO()
287
+ with PdfPages(buffer) as pdf:
288
+ fig, ax = plt.subplots(figsize=(10, 6))
289
+ if not trade_results:
290
+ ax.text(0.1, 0.5, "No data available for summary", fontsize=12)
291
+ else:
292
+ for i, (trade, data) in enumerate(trade_results.items()):
293
+ ax.text(0.1, 0.9 - 0.1*i,
294
+ f"{trade}: {data['Attendance']} (Actual), Shortage Risk: {data['Shortage_risk'][0]*100:.0f}%", fontsize=12)
295
+ ax.set_title("Weekly Labour Forecast Summary")
296
+ ax.axis('off')
297
+ pdf.savefig()
298
+ plt.close()
299
+ pdf_base64 = base64.b64encode(buffer.getvalue()).decode()
300
+ logger.info("PDF summary generated")
301
+ return pdf_base64
302
+ except Exception as e:
303
+ logger.error(f"Error generating PDF: {str(e)}")
304
+ return None
305
+
306
+ def format_output(trade_results, site_calendar_date):
307
+ output_columns = Config.REQUIRED_COLUMNS + ['Forecast_Next_3_Days__c', 'Shortage_risk', 'Suggested_actions', 'Alert_status']
308
+ output = []
309
+ notifications = []
310
+
311
+ for trade, data in trade_results.items():
312
+ output.append(f"Trade: {trade}")
313
+ for key in output_columns:
314
+ if key == 'Date':
315
+ value = pd.to_datetime(site_calendar_date).strftime('%Y-%m-%d') if pd.notna(site_calendar_date) else 'N/A'
316
+ elif key == 'Forecast_Next_3_Days__c':
317
+ value = ', '.join([f"{item['date']}: {item['headcount']}" for item in data.get(key, [])]) if data.get(key) else 'N/A'
318
+ else:
319
+ value = data.get(key, 'N/A')
320
+ if key in ['Weather', 'Alert_status', 'Suggested_actions', 'Trade'] and value is not None:
321
+ value = str(value)
322
+ elif key == 'Shortage_risk' and value is not None:
323
+ value = str(round(value[0], 2))
324
+ elif key == 'Attendance' and value is not None:
325
+ value = str(int(value))
326
+ output.append(f" • {key}: {value}")
327
+
328
+ alert_status = data.get('Alert_status', 'Normal')
329
+ suggested_actions = data.get('Suggested_actions', 'Monitor')
330
+ if alert_status == 'Critical':
331
+ notification = f"Urgent Alert for {trade}: {suggested_actions} due to high shortage risk of {round(data.get('Shortage_risk', [0])[0] * 100)}%."
332
+ elif alert_status == 'Warning':
333
+ notification = f"Warning for {trade}: {suggested_actions} due to moderate shortage risk of {round(data.get('Shortage_risk', [0])[0] * 100)}%."
334
+ else:
335
+ notification = f"Notice for {trade}: {suggested_actions}, shortage risk is low at {round(data.get('Shortage_risk', [0])[0] * 100)}%."
336
+ notifications.append(notification)
337
+
338
+ output.append("")
339
+
340
+ formatted_output = "\n".join(output) if trade_results else "No valid trade data available."
341
+ formatted_notifications = "Contractor Notifications:\n" + "\n".join([f" • {notification}" for notification in notifications]) if notifications else "No notifications available."
342
+
343
+ return formatted_output, formatted_notifications
344
+
345
+ def push_to_salesforce(sf, trade_results, site_calendar_date):
346
+ try:
347
+ if sf is None:
348
+ return "Salesforce connection not established"
349
 
350
+ records_to_upsert = []
351
+ for trade, data in trade_results.items():
352
+ forecast_json = ', '.join([f"{item['date']}: {item['headcount']}" for item in data.get('Forecast_Next_3_Days__c', [])])
353
+ record = {
354
+ 'Trade__c': trade,
355
+ 'Date__c': site_calendar_date.strftime('%Y-%m-%d'),
356
+ 'Expected_Headcount__c': int(data['Attendance']),
357
+ 'Actual_Headcount__c': int(data['Attendance']),
358
+ 'Forecast_Next_3_Days__c': forecast_json,
359
+ 'Shortage_Risk__c': float(data['Shortage_risk'][0]),
360
+ 'Suggested_Actions__c': str(data['Suggested_actions']),
361
+ 'Alert_Status__c': str(data['Alert_status']),
362
+ 'Dashboard_Display__c': True
363
+ }
364
+ records_to_upsert.append(record)
365
 
366
+ for record in records_to_upsert:
367
+ sf.Labour_Attendance_Forecast__c.create(record)
 
368
 
369
+ logger.info(f"Successfully pushed {len(records_to_upsert)} records to Salesforce")
370
+ return None
371
+ except Exception as e:
372
+ logger.error(f"Error pushing to Salesforce: {str(e)}")
373
+ return f"Error pushing to Salesforce: {str(e)}"
374
+
375
+ def generate_sample_csv():
376
+ sample_data = {
377
+ 'Date': ['2025-06-12', '2025-06-12', '2025-06-12', '2025-06-12'],
378
+ 'Attendance': [10, 15, 20, 12],
379
+ 'Trade': ['Painter', 'Electrician', 'Carpenter', 'Plumber'],
380
+ 'Weather': ['Sunny', 'Rainy', 'Cloudy', 'Sunny']
381
+ }
382
+ df = pd.DataFrame(sample_data)
383
+ buffer = io.StringIO()
384
+ df.to_csv(buffer, index=False, encoding='utf-8')
385
+ csv_base64 = base64.b64encode(buffer.getvalue().encode('utf-8')).decode()
386
+ return csv_base64
387
+
388
+ # Main forecast function
389
+ def forecast_labour(csv_file, trade_filter=None, site_calendar_date=None):
390
+ try:
391
+ logger.info("Starting forecast process")
392
+ if csv_file is None:
393
+ return "Error: No CSV file uploaded", None, None, None, "No notifications available."
394
+
395
+ # Validate site calendar date format
396
+ try:
397
+ if not site_calendar_date:
398
+ raise ValueError("Site calendar date is required")
399
+ logger.info(f"Raw site_calendar_date input: '{site_calendar_date}'")
400
+ site_calendar_date = site_calendar_date.strip()
401
+ try:
402
+ site_calendar_date = pd.to_datetime(site_calendar_date, format='%Y-%m-%d')
403
+ except ValueError as strict_error:
404
+ logger.warning(f"Strict date parsing failed: {str(strict_error)}. Attempting mixed format parsing.")
405
+ site_calendar_date = pd.to_datetime(site_calendar_date, format='mixed', dayfirst=True, errors='coerce')
406
+ if pd.isna(site_calendar_date):
407
+ raise ValueError("Invalid site calendar date format. Use YYYY-MM-DD (e.g., 2025-06-13)")
408
+ except ValueError as e:
409
+ logger.error(f"Date validation error: {str(e)}")
410
+ return f"Error: {str(e)}", None, None, None, "No notifications available."
411
+
412
+ logger.info(f"Processing CSV file: {csv_file}")
413
+ df, error = process_csv(csv_file)
414
+ if error:
415
+ return error, None, None, None, "No notifications available."
416
+
417
+ unique_trades = df['Trade'].dropna().unique()
418
+ logger.info(f"Unique trades in CSV: {list(unique_trades)}")
419
+
420
+ if trade_filter and trade_filter.strip():
421
+ selected_trades = [t.strip() for t in trade_filter.split(',') if t.strip()]
422
+ selected_trades = [t for t in selected_trades if any(t.lower() == ut.lower() for ut in unique_trades)]
423
+ if not selected_trades:
424
+ logger.warning(f"No valid trades found in filter: {trade_filter}. Defaulting to all trades.")
425
+ selected_trades = unique_trades
426
+ else:
427
+ logger.info("Trade filter empty. Using all trades.")
428
+ selected_trades = unique_trades
429
 
430
+ logger.info(f"Selected trades: {list(selected_trades)}")
431
+
432
+ trade_results = {}
433
+ predictions_dict = {}
434
+ shortage_probs_dict = {}
435
+ alert_statuses = {}
436
+ errors = []
437
+
438
+ for trade in selected_trades:
439
+ trade_df = df[df['Trade'].str.lower() == trade.lower()]
440
+ date_match = trade_df[trade_df['Date'] == site_calendar_date]
441
+ if date_match.empty:
442
+ errors.append(f"No data for trade {trade} on {site_calendar_date.strftime('%Y-%m-%d')}")
443
+ continue
444
+ if len(date_match) > 1:
445
+ errors.append(f"Warning: Multiple rows for trade {trade} on {site_calendar_date.strftime('%Y-%m-%d')}")
446
+
447
+ predictions, shortage_probs, site_calendar, suggested_actions, alert_status, forecast_error = weighted_moving_average_forecast(df, trade, site_calendar_date)
448
+ if forecast_error:
449
+ errors.append(forecast_error)
450
+ continue
451
+
452
+ predictions_dict[trade] = predictions
453
+ shortage_probs_dict[trade] = shortage_probs
454
+ alert_statuses[trade] = alert_status
455
+ record = date_match.iloc[0]
456
+
457
+ result_data = {
458
+ 'Date': site_calendar_date,
459
+ 'Trade': trade,
460
+ 'Weather': record['Weather'],
461
+ 'Attendance': record['Attendance'],
462
+ 'Forecast_Next_3_Days__c': predictions,
463
+ 'Shortage_risk': shortage_probs,
464
+ 'Suggested_actions': suggested_actions,
465
+ 'Alert_status': alert_status
466
+ }
467
+
468
+ trade_results[trade] = result_data
469
+
470
+ if not trade_results:
471
+ error_msg = "No valid trade data processed"
472
+ if errors:
473
+ error_msg += f". Errors: {'; '.join(errors)}"
474
+ return error_msg, None, None, None, "No notifications available."
475
+
476
+ sf = connect_to_salesforce()
477
+ sf_error = push_to_salesforce(sf, trade_results, site_calendar_date)
478
+ if sf_error:
479
+ errors.append(sf_error)
480
+
481
+ line_chart = create_chart(df, predictions_dict)
482
+ heatmap = create_heatmap(df, predictions_dict, shortage_probs_dict, site_calendar_date)
483
+ pdf_summary = generate_pdf_summary(trade_results)
484
+
485
+ formatted_output, formatted_notifications = format_output(trade_results, site_calendar_date)
486
+
487
+ error_msg = "; ".join(errors) if errors else None
488
+ final_output = formatted_output + (f"\nWarnings: {error_msg}" if error_msg else "")
489
+
490
+ return (
491
+ final_output,
492
+ line_chart,
493
+ heatmap,
494
+ f'<a href="data:application/pdf;base64,{pdf_summary}" download="summary.pdf">Download Summary PDF</a>',
495
+ formatted_notifications
496
+ )
497
+ except Exception as e:
498
+ logger.error(f"Unexpected error in forecast: {str(e)}", exc_info=True)
499
+ return f"Error processing file: {str(e)}", None, None, None, "No notifications available."
500
+
501
+ # Gradio interface
502
+ def gradio_interface():
503
+ sample_csv = generate_sample_csv()
504
+ with gr.Blocks(theme=gr.themes.Soft()) as interface:
505
+ gr.Markdown("# Labour Attendance Forecast")
506
+ gr.Markdown("Upload a CSV with columns: Date, Attendance, Trade, Weather")
507
+ gr.Markdown("Enter trade names (e.g., 'Painter, Electrician') or leave blank for all trades")
508
+ gr.Markdown("Enter site calendar date (YYYY-MM-DD) for CSV data and 3-day forecast")
509
+ gr.Markdown(f'<a href="data:text/csv;base64,{sample_csv}" download="sample_labour_data.csv">Download Sample CSV</a>')
510
+
511
+ with gr.Row():
512
+ csv_input = gr.File(label="Upload CSV", file_types=[".csv"])
513
+ trade_input = gr.Textbox(label="Filter by Trades", placeholder="e.g., Painter, Electrician")
514
+ site_calendar_input = gr.Textbox(label="Site Calendar Date (YYYY-MM-DD)", placeholder="e.g., 2025-06-13")
515
+
516
+ forecast_button = gr.Button("Generate Forecast")
517
+ result_output = gr.Textbox(label="Forecast Result", lines=20)
518
+ line_chart_output = gr.Plot(label="Forecast Trendline")
519
+ heatmap_output = gr.Plot(label="Real-Time Shortage Risk Heatmap")
520
+ notification_output = gr.Textbox(label="Contractor Notifications", lines=5)
521
+ pdf_output = gr.HTML(label="Download Summary PDF")
522
+
523
+ forecast_button.click(
524
+ fn=forecast_labour,
525
+ inputs=[csv_input, trade_input, site_calendar_input],
526
+ outputs=[result_output, line_chart_output, heatmap_output, pdf_output, notification_output]
527
+ )
528
 
529
+ logger.info("Launching Gradio interface")
530
+ return interface
531
+
532
+ if __name__ == '__main__':
533
+ interface = gradio_interface()
534
+ interface.launch(share=False)