Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| from datetime import datetime | |
| from simple_salesforce import Salesforce | |
| from dotenv import load_dotenv | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import io | |
| import base64 | |
| from matplotlib.backends.backend_pdf import PdfPages | |
| import matplotlib.pyplot as plt | |
| # Load environment variables | |
| load_dotenv() | |
| # Salesforce credentials | |
| SF_USERNAME = os.getenv('SF_USERNAME') | |
| SF_PASSWORD = os.getenv('SF_PASSWORD') | |
| SF_SECURITY_TOKEN = os.getenv('SF_SECURITY_TOKEN') | |
| # Connect to Salesforce | |
| try: | |
| sf = Salesforce( | |
| username=SF_USERNAME, | |
| password=SF_PASSWORD, | |
| security_token=SF_SECURITY_TOKEN | |
| ) | |
| except Exception as e: | |
| sf = None | |
| print(f"Error connecting to Salesforce: {str(e)}") | |
| # Weighted moving average forecast with heuristic shortage probability | |
| def weighted_moving_average_forecast(df, trade, site_calendar_date): | |
| df['Date'] = pd.to_datetime(df['Date'], format='%Y-%m-%d', errors='coerce').dt.date | |
| trade_df = df[df['Trade'] == trade].copy() | |
| if trade_df.empty: | |
| return [], 0.5, None, f"No data found for trade: {trade}" | |
| # Parse site calendar date | |
| try: | |
| site_calendar_date = pd.to_datetime(site_calendar_date, format='%Y-%m-%d').date() | |
| is_weekday = site_calendar_date.weekday() < 5 | |
| site_calendar = 1 if is_weekday else 0 | |
| except ValueError: | |
| return [], 0.5, None, f"Invalid site calendar date: {site_calendar_date}" | |
| # Check for data on the next 3 days | |
| future_dates = pd.date_range(site_calendar_date, periods=4, freq='D')[1:] | |
| predictions = [] | |
| shortage_prob = 0.5 # Default shortage probability | |
| # Filter data up to and including site_calendar_date for historical context | |
| trade_df = trade_df[trade_df['Date'] <= site_calendar_date] | |
| recent_data = trade_df.tail(30)[['Date', 'Attendance', 'Weather', 'Shortage_risk']] | |
| if recent_data.empty: | |
| return [], 0.5, None, f"No data available for trade {trade} on or before {site_calendar_date}" | |
| # Check if future dates exist in CSV | |
| for date in future_dates: | |
| date = date.date() # Normalize to date-only | |
| future_data = df[(df['Trade'] == trade) & (df['Date'] == date)] | |
| if not future_data.empty: | |
| # Use CSV data if available | |
| record = future_data.iloc[0] | |
| headcount = int(record['Attendance']) if pd.notna(record['Attendance']) else 0 | |
| shortage_prob = record['Shortage_risk'] if pd.notna(record['Shortage_risk']) else 0.5 | |
| predictions.append({ | |
| "date": date.strftime('%Y-%m-%d'), | |
| "headcount": headcount | |
| }) | |
| else: | |
| # Fallback to weighted moving average | |
| recent_attendance = recent_data['Attendance'].values | |
| num_days = len(recent_attendance) | |
| if num_days >= 3: | |
| weights = np.array([0.5, 0.3, 0.2]) | |
| recent_attendance = recent_attendance[-3:] | |
| elif num_days == 2: | |
| weights = np.array([0.6, 0.4]) | |
| recent_attendance = recent_attendance[-2:] | |
| else: | |
| weights = np.array([1.0]) | |
| recent_attendance = recent_attendance[-1:] | |
| forecast_value = np.average(recent_attendance, weights=weights) | |
| latest_weather = recent_data['Weather'].map({'Sunny': 0, 'Rainy': 1, 'Cloudy': 0.5, np.nan: 0.5}).iloc[-1] | |
| forecast_value *= (1 - 0.1 * latest_weather) | |
| headcount = round(forecast_value * (1 if site_calendar == 1 else 0.8)) | |
| predictions.append({ | |
| "date": date.strftime('%Y-%m-%d'), | |
| "headcount": headcount | |
| }) | |
| # Use historical shortage risk for future dates if no CSV data | |
| shortage_prob = recent_data['Shortage_risk'].tail(30).mean() | |
| attendance_trend = recent_data['Attendance'].pct_change().mean() if num_days > 1 else 0 | |
| shortage_prob = min(max(shortage_prob + attendance_trend * 0.1, 0), 1) | |
| site_calendar_value = site_calendar_date.strftime('%Y-%m-%d') + f" ({'Weekday' if is_weekday else 'Weekend'})" | |
| return predictions, shortage_prob, site_calendar_value, None | |
| # Fetch Project ID from Salesforce | |
| def get_project_id(): | |
| if not sf: | |
| return None, "Salesforce connection failed." | |
| try: | |
| query = "SELECT Id FROM Project__c ORDER BY CreatedDate DESC LIMIT 1" | |
| result = sf.query(query) | |
| if result['totalSize'] > 0: | |
| return result['records'][0]['Id'], None | |
| return None, "No project found in Salesforce." | |
| except Exception as e: | |
| return None, f"Error fetching Project ID: {str(e)}" | |
| # Save to Salesforce | |
| def save_to_salesforce(record): | |
| if not sf: | |
| return {"error": "Salesforce connection failed."} | |
| try: | |
| result = sf.Labour_Attendance_Forecast__c.create(record) | |
| return {"success": f"Record created for {record['Trade__c']}", "record_id": result['id']} | |
| except Exception as e: | |
| return {"error": f"Error uploading to Salesforce for {record['Trade__c']}: {str(e)}"} | |
| # Create heatmap for shortfall risk | |
| def create_heatmap(df, predictions_dict, shortage_probs, site_calendar_date): | |
| heatmap_data = [] | |
| site_calendar_date = pd.to_datetime(site_calendar_date, format='%Y-%m-%d').date() | |
| future_dates = pd.date_range(site_calendar_date, periods=4, freq='D')[1:] | |
| for trade in predictions_dict.keys(): | |
| # Get shortage risk for the specified date from CSV | |
| trade_df = df[(df['Trade'] == trade) & (df['Date'] == site_calendar_date)] | |
| if not trade_df.empty: | |
| prob = trade_df.iloc[0]['Shortage_risk'] if pd.notna(trade_df.iloc[0]['Shortage_risk']) else 0.5 | |
| heatmap_data.append({ | |
| 'Date': site_calendar_date.strftime('%Y-%m-%d'), | |
| 'Trade': trade, | |
| 'Shortage_Probability': prob | |
| }) | |
| # Get shortage probabilities for future dates | |
| for date in future_dates: | |
| date = date.date() | |
| future_data = df[(df['Trade'] == trade) & (df['Date'] == date)] | |
| if not future_data.empty: | |
| prob = future_data.iloc[0]['Shortage_risk'] if pd.notna(future_data.iloc[0]['Shortage_risk']) else 0.5 | |
| else: | |
| prob = shortage_probs.get(trade, 0.5) | |
| heatmap_data.append({ | |
| 'Date': date.strftime('%Y-%m-%d'), | |
| 'Trade': trade, | |
| 'Shortage_Probability': prob | |
| }) | |
| heatmap_df = pd.DataFrame(heatmap_data) | |
| if heatmap_df.empty: | |
| return go.Figure().update_layout(title="Shortage Risk Heatmap (No Data)") | |
| # Create heatmap with improved styling | |
| fig = go.Figure(data=go.Heatmap( | |
| x=heatmap_df['Date'], | |
| y=heatmap_df['Trade'], | |
| z=heatmap_df['Shortage_Probability'], | |
| colorscale='Blues', | |
| zmin=0, | |
| zmax=1, | |
| text=heatmap_df['Shortage_Probability'].round(2), | |
| texttemplate="%{text}", | |
| textfont={"size": 12}, | |
| colorbar=dict(title="Shortage Risk", tickvals=[0, 0.5, 1], ticktext=["0%", "50%", "100%"]) | |
| )) | |
| fig.update_layout( | |
| title="Shortage Risk Heatmap", | |
| xaxis_title="Date", | |
| yaxis_title="Trade", | |
| xaxis=dict(tickangle=45, tickformat="%Y-%m-%d"), | |
| yaxis=dict(autorange="reversed"), | |
| font=dict(size=14), | |
| margin=dict(l=100, r=50, t=100, b=100), | |
| plot_bgcolor="white", | |
| paper_bgcolor="white", | |
| showlegend=False, | |
| grid=dict(rows=1, columns=1) | |
| ) | |
| fig.update_xaxes(showgrid=True, gridcolor="lightgray") | |
| fig.update_yaxes(showgrid=True, gridcolor="lightgray") | |
| return fig | |
| # Create line chart for forecasts | |
| def create_chart(df, predictions_dict): | |
| combined_df = pd.DataFrame() | |
| for trade, predictions in predictions_dict.items(): | |
| trade_df = df[df['Trade'] == trade].copy() | |
| if trade_df.empty: | |
| continue | |
| trade_df['Type'] = 'Historical' | |
| trade_df['Trade'] = trade | |
| forecast_df = pd.DataFrame(predictions) | |
| if forecast_df.empty: | |
| continue | |
| forecast_df['Date'] = pd.to_datetime(forecast_df['date'], format='%Y-%m-%d').dt.date | |
| forecast_df['Attendance'] = forecast_df['headcount'] | |
| forecast_df['Type'] = 'Forecast' | |
| forecast_df['Trade'] = trade | |
| combined_df = pd.concat([ | |
| combined_df, | |
| trade_df[['Date', 'Attendance', 'Type', 'Trade']], | |
| forecast_df[['Date', 'Attendance', 'Type', 'Trade']] | |
| ]) | |
| if combined_df.empty: | |
| return go.Figure().update_layout(title="Labour Attendance Forecast (No Data)") | |
| fig = px.line( | |
| combined_df, | |
| x='Date', | |
| y='Attendance', | |
| color='Trade', | |
| line_dash='Type', | |
| markers=True, | |
| title='Labour Attendance Forecast by Trade' | |
| ) | |
| return fig | |
| # Generate PDF summary | |
| def generate_pdf_summary(trade_results, project_id): | |
| buffer = io.BytesIO() | |
| with PdfPages(buffer) as pdf: | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| if not trade_results: | |
| ax.text(0.1, 0.5, "No data available for summary", fontsize=12) | |
| else: | |
| for i, (trade, data) in enumerate(trade_results.items()): | |
| ax.text(0.1, 0.9 - 0.1*i, | |
| f"{trade}: {data['Attendance']} (Actual)", | |
| fontsize=12) | |
| ax.set_title(f"Weekly Summary for Project {project_id}") | |
| ax.axis('off') | |
| pdf.savefig() | |
| plt.close() | |
| pdf_base64 = base64.b64encode(buffer.getvalue()).decode() | |
| return pdf_base64 | |
| # Notify contractor (mock) | |
| def notify_contractor(trade, alert_status): | |
| return f"Notification sent to contractor for {trade} with alert status: {alert_status}" | |
| # Format output to display CSV file values and Forecast_Next_3_Days__c | |
| def format_output(trade_results, site_calendar_date): | |
| csv_columns = ['Date', 'Trade', 'Weather', 'Alert_status', 'Shortage_risk', 'Suggested_actions', 'Attendance', 'Forecast_Next_3_Days__c'] | |
| output = [] | |
| for trade, data in trade_results.items(): | |
| output.append(f"Trade: {trade}") | |
| for key in csv_columns: | |
| if key == 'Date': | |
| value = pd.to_datetime(site_calendar_date, format='%Y-%m-%d').strftime('%Y-%m-%d') if pd.notna(site_calendar_date) else 'N/A' | |
| elif key == 'Forecast_Next_3_Days__c': | |
| value = ', '.join([f"{item['date']}: {item['headcount']}" for item in data.get(key, [])]) if data.get(key) else 'N/A' | |
| else: | |
| value = data.get(key, 'N/A') | |
| if key in ['Weather', 'Alert_status', 'Suggested_actions', 'Trade'] and value is not None: | |
| value = str(value) | |
| elif key == 'Shortage_risk' and value is not None: | |
| value = str(round(value, 2)) | |
| elif key == 'Attendance' and value is not None: | |
| value = str(int(value)) | |
| output.append(f" • {key}: {value}") | |
| output.append("") | |
| return "\n".join(output) if trade_results else "No valid trade data available." | |
| # Gradio forecast function | |
| def forecast_labour(csv_file, trade_filter=None, site_calendar_date=None): | |
| try: | |
| encodings = ['utf-8', 'latin1', 'iso-8859-1', 'utf-16'] | |
| df = None | |
| for encoding in encodings: | |
| try: | |
| df = pd.read_csv(csv_file.name, encoding=encoding, dtype_backend='numpy_nullable') | |
| break | |
| except UnicodeDecodeError: | |
| continue | |
| if df is None: | |
| return "Error: Could not decode CSV file.", None, None, None, None | |
| df.columns = df.columns.str.strip().str.capitalize() | |
| required_columns = ['Date', 'Attendance', 'Trade', 'Weather', 'Alert_status', 'Shortage_risk', 'Suggested_actions'] | |
| missing_columns = [col for col in required_columns if col not in df.columns] | |
| if missing_columns: | |
| return f"Error: CSV missing columns: {', '.join(missing_columns)}", None, None, None, None | |
| # Parse dates with explicit format | |
| df['Date'] = pd.to_datetime(df['Date'], format='%Y-%m-%d', errors='coerce').dt.date | |
| if df['Date'].isna().all(): | |
| return "Error: All dates in CSV are invalid.", None, None, None, None | |
| df['Attendance'] = pd.to_numeric(df['Attendance'], errors='coerce').fillna(0).astype('Int64') | |
| df['Shortage_risk'] = df['Shortage_risk'].replace('%', '', regex=True) | |
| df['Shortage_risk'] = pd.to_numeric(df['Shortage_risk'], errors='coerce').fillna(0.5) / 100 | |
| df['Weather'] = df['Weather'].astype(str).replace('nan', 'N/A') | |
| df['Alert_status'] = df['Alert_status'].astype(str).replace('nan', 'N/A') | |
| df['Suggested_actions'] = df['Suggested_actions'].astype(str).replace('nan', 'N/A') | |
| df['Trade'] = df['Trade'].astype(str).replace('nan', 'N/A') | |
| unique_trades = df['Trade'].dropna().unique() | |
| if trade_filter: | |
| selected_trades = [t.strip() for t in trade_filter.split(',') if t.strip()] | |
| selected_trades = [t for t in selected_trades if t in unique_trades] | |
| if not selected_trades: | |
| return f"Error: None of the specified trades '{trade_filter}' found in CSV.", None, None, None, None | |
| else: | |
| selected_trades = unique_trades | |
| trade_results = {} | |
| predictions_dict = {} | |
| shortage_probs = {} | |
| errors = [] | |
| project_id, error = get_project_id() | |
| if error: | |
| return f"Error: {error}", None, None, None, None | |
| # Parse site_calendar_date with explicit format | |
| try: | |
| site_calendar_date = pd.to_datetime(site_calendar_date, format='%Y-%m-%d', errors='coerce').date() | |
| if pd.isna(site_calendar_date): | |
| raise ValueError(f"Invalid site calendar date: {site_calendar_date}") | |
| except ValueError as e: | |
| errors.append(str(e)) | |
| return f"Error: {e}", None, None, None, None | |
| for trade in selected_trades: | |
| trade_df = df[df['Trade'] == trade].copy() | |
| if trade_df.empty: | |
| errors.append(f"No data for trade: {trade}") | |
| continue | |
| # Debug: Print trade_df to verify data | |
| print(f"Trade: {trade}, Data for {site_calendar_date}:") | |
| print(trade_df[trade_df['Date'] == site_calendar_date]) | |
| date_match = trade_df[trade_df['Date'] == site_calendar_date] | |
| if date_match.empty: | |
| errors.append(f"No data found for trade {trade} on {site_calendar_date}") | |
| continue | |
| if len(date_match) > 1: | |
| errors.append(f"Warning: Multiple rows found for trade {trade} on {site_calendar_date}. Using first row.") | |
| predictions, shortage_prob, site_calendar, forecast_error = weighted_moving_average_forecast(trade_df, trade, site_calendar_date) | |
| if forecast_error: | |
| errors.append(forecast_error) | |
| continue | |
| predictions_dict[trade] = predictions | |
| shortage_probs[trade] = shortage_prob | |
| record = date_match.iloc[0] | |
| result_data = { | |
| 'Date': site_calendar_date, | |
| 'Trade': record['Trade'], | |
| 'Weather': record['Weather'], | |
| 'Alert_status': record['Alert_status'], | |
| 'Shortage_risk': record['Shortage_risk'], | |
| 'Suggested_actions': record['Suggested_actions'], | |
| 'Attendance': record['Attendance'], | |
| 'Forecast': predictions, | |
| 'Shortage_Probability': round(shortage_prob, 2), | |
| 'Forecast_Next_3_Days__c': predictions, | |
| 'Project__c': project_id | |
| } | |
| salesforce_record = { | |
| 'Trade__c': trade, | |
| 'Shortage_Risk__c': record['Shortage_risk'], | |
| 'Suggested_Actions__c': record['Suggested_actions'], | |
| 'Expected_Headcount__c': predictions[0]['headcount'] if predictions else 0, | |
| 'Actual_Headcount__c': int(record['Attendance']) if pd.notna(record['Attendance']) else 0, | |
| 'Forecast_Next_3_Days__c': str(predictions), | |
| 'Project_ID__c': project_id, | |
| 'Alert_Status__c': record['Alert_status'], | |
| 'Dashboard_Display__c': True, | |
| 'Date__c': pd.Timestamp(site_calendar_date).isoformat() | |
| } | |
| sf_result = save_to_salesforce(salesforce_record) | |
| result_data.update(sf_result) | |
| trade_results[trade] = result_data | |
| if not trade_results: | |
| error_msg = "No valid trade data processed for the specified date." | |
| if errors: | |
| error_msg += " Errors: " + "; ".join(errors) | |
| return error_msg, None, None, None, None | |
| line_chart = create_chart(df, predictions_dict) | |
| heatmap = create_heatmap(df, predictions_dict, shortage_probs, site_calendar_date) | |
| pdf_summary = generate_pdf_summary(trade_results, project_id) | |
| notification_trade = selected_trades[0] | |
| notification = notify_contractor(notification_trade, trade_results[notification_trade]['Alert_status']) | |
| error_msg = "; ".join(errors) if errors else None | |
| return ( | |
| format_output(trade_results, site_calendar_date) + (f"\nWarnings: {error_msg}" if error_msg else ""), | |
| line_chart, | |
| heatmap, | |
| f'<a href="data:application/pdf;base64,{pdf_summary}" download="summary.pdf">Download Summary PDF</a>', | |
| notification | |
| ) | |
| except Exception as e: | |
| return f"Error processing file: {str(e)}", None, None, None, None | |
| # Gradio UI | |
| def gradio_interface(): | |
| with gr.Blocks(theme=gr.themes.Soft()) as interface: | |
| gr.Markdown("# Labour Attendance Forecast") | |
| gr.Markdown("Upload a CSV with columns: Date, Attendance, Trade, Weather, Alert_Status, Shortage_Risk (e.g. 22%), Suggested_Actions.") | |
| gr.Markdown("Enter trade names (e.g., 'Painter, Electrician') separated by commas, or leave blank to process all trades.") | |
| gr.Markdown("Enter a specific date for the site calendar (YYYY-MM-DD) to display CSV data for that date and forecast the next 3 days.") | |
| with gr.Row(): | |
| csv_input = gr.File(label="Upload CSV") | |
| trade_input = gr.Textbox(label="Filter by Trades (e.g., Painter, Electrician)", placeholder="Enter trade names separated by commas or leave blank for all trades") | |
| site_calendar_input = gr.Textbox(label="Site Calendar Date (YYYY-MM-DD)", placeholder="e.g., 2025-05-24") | |
| forecast_button = gr.Button("Generate Forecast") | |
| result_output = gr.Textbox(label="Forecast Result", lines=20) | |
| line_chart_output = gr.Plot(label="Forecast Trendline") | |
| heatmap_output = gr.Plot(label="Shortage Risk Heatmap") | |
| pdf_output = gr.HTML(label="Download Summary PDF") | |
| notification_output = gr.Textbox(label="Contractor Notification") | |
| forecast_button.click( | |
| fn=forecast_labour, | |
| inputs=[csv_input, trade_input, site_calendar_input], | |
| outputs=[result_output, line_chart_output, heatmap_output, pdf_output, notification_output] | |
| ) | |
| interface.launch(share=False) | |
| if __name__ == '__main__': | |
| gradio_interface() |