Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| from datetime import datetime | |
| from simple_salesforce import Salesforce | |
| from dotenv import load_dotenv | |
| import plotly.express as px | |
| # Load environment variables from .env | |
| load_dotenv() | |
| # Salesforce credentials | |
| SF_USERNAME = os.getenv('SF_USERNAME') | |
| SF_PASSWORD = os.getenv('SF_PASSWORD') | |
| SF_SECURITY_TOKEN = os.getenv('SF_SECURITY_TOKEN') | |
| # Connect to Salesforce | |
| try: | |
| sf = Salesforce( | |
| username=SF_USERNAME, | |
| password=SF_PASSWORD, | |
| security_token=SF_SECURITY_TOKEN | |
| ) | |
| except Exception as e: | |
| sf = None | |
| print(f"Error connecting to Salesforce: {str(e)}") | |
| # Function to fetch Project ID from Salesforce automatically | |
| def get_project_id(): | |
| if not sf: | |
| return None, "Salesforce connection failed. Check credentials." | |
| try: | |
| query = "SELECT Id FROM Project__c ORDER BY CreatedDate DESC LIMIT 1" | |
| result = sf.query(query) | |
| if result['totalSize'] > 0: | |
| return result['records'][0]['Id'], None | |
| return None, "No project found in Salesforce." | |
| except Exception as e: | |
| return None, f"Error fetching Project ID: {str(e)}" | |
| # Simple moving average forecast | |
| def simple_forecast(df): | |
| df['Date'] = pd.to_datetime(df['Date'], dayfirst=True) | |
| df['Forecast'] = df['Attendance'].rolling(window=3, min_periods=1).mean() | |
| future_dates = pd.date_range(df['Date'].max(), periods=4, freq='D')[1:] | |
| future_preds = np.repeat(df['Forecast'].iloc[-1], 3) | |
| predictions = [ | |
| {"date": date.strftime('%Y-%m-%d'), "headcount": round(pred)} | |
| for date, pred in zip(future_dates, future_preds) | |
| ] | |
| return predictions | |
| # Save record to Salesforce | |
| def save_to_salesforce(record): | |
| if not sf: | |
| return {"error": "Salesforce connection failed. Check credentials."} | |
| try: | |
| result = sf.Labour_Attendance_Forecast__c.create(record) | |
| return {"success": f"Record created successfully for {record['Trade__c']}", "record_id": result['id']} | |
| except Exception as e: | |
| return {"error": f"Error uploading data to Salesforce for {record['Trade__c']}: {str(e)}"} | |
| # Create line chart for multiple trades | |
| def create_chart(df, predictions_dict): | |
| combined_df = pd.DataFrame() | |
| for trade, predictions in predictions_dict.items(): | |
| trade_df = df[df['Trade'] == trade].copy() | |
| trade_df['Type'] = 'Historical' | |
| trade_df['Trade'] = trade | |
| forecast_df = pd.DataFrame(predictions) | |
| forecast_df['Date'] = pd.to_datetime(forecast_df['date']) | |
| forecast_df['Attendance'] = forecast_df['headcount'] | |
| forecast_df['Type'] = 'Forecast' | |
| forecast_df['Trade'] = trade | |
| combined_df = pd.concat([ | |
| combined_df, | |
| trade_df[['Date', 'Attendance', 'Type', 'Trade']], | |
| forecast_df[['Date', 'Attendance', 'Type', 'Trade']] | |
| ]) | |
| fig = px.line( | |
| combined_df, | |
| x='Date', | |
| y='Attendance', | |
| color='Trade', | |
| line_dash='Type', | |
| markers=True, | |
| title='Labour Attendance Forecast by Trade' | |
| ) | |
| return fig | |
| # Format output in bullet/line-by-line style for multiple trades | |
| def format_output(trade_results): | |
| exclude_keys = {'Project__c', 'record_id', 'success'} | |
| output = [] | |
| for trade, data in trade_results.items(): | |
| output.append(f"Trade: {trade}") | |
| for key, value in data.items(): | |
| if key in exclude_keys: | |
| continue | |
| if isinstance(value, list): | |
| value = ', '.join(str(item) for item in value) | |
| output.append(f" • {key}: {value}") | |
| output.append("") | |
| return "\n".join(output) | |
| # Forecast function for Gradio | |
| def forecast_labour(csv_file): | |
| try: | |
| encodings = ['utf-8', 'latin1', 'iso-8859-1', 'utf-16'] | |
| df = None | |
| for encoding in encodings: | |
| try: | |
| df = pd.read_csv(csv_file.name, encoding=encoding) | |
| break | |
| except UnicodeDecodeError: | |
| continue | |
| if df is None: | |
| return "Error: Could not decode CSV file with any supported encoding (utf-8, latin1, iso-8859-1, utf-16). Please ensure the file is properly encoded.", None | |
| df.columns = df.columns.str.strip().str.capitalize() | |
| required_columns = ['Date', 'Attendance', 'Trade', 'Weather', 'Alert_status', 'Shortage_risk', 'Suggested_actions'] | |
| missing_columns = [col for col in required_columns if col not in df.columns] | |
| if missing_columns: | |
| return f"Error: CSV missing required columns: {', '.join(missing_columns)}", None | |
| df['Date'] = pd.to_datetime(df['Date'], dayfirst=True) | |
| df['Attendance'] = df['Attendance'].astype(int) | |
| df['Shortage_risk'] = df['Shortage_risk'].replace('%', '', regex=True).astype(float) / 100 | |
| unique_trades = df['Trade'].unique() | |
| if len(unique_trades) < 10: | |
| return f"Error: CSV contains only {len(unique_trades)} trades, but a minimum of 10 trades is required.", None | |
| selected_trades = unique_trades[:10] | |
| trade_results = {} | |
| predictions_dict = {} | |
| project_id, error = get_project_id() | |
| if error: | |
| return f"Error: {error}", None | |
| for trade in selected_trades: | |
| trade_df = df[df['Trade'] == trade].copy() | |
| if trade_df.empty: | |
| continue | |
| predictions = simple_forecast(trade_df) | |
| predictions_dict[trade] = predictions | |
| latest_record = trade_df.sort_values(by='Date').iloc[-1] | |
| weather = latest_record['Weather'] | |
| alert_status = latest_record['Alert_status'] | |
| shortage_risk = latest_record['Shortage_risk'] | |
| suggested_actions = latest_record['Suggested_actions'] | |
| result_data = { | |
| "Title": f"Labour Attendance Data for {trade}", | |
| "Date": trade_df['Date'].max().strftime('%B %Y'), | |
| "Trade": trade, | |
| "Weather": weather, | |
| "Forecast": predictions, | |
| "Alert Status": alert_status, | |
| "Shortage_risk": shortage_risk, | |
| "Suggested_actions": suggested_actions, | |
| "Expected_headcount": predictions[0]['headcount'], | |
| "Actual_headcount": int(trade_df['Attendance'].iloc[-1]), | |
| "Forecast_Next_3_Days__c": predictions, | |
| "Project__c": project_id | |
| } | |
| salesforce_record = { | |
| 'Trade__c': trade, | |
| 'Shortage_Risk__c': shortage_risk, | |
| 'Suggested_Actions__c': suggested_actions, | |
| 'Expected_Headcount__c': result_data['Expected_headcount'], | |
| 'Actual_Headcount__c': result_data['Actual_headcount'], | |
| 'Forecast_Next_3_Days__c': str(predictions), | |
| 'Project_ID__c': project_id, | |
| 'Alert_Status__c': alert_status, | |
| 'Dashboard_Display__c': True, | |
| 'Date__c': trade_df['Date'].max().date().isoformat() | |
| } | |
| sf_result = save_to_salesforce(salesforce_record) | |
| result_data.update(sf_result) | |
| trade_results[trade] = result_data | |
| chart = create_chart(df, predictions_dict) | |
| return format_output(trade_results), chart | |
| except Exception as e: | |
| return f"Error processing file: {str(e)}", None | |
| # Gradio UI without share | |
| def gradio_interface(): | |
| gr.Interface( | |
| fn=forecast_labour, | |
| inputs=[ | |
| gr.File(label="Upload CSV with required columns for at least 10 trades") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Forecast Result", lines=20), | |
| gr.Plot(label="Forecast Chart") | |
| ], | |
| title="Labour Attendance Forecast", | |
| description="Upload a CSV file with columns: Date, Attendance, Trade, Weather, Alert_Status, Shortage_Risk (e.g. 22%), Suggested_Actions. The file must contain data for at least 10 trades. " | |
| ).launch(share=False) | |
| if __name__ == '__main__': | |
| gradio_interface() |