pvk / app.py
prudhviLatha's picture
Create app.py
3f1ba1e verified
import os
import logging
import gradio as gr
import pandas as pd
import numpy as np
from datetime import datetime
import plotly.express as px
import plotly.graph_objects as go
import io
import base64
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from simple_salesforce import Salesforce
from typing import List, Dict, Tuple, Optional, Union
from pathlib import Path
from dataclasses import dataclass, field
from logging.handlers import RotatingFileHandler
from retrying import retry
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Configure logging with rotation
logger = logging.getLogger(__name__)
handler = RotatingFileHandler('labour_forecast.log', maxBytes=5*1024*1024, backupCount=3)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.INFO)
# Salesforce credentials from environment variables
SF_USERNAME = os.getenv('SF_USERNAME')
SF_PASSWORD = os.getenv('SF_PASSWORD')
SF_SECURITY_TOKEN = os.getenv('SF_SECURITY_TOKEN')
SF_DOMAIN = os.getenv('SF_DOMAIN', 'login')
# Verify Salesforce credentials
if not all([SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN]):
logger.error("One or more Salesforce credentials are missing.")
raise ValueError("Salesforce credentials not set. Please set SF_USERNAME, SF_PASSWORD, and SF_SECURITY_TOKEN in your .env file.")
# Configuration class with default_factory for mutable fields
@dataclass(frozen=True)
class Config:
REQUIRED_COLUMNS: List[str] = field(default_factory=lambda: ['Date', 'Attendance', 'Trade', 'Weather'])
OPTIONAL_COLUMNS: List[str] = field(default_factory=lambda: ['Alert_status'])
ENCODINGS: List[str] = field(default_factory=lambda: ['utf-8', 'latin1', 'iso-8859-1', 'utf-16'])
FORECAST_DAYS: int = 3
WMA_WEIGHTS: Dict[int, np.ndarray] = field(default_factory=lambda: {
3: np.array([0.5, 0.3, 0.2]),
2: np.array([0.6, 0.4]),
1: np.array([1.0])
})
WEATHER_IMPACT: Dict[str, float] = field(default_factory=lambda: {
'Sunny': 0,
'Rainy': 1,
'Cloudy': 0.5,
'N/A': 0.5
})
WEEKEND_ADJUSTMENT: float = 0.8
VALID_ALERT_STATUSES: set = frozenset({'Normal', 'Critical', 'Warning'})
ALERT_TO_SHORTAGE: Dict[str, float] = field(default_factory=lambda: {
'Normal': 0.0,
'Warning': 0.41,
'Critical': 0.71
})
CONFIG = Config()
# Salesforce connection with retry mechanism
@retry(stop_max_attempt_number=3, wait_fixed=2000)
def connect_to_salesforce() -> Optional[Salesforce]:
"""Establish connection to Salesforce with retry mechanism."""
try:
sf = Salesforce(
username=SF_USERNAME,
password=SF_PASSWORD,
security_token=SF_SECURITY_TOKEN,
domain=SF_DOMAIN
)
logger.info("Successfully connected to Salesforce")
return sf
except Exception as e:
logger.error(f"Failed to connect to Salesforce: {str(e)}")
raise
# Data processing with improved validation
def process_csv(file_path: Union[str, Path]) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
"""Read and preprocess CSV file with enhanced validation."""
file_path = Path(file_path) if isinstance(file_path, str) else file_path
for encoding in CONFIG.ENCODINGS:
try:
df = pd.read_csv(file_path, encoding=encoding, dtype_backend='numpy_nullable')
df.columns = df.columns.str.strip().str.capitalize()
missing_required = [col for col in CONFIG.REQUIRED_COLUMNS if col not in df.columns]
if missing_required:
raise ValueError(f"Missing required columns: {', '.join(missing_required)}")
df['Date'] = pd.to_datetime(df['Date'], dayfirst=True, errors='coerce')
if df['Date'].isna().all():
raise ValueError("All dates in CSV are invalid")
df['Attendance'] = pd.to_numeric(df['Attendance'], errors='coerce').fillna(0).astype('Int64')
df['Trade'] = df['Trade'].astype(str).replace('nan', 'N/A')
df['Weather'] = df['Weather'].astype(str).replace('nan', 'N/A')
if 'Alert_status' in df.columns:
df['Alert_status'] = df['Alert_status'].astype(str).replace('nan', 'Normal')
df['Alert_status'] = df['Alert_status'].apply(
lambda x: x if x in CONFIG.VALID_ALERT_STATUSES else 'Normal'
)
logger.info(f"Processed Alert_status values: {df['Alert_status'].unique()}")
else:
df['Alert_status'] = 'Normal'
logger.info("Alert_status column missing; defaulting to 'Normal'")
logger.info("CSV file processed successfully")
return df, None
except Exception as e:
logger.warning(f"Failed with encoding {encoding}: {str(e)}")
continue
return None, f"Could not decode CSV file: {file_path.name}"
# Forecasting logic with optimized performance
def weighted_moving_average_forecast(
df: pd.DataFrame,
trade: str,
site_calendar_date: pd.Timestamp
) -> Tuple[List[Dict], List[float], Optional[str], str, str, Optional[str]]:
"""Generate forecast using weighted moving average with day-specific shortage probabilities based on alert status."""
try:
trade_df = df[df['Trade'].str.lower() == trade.lower()].copy()
if trade_df.empty:
return [], [], None, 'N/A', 'Normal', f"No data found for trade: {trade}"
is_weekday = site_calendar_date.weekday() < 5
site_calendar = 1 if is_weekday else 0
trade_df = trade_df.sort_values('Date')
recent_data = trade_df.tail(30)[['Date', 'Attendance', 'Weather', 'Alert_status']].copy()
if recent_data.empty:
return [], [], None, 'N/A', 'Normal', f"No data for trade {trade} on or before {site_calendar_date.strftime('%Y-%m-%d')}"
predictions = []
shortage_probs = []
alert_statuses = []
future_dates = pd.date_range(site_calendar_date, periods=CONFIG.FORECAST_DAYS + 1, freq='D')[1:]
# Process the initial date (site_calendar_date)
date_match = trade_df[trade_df['Date'] == site_calendar_date]
if not date_match.empty:
alert_status = date_match['Alert_status'].iloc[0]
shortage_prob = CONFIG.ALERT_TO_SHORTAGE.get(alert_status, 0.0)
shortage_probs.append(shortage_prob)
alert_statuses.append(alert_status)
else:
shortage_probs.append(0.0)
alert_statuses.append('Normal')
# Process forecast dates
for i, date in enumerate(future_dates):
date_data = trade_df[trade_df['Date'] == date]
if not date_data.empty:
headcount = int(date_data['Attendance'].iloc[0])
weather = date_data['Weather'].iloc[0]
alert_status = date_data['Alert_status'].iloc[0]
shortage_prob = CONFIG.ALERT_TO_SHORTAGE.get(alert_status, 0.0)
else:
recent_attendance = recent_data['Attendance'].values[-3:]
weights = CONFIG.WMA_WEIGHTS.get(len(recent_attendance), CONFIG.WMA_WEIGHTS[1])
forecast_value = np.average(recent_attendance, weights=weights)
weather_idx = i % len(recent_data['Weather'])
weather = recent_data['Weather'].iloc[-weather_idx-1]
weather_impact = CONFIG.WEATHER_IMPACT.get(weather, 0.5)
forecast_value *= (1 - 0.1 * weather_impact)
headcount = round(forecast_value * (1 if site_calendar == 1 else CONFIG.WEEKEND_ADJUSTMENT))
# Use the most recent alert status for forecast dates if no data exists
recent_alert_idx = i % len(recent_data['Alert_status'])
alert_status = recent_data['Alert_status'].iloc[-recent_alert_idx-1]
shortage_prob = CONFIG.ALERT_TO_SHORTAGE.get(alert_status, 0.0)
predictions.append({
"date": date.strftime('%Y-%m-%d'),
"headcount": headcount,
"source": 'CSV' if not date_data.empty else 'WMA'
})
shortage_probs.append(shortage_prob)
alert_statuses.append(alert_status)
# Set suggested actions based on the alert status of the initial date
alert_status = alert_statuses[0]
if alert_status == 'Normal':
suggested_actions = 'No action needed'
elif alert_status == 'Warning':
suggested_actions = 'Review staffing levels'
elif alert_status == 'Critical':
suggested_actions = 'Urgent hiring needed'
else:
suggested_actions = 'Monitor'
site_calendar_value = site_calendar_date.strftime('%Y-%m-%d') + f" ({'Weekday' if is_weekday else 'Weekend'})"
logger.info(f"Forecast generated for trade: {trade}")
return predictions, shortage_probs, site_calendar_value, suggested_actions, alert_status, None
except Exception as e:
logger.error(f"Forecast error for trade {trade}: {str(e)}")
return [], [], None, 'N/A', 'Normal', f"Forecast error: {str(e)}"
# Visualization with improved readability
def create_heatmap(
df: pd.DataFrame,
predictions_dict: Dict,
shortage_probs_dict: Dict,
site_calendar_date: pd.Timestamp
) -> go.Figure:
"""Create a clean and clear heatmap with day-specific shortage probabilities."""
try:
heatmap_data = []
future_dates = pd.date_range(site_calendar_date, periods=CONFIG.FORECAST_DAYS + 1, freq='D')[1:]
for trade in predictions_dict.keys():
probs = shortage_probs_dict.get(trade, [0.0] * (CONFIG.FORECAST_DAYS + 1))
# Initial date
heatmap_data.append({
'Date': site_calendar_date.strftime('%Y-%m-%d'),
'Trade': trade,
'Shortage_Probability': probs[0]
})
# Forecasted dates
for i, date in enumerate(future_dates):
heatmap_data.append({
'Date': date.strftime('%Y-%m-%d'),
'Trade': trade,
'Shortage_Probability': probs[i + 1]
})
heatmap_df = pd.DataFrame(heatmap_data)
if heatmap_df.empty:
return go.Figure().update_layout(title="Shortage Risk Heatmap (No Data)")
display_probs = heatmap_df['Shortage_Probability'] * 100
custom_colorscale = [
[0.0, '#00FF00'], # Light green for Normal
[0.4, '#00FF00'],
[0.41, '#FFA500'], # Light orange for Warning
[0.7, '#FFA500'],
[0.71, '#FF0000'], # Bright red for Critical
[1.0, '#FF0000']
]
fig = go.Figure(data=go.Heatmap(
x=heatmap_df['Date'],
y=heatmap_df['Trade'],
z=heatmap_df['Shortage_Probability'],
colorscale=custom_colorscale,
zmin=0, zmax=1,
text=display_probs.round(1).astype(str) + '%',
texttemplate="%{text}",
textfont={"size": 18, "color": "black"},
hovertemplate="Trade: %{y}<br>Date: %{x}<br>Shortage Risk: %{text}<extra></extra>",
colorbar=dict(
title="Shortage Risk",
tickvals=[0, 0.41, 0.71],
ticktext=["0%", "41%", "71%"],
len=0.8,
thickness=20
)
))
fig.update_layout(
title="Shortage Risk Heatmap",
xaxis_title="Date",
yaxis_title="Trade",
xaxis=dict(tickangle=45, tickformat="%Y-%m-%d", gridcolor='black', gridwidth=1),
yaxis=dict(autorange="reversed", gridcolor='black', gridwidth=1),
font=dict(size=14),
margin=dict(l=100, r=50, t=80, b=100),
plot_bgcolor="white",
paper_bgcolor="white",
width=800,
height=400 + len(heatmap_df['Trade'].unique()) * 50
)
fig.update_traces(xgap=3, ygap=3)
return fig
except Exception as e:
logger.error(f"Error creating heatmap: {str(e)}")
return go.Figure().update_layout(title=f"Error in Heatmap: {str(e)}")
def create_chart(df: pd.DataFrame, predictions_dict: Dict) -> go.Figure:
"""Create a line chart for attendance forecasts."""
try:
combined_df = pd.DataFrame()
for trade, predictions in predictions_dict.items():
trade_df = df[df['Trade'].str.lower() == trade.lower()][['Date', 'Attendance']].copy()
trade_df['Type'] = 'Historical'
trade_df['Trade'] = trade
forecast_df = pd.DataFrame(predictions)
if not forecast_df.empty:
forecast_df['Date'] = pd.to_datetime(forecast_df['date'])
forecast_df['Attendance'] = forecast_df['headcount']
forecast_df['Type'] = 'Forecast'
forecast_df['Trade'] = trade
combined_df = pd.concat([combined_df, trade_df, forecast_df[['Date', 'Attendance', 'Type', 'Trade']]])
if combined_df.empty:
return go.Figure().update_layout(title="Labour Attendance Forecast (No Data)")
fig = px.line(
combined_df,
x='Date',
y='Attendance',
color='Trade',
line_dash='Type',
markers=True,
title='Labour Attendance Forecast by Trade'
)
return fig
except Exception as e:
logger.error(f"Error creating chart: {str(e)}")
return go.Figure().update_layout(title=f"Error in Chart: {str(e)}")
def generate_pdf_summary(trade_results: Dict) -> Optional[str]:
"""Generate a PDF summary of the forecast."""
try:
buffer = io.BytesIO()
with PdfPages(buffer) as pdf:
fig, ax = plt.subplots(figsize=(10, 6))
if not trade_results:
ax.text(0.1, 0.5, "No data available for summary", fontsize=12)
else:
for i, (trade, data) in enumerate(trade_results.items()):
ax.text(
0.1, 0.9 - 0.1*i,
f"{trade}: {data['Attendance']} (Actual), Shortage Risk: {data['Shortage_risk'][0]*100:.0f}%, Alert: {data['Alert_status']}",
fontsize=12
)
ax.set_title("Weekly Labour Forecast Summary")
ax.axis('off')
pdf.savefig()
plt.close()
pdf_base64 = base64.b64encode(buffer.getvalue()).decode()
logger.info("PDF summary generated")
return pdf_base64
except Exception as e:
logger.error(f"Error generating PDF: {str(e)}")
return None
def format_output(trade_results: Dict, site_calendar_date: pd.Timestamp) -> Tuple[str, str]:
"""Format the output for display and return contractor notifications."""
output = []
notifications = []
for trade, data in trade_results.items():
output.append(f"Trade: {trade}")
for key in CONFIG.REQUIRED_COLUMNS + ['Alert_status', 'Forecast_Next_3_Days__c', 'Shortage_risk', 'Suggested_actions']:
if key == 'Date':
value = site_calendar_date.strftime('%Y-%m-%d')
elif key == 'Forecast_Next_3_Days__c':
value = ', '.join([f"{item['date']}: {item['headcount']}" for item in data.get(key, [])]) or 'N/A'
else:
value = data.get(key, 'N/A')
if key in ['Weather', 'Alert_status', 'Suggested_actions', 'Trade']:
value = str(value)
elif key == 'Shortage_risk':
value = str(round(value[0], 2))
elif key == 'Attendance':
value = str(int(value))
output.append(f" • {key}: {value}")
alert_status = data.get('Alert_status', 'Normal')
shortage_risk = data.get('Shortage_risk', [0])[0] * 100
suggested_actions = data.get('Suggested_actions', 'No action needed')
if alert_status == 'Critical':
notification = f"Urgent Alert for {trade}: {suggested_actions} due to high shortage risk of {round(shortage_risk)}%."
elif alert_status == 'Warning':
notification = f"Warning for {trade}: {suggested_actions} due to moderate shortage risk of {round(shortage_risk)}%."
else:
notification = f"Notice for {trade}: {suggested_actions}, shortage risk is low at {round(shortage_risk)}%."
notifications.append(notification)
output.append("")
formatted_output = "\n".join(output) or "No valid trade data available."
formatted_notifications = "Contractor Notifications:\n" + "\n".join([f" • {notification}" for notification in notifications]) or "No notifications available."
return formatted_output, formatted_notifications
def push_to_salesforce(sf: Optional[Salesforce], trade_results: Dict, site_calendar_date: pd.Timestamp) -> Optional[str]:
"""Push forecast results to Salesforce Labour_Attendance_Forecast__c object."""
try:
if sf is None:
return "Salesforce connection not established"
records_to_upsert = []
for trade, data in trade_results.items():
forecast_json = ', '.join([f"{item['date']}: {item['headcount']}" for item in data.get('Forecast_Next_3_Days__c', [])])
record = {
'Trade__c': trade,
'Date__c': site_calendar_date.strftime('%Y-%m-%d'),
'Expected_Headcount__c': int(data['Attendance']),
'Actual_Headcount__c': int(data['Attendance']),
'Forecast_Next_3_Days__c': forecast_json,
'Shortage_Risk__c': float(data['Shortage_risk'][0]),
'Suggested_Actions__c': str(data['Suggested_actions']),
'Alert_Status__c': str(data['Alert_status']),
'Dashboard_Display__c': True
}
records_to_upsert.append(record)
for record in records_to_upsert:
sf.Labour_Attendance_Forecast__c.create(record)
logger.info(f"Successfully pushed {len(records_to_upsert)} records to Salesforce")
return None
except Exception as e:
logger.error(f"Error pushing to Salesforce: {str(e)}")
return f"Error pushing to Salesforce: {str(e)}"
# Main forecast function with improved error handling
def forecast_labour(
csv_file: Optional[Union[str, Path]],
trade_filter: Optional[str] = None,
site_calendar_date: Optional[str] = None
) -> Tuple[str, Optional[go.Figure], Optional[go.Figure], Optional[str], str]:
"""Generate labour forecast based on CSV input and push to Salesforce."""
try:
logger.info("Starting forecast process")
if not csv_file:
return "Error: No CSV file uploaded", None, None, None, "No notifications available."
df, error = process_csv(csv_file)
if error:
return error, None, None, None, "No notifications available."
try:
site_calendar_date = pd.to_datetime(site_calendar_date)
if pd.isna(site_calendar_date):
raise ValueError("Invalid site calendar date")
except ValueError as e:
logger.error(f"Date error: {str(e)}")
return f"Error: {str(e)}", None, None, None, "No notifications available."
unique_trades = df['Trade'].dropna().unique()
logger.info(f"Unique trades in CSV: {list(unique_trades)}")
selected_trades = []
if trade_filter and trade_filter.strip():
selected_trades = [t.strip() for t in trade_filter.split(',') if t.strip()]
selected_trades = [t for t in selected_trades if any(t.lower() == ut.lower() for ut in unique_trades)]
if not selected_trades:
logger.warning(f"No valid trades found in filter: {trade_filter}. Defaulting to all trades.")
selected_trades = unique_trades
else:
logger.info("Trade filter empty. Using all trades.")
selected_trades = unique_trades
logger.info(f"Selected trades: {list(selected_trades)}")
trade_results = {}
predictions_dict = {}
shortage_probs_dict = {}
errors = []
for trade in selected_trades:
trade_df = df[df['Trade'].str.lower() == trade.lower()]
date_match = trade_df[trade_df['Date'] == site_calendar_date]
if date_match.empty:
errors.append(f"No data for trade {trade} on {site_calendar_date.strftime('%Y-%m-%d')}")
continue
if len(date_match) > 1:
errors.append(f"Warning: Multiple rows for trade {trade} on {site_calendar_date.strftime('%Y-%m-%d')}")
predictions, shortage_probs, site_calendar, suggested_actions, alert_status, forecast_error = weighted_moving_average_forecast(df, trade, site_calendar_date)
if forecast_error:
errors.append(forecast_error)
continue
predictions_dict[trade] = predictions
shortage_probs_dict[trade] = shortage_probs
record = date_match.iloc[0]
result_data = {
'Date': site_calendar_date,
'Trade': trade,
'Weather': record['Weather'],
'Attendance': record['Attendance'],
'Forecast_Next_3_Days__c': predictions,
'Shortage_risk': shortage_probs,
'Suggested_actions': suggested_actions,
'Alert_status': alert_status
}
trade_results[trade] = result_data
if not trade_results:
error_msg = "No valid trade data processed"
if errors:
error_msg += f". Errors: {'; '.join(errors)}"
return error_msg, None, None, None, "No notifications available."
sf = connect_to_salesforce()
sf_error = push_to_salesforce(sf, trade_results, site_calendar_date)
if sf_error:
errors.append(sf_error)
line_chart = create_chart(df, predictions_dict)
heatmap = create_heatmap(df, predictions_dict, shortage_probs_dict, site_calendar_date)
pdf_summary = generate_pdf_summary(trade_results)
formatted_output, formatted_notifications = format_output(trade_results, site_calendar_date)
error_msg = "; ".join(errors) if errors else None
final_output = formatted_output + (f"\nWarnings: {error_msg}" if error_msg else "")
pdf_link = f'<a href="data:application/pdf;base64,{pdf_summary}" download="summary.pdf">Download Summary PDF</a>' if pdf_summary else "PDF generation failed."
return final_output, line_chart, heatmap, pdf_link, formatted_notifications
except Exception as e:
logger.error(f"Unexpected error in forecast: {str(e)}")
return f"Error processing file: {str(e)}", None, None, None, "No notifications available."
# Gradio interface with improved UI
def gradio_interface():
"""Launch Gradio interface for the application."""
with gr.Blocks(theme=gr.themes.Soft()) as interface:
gr.Markdown("# Labour Attendance Forecast (Updated June 2025)")
gr.Markdown("Upload a CSV with columns: Date, Attendance, Trade, Weather, Alert_status (optional)")
gr.Markdown("Enter trade names (e.g., 'Painter, Electrician') or leave blank for all trades")
gr.Markdown("Enter site calendar date (YYYY-MM-DD) for CSV data and 3-day forecast")
with gr.Row():
csv_input = gr.File(label="Upload CSV", file_types=[".csv"])
trade_input = gr.Textbox(label="Filter by Trades", placeholder="e.g., Painter, Electrician")
site_calendar_input = gr.Textbox(
label="Site Calendar Date (YYYY-MM-DD)",
placeholder="e.g., 2025-04-25",
value="2025-04-01" # Default date for example
)
forecast_button = gr.Button("Generate Forecast")
result_output = gr.Textbox(label="Forecast Result", lines=20, show_copy_button=True)
line_chart_output = gr.Plot(label="Forecast Trendline")
heatmap_output = gr.Plot(label="Shortage Risk Heatmap")
notification_output = gr.Textbox(label="Contractor Notifications", lines=5, show_copy_button=True)
pdf_output = gr.HTML(label="Download Summary PDF")
forecast_button.click(
fn=forecast_labour,
inputs=[csv_input, trade_input, site_calendar_input],
outputs=[result_output, line_chart_output, heatmap_output, pdf_output, notification_output]
)
logger.info("Launching Gradio interface")
return interface
if __name__ == '__main__':
interface = gradio_interface()
interface.launch(share=False)