import os
import logging
import gradio as gr
import pandas as pd
import numpy as np
from datetime import datetime
import plotly.express as px
import plotly.graph_objects as go
import io
import base64
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from simple_salesforce import Salesforce
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('labour_forecast.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# Configuration class
class Config:
REQUIRED_COLUMNS = ['Date', 'Attendance', 'Trade', 'Weather']
ENCODINGS = ['utf-8', 'utf-8-sig', 'latin1', 'iso-8859-1', 'cp1252', 'utf-16']
FORECAST_DAYS = 3
WMA_WEIGHTS = {3: np.array([0.5, 0.3, 0.2]), 2: np.array([0.6, 0.4]), 1: np.array([1.0])}
WEATHER_IMPACT = {'Sunny': 0.0, 'Rainy': 0.25, 'Cloudy': 0.15, 'N/A': 0.05}
WEEKEND_ADJUSTMENT = 0.8
VALID_ALERT_STATUSES = {'Normal', 'Critical', 'Warning'}
# Salesforce connection
def connect_to_salesforce():
try:
sf = Salesforce(
username=os.getenv('SF_USERNAME'),
password=os.getenv('SF_PASSWORD'),
security_token=os.getenv('SF_SECURITY_TOKEN'),
domain=os.getenv('SF_DOMAIN', 'login')
)
logger.info("Successfully connected to Salesforce")
return sf
except Exception as e:
logger.error(f"Failed to connect to Salesforce: {str(e)}")
return None
# Data processing
def process_csv(file_path):
if not os.path.exists(file_path):
error_msg = f"File not found: {file_path}"
logger.error(error_msg)
return None, error_msg
for encoding in Config.ENCODINGS:
try:
df = pd.read_csv(file_path, encoding=encoding, dtype_backend='numpy_nullable')
df.columns = df.columns.str.strip().str.capitalize()
missing_columns = [col for col in Config.REQUIRED_COLUMNS if col not in df.columns]
if missing_columns:
raise ValueError(f"Missing columns: {', '.join(missing_columns)}")
df['Date'] = pd.to_datetime(df['Date'], dayfirst=True, errors='coerce')
if df['Date'].isna().all():
raise ValueError("All dates in CSV are invalid")
df['Attendance'] = pd.to_numeric(df['Attendance'], errors='coerce').fillna(0).astype('Int64')
for col in ['Weather', 'Trade']:
df[col] = df[col].astype(str).replace('nan', 'N/A')
logger.info(f"Successfully processed CSV with encoding '{encoding}'. Rows: {len(df)}. Attendance summary: {df['Attendance'].describe().to_dict()}")
return df, None
except Exception as e:
logger.warning(f"Failed with encoding '{encoding}': {str(e)}")
continue
error_msg = "Could not decode CSV file. Tried encodings: " + ", ".join(Config.ENCODINGS) + ". Ensure the file is a valid CSV and uses a supported encoding."
logger.error(error_msg)
return None, error_msg
# Forecasting logic with real-time adjustments
def weighted_moving_average_forecast(df, trade, site_calendar_date):
try:
site_calendar_date = pd.to_datetime(site_calendar_date)
trade_df = df[df['Trade'].str.lower() == trade.lower()].copy()
if trade_df.empty:
return [], [], None, 'N/A', 'Normal', f"No data found for trade: {trade}"
is_weekday = site_calendar_date.weekday() < 5
site_calendar = 1 if is_weekday else 0
# Filter data up to the selected date
trade_df = trade_df[trade_df['Date'] <= site_calendar_date]
recent_data = trade_df.tail(30)[['Date', 'Attendance', 'Weather']]
if recent_data.empty:
return [], [], None, 'N/A', 'Normal', f"No data for trade {trade} on or before {site_calendar_date.strftime('%Y-%m-%d')}"
# Real-time attendance for the selected date
current_day_data = trade_df[trade_df['Date'] == site_calendar_date]
if current_day_data.empty:
return [], [], None, 'N/A', 'Normal', f"No data for trade {trade} on {site_calendar_date.strftime('%Y-%m-%d')}"
current_attendance = current_day_data['Attendance'].iloc[0]
current_weather = current_day_data['Weather'].iloc[0]
predictions = []
shortage_probs = []
suggested_actions = 'Monitor'
alert_status = 'Normal'
future_dates = pd.date_range(site_calendar_date, periods=Config.FORECAST_DAYS + 1, freq='D')[1:]
# Calculate historical metrics
attendance_mean = recent_data['Attendance'].mean()
attendance_trend = recent_data['Attendance'].pct_change().mean() if len(recent_data) > 1 else 0
attendance_volatility = recent_data['Attendance'].pct_change().std() if len(recent_data) > 1 else 0.1
logger.info(f"Trade: {trade}, Mean Attendance: {attendance_mean}, Trend: {attendance_trend}, Volatility: {attendance_volatility}")
scale_factor = 1.0 if attendance_mean == 0 else min(100 / attendance_mean, 2.0)
# Calculate shortage risk for the current day
weather_impact = Config.WEATHER_IMPACT.get(current_weather, 0.05)
expected_attendance = attendance_mean * (1 - weather_impact) * (1 if site_calendar == 1 else Config.WEEKEND_ADJUSTMENT)
shortage_ratio = 1 - (current_attendance / expected_attendance) if expected_attendance > 0 else 0
shortage_prob = 0.5 + (shortage_ratio * 0.5) + (weather_impact * 0.3) + (attendance_trend * 0.2 * scale_factor)
shortage_prob = min(max(shortage_prob, 0.4), 0.9)
shortage_probs.append(shortage_prob)
if shortage_prob > 0.7:
suggested_actions = 'Urgent hiring needed'
alert_status = 'Critical'
elif shortage_prob > 0.5:
suggested_actions = 'Reschedule tasks'
alert_status = 'Warning'
else:
suggested_actions = 'Monitor'
alert_status = 'Normal'
# Forecast for future days
for i, date in enumerate(future_dates):
recent_attendance = recent_data['Attendance'].values[-3:]
weights = Config.WMA_WEIGHTS.get(len(recent_attendance), Config.WMA_WEIGHTS[1])
forecast_value = np.average(recent_attendance, weights=weights)
weather_idx = i % len(recent_data['Weather'])
future_weather = Config.WEATHER_IMPACT.get(recent_data['Weather'].iloc[-weather_idx-1], 0.05)
forecast_value *= (1 - future_weather)
headcount = round(forecast_value * (1 if site_calendar == 1 else Config.WEEKEND_ADJUSTMENT))
base_prob = 0.5 + (attendance_trend * 0.5 * scale_factor)
day_adjustment = (i + 1) * 0.02 * attendance_volatility
weather_adjustment = future_weather * 0.3
future_shortage_prob = base_prob + day_adjustment + weather_adjustment
future_shortage_prob = min(max(future_shortage_prob * 0.7 + 0.3, 0.5), 0.7)
shortage_probs.append(future_shortage_prob)
predictions.append({"date": date.strftime('%Y-%m-%d'), "headcount": headcount})
logger.info(f"Trade: {trade}, Shortage Probabilities: {shortage_probs}")
site_calendar_value = site_calendar_date.strftime('%Y-%m-%d') + f" ({'Weekday' if is_weekday else 'Weekend'})"
logger.info(f"Forecast generated for trade: {trade}")
return predictions, shortage_probs, site_calendar_value, suggested_actions, alert_status, None
except Exception as e:
logger.error(f"Forecast error for trade {trade}: {str(e)}")
return [], [], None, 'N/A', 'Normal', f"Forecast error: {str(e)}"
# Real-time shortage risk heatmap for the selected day
def create_heatmap(df, predictions_dict, shortage_probs_dict, site_calendar_date):
try:
site_calendar_date = pd.to_datetime(site_calendar_date)
heatmap_data = []
# Extend to 6 days to match the screenshot (2025-04-24 to 2025-04-29)
future_dates = pd.date_range(site_calendar_date, periods=6, freq='D')
for trade in predictions_dict.keys():
probs = shortage_probs_dict.get(trade, [0.5] * len(future_dates))
for i, date in enumerate(future_dates):
# Use the shortage probability for the current day (index 0) and future days
prob = probs[i] if i < len(probs) else probs[-1] # Fallback to last prob if not enough data
heatmap_data.append({
'Date': date.strftime('%Y-%m-%d'),
'Trade': trade,
'Shortage_Probability': prob
})
heatmap_df = pd.DataFrame(heatmap_data)
if heatmap_df.empty:
return go.Figure().update_layout(title="Shortage Risk Heatmap (No Data)")
display_probs = heatmap_df['Shortage_Probability'] * 100
# Custom colorscale adjusted to make 46% and 50% red
custom_colorscale = [
[0, 'red'], # 0 maps to red
[0.001, '#1f77b4'], # Slightly above 0 starts with a blue shade
[0.45, '#1f77b4'], # Keep blue until just before 46%
[0.46, 'red'], # 46% maps to red
[0.47, '#1f77b4'], # Back to blue after 46%
[0.49, '#1f77b4'], # Keep blue until just before 50%
[0.5, 'red'], # 50% maps to red
[0.51, '#aec7e8'], # Resume the original transition
[1, '#08306b'] # Dark blue at 1
]
fig = go.Figure(data=go.Heatmap(
x=heatmap_df['Date'],
y=heatmap_df['Trade'],
z=heatmap_df['Shortage_Probability'],
colorscale=custom_colorscale,
zmin=0, zmax=1,
text=display_probs.round(0).astype(int).astype(str) + '%',
texttemplate="%{text}",
textfont={"size": 14, "color": "black"},
hovertemplate="Trade: %{y}
Date: %{x}
Shortage Risk: %{text}",
colorbar=dict(
title="Shortage Risk",
tickvals=[0, 0.5, 1],
ticktext=["0%", "50%", "100%"]
)
))
fig.update_layout(
title="Shortage Risk Heatmap",
xaxis_title="Date",
yaxis_title="Trade",
xaxis=dict(
tickangle=45,
tickformat="%Y-%m-%d",
showgrid=False
),
yaxis=dict(
autorange="reversed",
showgrid=False
),
font=dict(size=14),
margin=dict(l=100, r=50, t=100, b=100),
plot_bgcolor="white",
paper_bgcolor="white",
showlegend=False
)
return fig
except Exception as e:
logger.error(f"Error creating heatmap: {str(e)}")
return go.Figure().update_layout(title=f"Error in Heatmap: {str(e)}")
def create_chart(df, predictions_dict):
try:
combined_df = pd.DataFrame()
for trade, predictions in predictions_dict.items():
trade_df = df[df['Trade'].str.lower() == trade.lower()][['Date', 'Attendance']].copy()
trade_df['Type'] = 'Historical'
trade_df['Trade'] = trade
forecast_df = pd.DataFrame(predictions)
if not forecast_df.empty:
forecast_df['Date'] = pd.to_datetime(forecast_df['date'])
forecast_df['Attendance'] = forecast_df['headcount']
forecast_df['Type'] = 'Forecast'
forecast_df['Trade'] = trade
combined_df = pd.concat([combined_df, trade_df, forecast_df[['Date', 'Attendance', 'Type', 'Trade']]])
if combined_df.empty:
return go.Figure().update_layout(title="Labour Attendance Forecast (No Data)")
fig = px.line(
combined_df,
x='Date',
y='Attendance',
color='Trade',
line_dash='Type',
markers=True,
title='Labour Attendance Forecast by Trade'
)
return fig
except Exception as e:
logger.error(f"Error creating chart: {str(e)}")
return go.Figure().update_layout(title=f"Error in Chart: {str(e)}")
def generate_pdf_summary(trade_results):
try:
buffer = io.BytesIO()
with PdfPages(buffer) as pdf:
fig, ax = plt.subplots(figsize=(10, 6))
if not trade_results:
ax.text(0.1, 0.5, "No data available for summary", fontsize=12)
else:
for i, (trade, data) in enumerate(trade_results.items()):
ax.text(0.1, 0.9 - 0.1*i,
f"{trade}: {data['Attendance']} (Actual), Shortage Risk: {data['Shortage_risk'][0]*100:.0f}%", fontsize=12)
ax.set_title("Weekly Labour Forecast Summary")
ax.axis('off')
pdf.savefig()
plt.close()
pdf_base64 = base64.b64encode(buffer.getvalue()).decode()
logger.info("PDF summary generated")
return pdf_base64
except Exception as e:
logger.error(f"Error generating PDF: {str(e)}")
return None
def format_output(trade_results, site_calendar_date):
output_columns = Config.REQUIRED_COLUMNS + ['Forecast_Next_3_Days__c', 'Shortage_risk', 'Suggested_actions', 'Alert_status']
output = []
notifications = []
for trade, data in trade_results.items():
output.append(f"Trade: {trade}")
for key in output_columns:
if key == 'Date':
value = pd.to_datetime(site_calendar_date).strftime('%Y-%m-%d') if pd.notna(site_calendar_date) else 'N/A'
elif key == 'Forecast_Next_3_Days__c':
value = ', '.join([f"{item['date']}: {item['headcount']}" for item in data.get(key, [])]) if data.get(key) else 'N/A'
else:
value = data.get(key, 'N/A')
if key in ['Weather', 'Alert_status', 'Suggested_actions', 'Trade'] and value is not None:
value = str(value)
elif key == 'Shortage_risk' and value is not None:
value = str(round(value[0], 2))
elif key == 'Attendance' and value is not None:
value = str(int(value))
output.append(f" • {key}: {value}")
alert_status = data.get('Alert_status', 'Normal')
suggested_actions = data.get('Suggested_actions', 'Monitor')
if alert_status == 'Critical':
notification = f"Urgent Alert for {trade}: {suggested_actions} due to high shortage risk of {round(data.get('Shortage_risk', [0])[0] * 100)}%."
elif alert_status == 'Warning':
notification = f"Warning for {trade}: {suggested_actions} due to moderate shortage risk of {round(data.get('Shortage_risk', [0])[0] * 100)}%."
else:
notification = f"Notice for {trade}: {suggested_actions}, shortage risk is low at {round(data.get('Shortage_risk', [0])[0] * 100)}%."
notifications.append(notification)
output.append("")
formatted_output = "\n".join(output) if trade_results else "No valid trade data available."
formatted_notifications = "Contractor Notifications:\n" + "\n".join([f" • {notification}" for notification in notifications]) if notifications else "No notifications available."
return formatted_output, formatted_notifications
def push_to_salesforce(sf, trade_results, site_calendar_date):
try:
if sf is None:
return "Salesforce connection not established"
records_to_upsert = []
for trade, data in trade_results.items():
forecast_json = ', '.join([f"{item['date']}: {item['headcount']}" for item in data.get('Forecast_Next_3_Days__c', [])])
record = {
'Trade__c': trade,
'Date__c': site_calendar_date.strftime('%Y-%m-%d'),
'Expected_Headcount__c': int(data['Attendance']),
'Actual_Headcount__c': int(data['Attendance']),
'Forecast_Next_3_Days__c': forecast_json,
'Shortage_Risk__c': float(data['Shortage_risk'][0]),
'Suggested_Actions__c': str(data['Suggested_actions']),
'Alert_Status__c': str(data['Alert_status']),
'Dashboard_Display__c': True
}
records_to_upsert.append(record)
for record in records_to_upsert:
sf.Labour_Attendance_Forecast__c.create(record)
logger.info(f"Successfully pushed {len(records_to_upsert)} records to Salesforce")
return None
except Exception as e:
logger.error(f"Error pushing to Salesforce: {str(e)}")
return f"Error pushing to Salesforce: {str(e)}"
def generate_sample_csv():
sample_data = {
'Date': ['2025-06-12', '2025-06-12', '2025-06-12', '2025-06-12'],
'Attendance': [10, 15, 20, 12],
'Trade': ['Painter', 'Electrician', 'Carpenter', 'Plumber'],
'Weather': ['Sunny', 'Rainy', 'Cloudy', 'Sunny']
}
df = pd.DataFrame(sample_data)
buffer = io.StringIO()
df.to_csv(buffer, index=False, encoding='utf-8')
csv_base64 = base64.b64encode(buffer.getvalue().encode('utf-8')).decode()
return csv_base64
# Main forecast function
def forecast_labour(csv_file, trade_filter=None, site_calendar_date=None):
try:
logger.info("Starting forecast process")
if csv_file is None:
return "Error: No CSV file uploaded", None, None, None, "No notifications available."
# Validate site calendar date format
try:
if not site_calendar_date:
raise ValueError("Site calendar date is required")
logger.info(f"Raw site_calendar_date input: '{site_calendar_date}'")
site_calendar_date = site_calendar_date.strip()
try:
site_calendar_date = pd.to_datetime(site_calendar_date, format='%Y-%m-%d')
except ValueError as strict_error:
logger.warning(f"Strict date parsing failed: {str(strict_error)}. Attempting mixed format parsing.")
site_calendar_date = pd.to_datetime(site_calendar_date, format='mixed', dayfirst=True, errors='coerce')
if pd.isna(site_calendar_date):
raise ValueError("Invalid site calendar date format. Use YYYY-MM-DD (e.g., 2025-06-13)")
except ValueError as e:
logger.error(f"Date validation error: {str(e)}")
return f"Error: {str(e)}", None, None, None, "No notifications available."
logger.info(f"Processing CSV file: {csv_file}")
df, error = process_csv(csv_file)
if error:
return error, None, None, None, "No notifications available."
unique_trades = df['Trade'].dropna().unique()
logger.info(f"Unique trades in CSV: {list(unique_trades)}")
if trade_filter and trade_filter.strip():
selected_trades = [t.strip() for t in trade_filter.split(',') if t.strip()]
selected_trades = [t for t in selected_trades if any(t.lower() == ut.lower() for ut in unique_trades)]
if not selected_trades:
logger.warning(f"No valid trades found in filter: {trade_filter}. Defaulting to all trades.")
selected_trades = unique_trades
else:
logger.info("Trade filter empty. Using all trades.")
selected_trades = unique_trades
logger.info(f"Selected trades: {list(selected_trades)}")
trade_results = {}
predictions_dict = {}
shortage_probs_dict = {}
alert_statuses = {}
errors = []
for trade in selected_trades:
trade_df = df[df['Trade'].str.lower() == trade.lower()]
date_match = trade_df[trade_df['Date'] == site_calendar_date]
if date_match.empty:
errors.append(f"No data for trade {trade} on {site_calendar_date.strftime('%Y-%m-%d')}")
continue
if len(date_match) > 1:
errors.append(f"Warning: Multiple rows for trade {trade} on {site_calendar_date.strftime('%Y-%m-%d')}")
predictions, shortage_probs, site_calendar, suggested_actions, alert_status, forecast_error = weighted_moving_average_forecast(df, trade, site_calendar_date)
if forecast_error:
errors.append(forecast_error)
continue
predictions_dict[trade] = predictions
shortage_probs_dict[trade] = shortage_probs
alert_statuses[trade] = alert_status
record = date_match.iloc[0]
result_data = {
'Date': site_calendar_date,
'Trade': trade,
'Weather': record['Weather'],
'Attendance': record['Attendance'],
'Forecast_Next_3_Days__c': predictions,
'Shortage_risk': shortage_probs,
'Suggested_actions': suggested_actions,
'Alert_status': alert_status
}
trade_results[trade] = result_data
if not trade_results:
error_msg = "No valid trade data processed"
if errors:
error_msg += f". Errors: {'; '.join(errors)}"
return error_msg, None, None, None, "No notifications available."
sf = connect_to_salesforce()
sf_error = push_to_salesforce(sf, trade_results, site_calendar_date)
if sf_error:
errors.append(sf_error)
line_chart = create_chart(df, predictions_dict)
heatmap = create_heatmap(df, predictions_dict, shortage_probs_dict, site_calendar_date)
pdf_summary = generate_pdf_summary(trade_results)
formatted_output, formatted_notifications = format_output(trade_results, site_calendar_date)
error_msg = "; ".join(errors) if errors else None
final_output = formatted_output + (f"\nWarnings: {error_msg}" if error_msg else "")
return (
final_output,
line_chart,
heatmap,
f'Download Summary PDF',
formatted_notifications
)
except Exception as e:
logger.error(f"Unexpected error in forecast: {str(e)}", exc_info=True)
return f"Error processing file: {str(e)}", None, None, None, "No notifications available."
# Gradio interface
def gradio_interface():
sample_csv = generate_sample_csv()
with gr.Blocks(theme=gr.themes.Soft()) as interface:
gr.Markdown("# Labour Attendance Forecast")
gr.Markdown("Upload a CSV with columns: Date, Attendance, Trade, Weather")
gr.Markdown("Enter trade names (e.g., 'Painter, Electrician') or leave blank for all trades")
gr.Markdown("Enter site calendar date (YYYY-MM-DD) for CSV data and 3-day forecast")
gr.Markdown(f'Download Sample CSV')
with gr.Row():
csv_input = gr.File(label="Upload CSV", file_types=[".csv"])
trade_input = gr.Textbox(label="Filter by Trades", placeholder="e.g., Painter, Electrician")
site_calendar_input = gr.Textbox(label="Site Calendar Date (YYYY-MM-DD)", placeholder="e.g., 2025-06-13")
forecast_button = gr.Button("Generate Forecast")
result_output = gr.Textbox(label="Forecast Result", lines=20)
line_chart_output = gr.Plot(label="Forecast Trendline")
heatmap_output = gr.Plot(label="Real-Time Shortage Risk Heatmap")
notification_output = gr.Textbox(label="Contractor Notifications", lines=5)
pdf_output = gr.HTML(label="Download Summary PDF")
forecast_button.click(
fn=forecast_labour,
inputs=[csv_input, trade_input, site_calendar_input],
outputs=[result_output, line_chart_output, heatmap_output, pdf_output, notification_output]
)
logger.info("Launching Gradio interface")
return interface
if __name__ == '__main__':
interface = gradio_interface()
interface.launch(share=False)