Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| from datetime import datetime | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import io | |
| import base64 | |
| import matplotlib.pyplot as plt | |
| from matplotlib.backends.backend_pdf import PdfPages | |
| from simple_salesforce import Salesforce | |
| from typing import List, Dict, Tuple, Optional, Union | |
| from pathlib import Path | |
| from dataclasses import dataclass, field | |
| from logging.handlers import RotatingFileHandler | |
| from retrying import retry | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Configure logging with rotation | |
| logger = logging.getLogger(__name__) | |
| handler = RotatingFileHandler('labour_forecast.log', maxBytes=5*1024*1024, backupCount=3) | |
| handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) | |
| logger.addHandler(handler) | |
| logger.addHandler(logging.StreamHandler()) | |
| logger.setLevel(logging.INFO) | |
| # Salesforce credentials from environment variables | |
| SF_USERNAME = os.getenv('SF_USERNAME') | |
| SF_PASSWORD = os.getenv('SF_PASSWORD') | |
| SF_SECURITY_TOKEN = os.getenv('SF_SECURITY_TOKEN') | |
| SF_DOMAIN = os.getenv('SF_DOMAIN', 'login') | |
| # Verify Salesforce credentials | |
| if not all([SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN]): | |
| logger.error("One or more Salesforce credentials are missing.") | |
| raise ValueError("Salesforce credentials not set. Please set SF_USERNAME, SF_PASSWORD, and SF_SECURITY_TOKEN in your .env file.") | |
| # Configuration class with default_factory for mutable fields | |
| class Config: | |
| REQUIRED_COLUMNS: List[str] = field(default_factory=lambda: ['Date', 'Attendance', 'Trade', 'Weather']) | |
| OPTIONAL_COLUMNS: List[str] = field(default_factory=lambda: ['Alert_status']) | |
| ENCODINGS: List[str] = field(default_factory=lambda: ['utf-8', 'latin1', 'iso-8859-1', 'utf-16']) | |
| FORECAST_DAYS: int = 3 | |
| WMA_WEIGHTS: Dict[int, np.ndarray] = field(default_factory=lambda: { | |
| 3: np.array([0.5, 0.3, 0.2]), | |
| 2: np.array([0.6, 0.4]), | |
| 1: np.array([1.0]) | |
| }) | |
| WEATHER_IMPACT: Dict[str, float] = field(default_factory=lambda: { | |
| 'Sunny': 0, | |
| 'Rainy': 1, | |
| 'Cloudy': 0.5, | |
| 'N/A': 0.5 | |
| }) | |
| WEEKEND_ADJUSTMENT: float = 0.8 | |
| VALID_ALERT_STATUSES: set = frozenset({'Normal', 'Critical', 'Warning'}) | |
| ALERT_TO_SHORTAGE: Dict[str, float] = field(default_factory=lambda: { | |
| 'Normal': 0.0, | |
| 'Warning': 0.41, | |
| 'Critical': 0.71 | |
| }) | |
| CONFIG = Config() | |
| # Salesforce connection with retry mechanism | |
| def connect_to_salesforce() -> Optional[Salesforce]: | |
| """Establish connection to Salesforce with retry mechanism.""" | |
| try: | |
| sf = Salesforce( | |
| username=SF_USERNAME, | |
| password=SF_PASSWORD, | |
| security_token=SF_SECURITY_TOKEN, | |
| domain=SF_DOMAIN | |
| ) | |
| logger.info("Successfully connected to Salesforce") | |
| return sf | |
| except Exception as e: | |
| logger.error(f"Failed to connect to Salesforce: {str(e)}") | |
| raise | |
| # Data processing with improved validation | |
| def process_csv(file_path: Union[str, Path]) -> Tuple[Optional[pd.DataFrame], Optional[str]]: | |
| """Read and preprocess CSV file with enhanced validation.""" | |
| file_path = Path(file_path) if isinstance(file_path, str) else file_path | |
| for encoding in CONFIG.ENCODINGS: | |
| try: | |
| df = pd.read_csv(file_path, encoding=encoding, dtype_backend='numpy_nullable') | |
| df.columns = df.columns.str.strip().str.capitalize() | |
| missing_required = [col for col in CONFIG.REQUIRED_COLUMNS if col not in df.columns] | |
| if missing_required: | |
| raise ValueError(f"Missing required columns: {', '.join(missing_required)}") | |
| df['Date'] = pd.to_datetime(df['Date'], dayfirst=True, errors='coerce') | |
| if df['Date'].isna().all(): | |
| raise ValueError("All dates in CSV are invalid") | |
| df['Attendance'] = pd.to_numeric(df['Attendance'], errors='coerce').fillna(0).astype('Int64') | |
| df['Trade'] = df['Trade'].astype(str).replace('nan', 'N/A') | |
| df['Weather'] = df['Weather'].astype(str).replace('nan', 'N/A') | |
| if 'Alert_status' in df.columns: | |
| df['Alert_status'] = df['Alert_status'].astype(str).replace('nan', 'Normal') | |
| df['Alert_status'] = df['Alert_status'].apply( | |
| lambda x: x if x in CONFIG.VALID_ALERT_STATUSES else 'Normal' | |
| ) | |
| logger.info(f"Processed Alert_status values: {df['Alert_status'].unique()}") | |
| else: | |
| df['Alert_status'] = 'Normal' | |
| logger.info("Alert_status column missing; defaulting to 'Normal'") | |
| logger.info("CSV file processed successfully") | |
| return df, None | |
| except Exception as e: | |
| logger.warning(f"Failed with encoding {encoding}: {str(e)}") | |
| continue | |
| return None, f"Could not decode CSV file: {file_path.name}" | |
| # Forecasting logic with optimized performance | |
| def weighted_moving_average_forecast( | |
| df: pd.DataFrame, | |
| trade: str, | |
| site_calendar_date: pd.Timestamp | |
| ) -> Tuple[List[Dict], List[float], Optional[str], str, str, Optional[str]]: | |
| """Generate forecast using weighted moving average with day-specific shortage probabilities based on alert status.""" | |
| try: | |
| trade_df = df[df['Trade'].str.lower() == trade.lower()].copy() | |
| if trade_df.empty: | |
| return [], [], None, 'N/A', 'Normal', f"No data found for trade: {trade}" | |
| is_weekday = site_calendar_date.weekday() < 5 | |
| site_calendar = 1 if is_weekday else 0 | |
| trade_df = trade_df.sort_values('Date') | |
| recent_data = trade_df.tail(30)[['Date', 'Attendance', 'Weather', 'Alert_status']].copy() | |
| if recent_data.empty: | |
| return [], [], None, 'N/A', 'Normal', f"No data for trade {trade} on or before {site_calendar_date.strftime('%Y-%m-%d')}" | |
| predictions = [] | |
| shortage_probs = [] | |
| alert_statuses = [] | |
| future_dates = pd.date_range(site_calendar_date, periods=CONFIG.FORECAST_DAYS + 1, freq='D')[1:] | |
| # Process the initial date (site_calendar_date) | |
| date_match = trade_df[trade_df['Date'] == site_calendar_date] | |
| if not date_match.empty: | |
| alert_status = date_match['Alert_status'].iloc[0] | |
| shortage_prob = CONFIG.ALERT_TO_SHORTAGE.get(alert_status, 0.0) | |
| shortage_probs.append(shortage_prob) | |
| alert_statuses.append(alert_status) | |
| else: | |
| shortage_probs.append(0.0) | |
| alert_statuses.append('Normal') | |
| # Process forecast dates | |
| for i, date in enumerate(future_dates): | |
| date_data = trade_df[trade_df['Date'] == date] | |
| if not date_data.empty: | |
| headcount = int(date_data['Attendance'].iloc[0]) | |
| weather = date_data['Weather'].iloc[0] | |
| alert_status = date_data['Alert_status'].iloc[0] | |
| shortage_prob = CONFIG.ALERT_TO_SHORTAGE.get(alert_status, 0.0) | |
| else: | |
| recent_attendance = recent_data['Attendance'].values[-3:] | |
| weights = CONFIG.WMA_WEIGHTS.get(len(recent_attendance), CONFIG.WMA_WEIGHTS[1]) | |
| forecast_value = np.average(recent_attendance, weights=weights) | |
| weather_idx = i % len(recent_data['Weather']) | |
| weather = recent_data['Weather'].iloc[-weather_idx-1] | |
| weather_impact = CONFIG.WEATHER_IMPACT.get(weather, 0.5) | |
| forecast_value *= (1 - 0.1 * weather_impact) | |
| headcount = round(forecast_value * (1 if site_calendar == 1 else CONFIG.WEEKEND_ADJUSTMENT)) | |
| # Use the most recent alert status for forecast dates if no data exists | |
| recent_alert_idx = i % len(recent_data['Alert_status']) | |
| alert_status = recent_data['Alert_status'].iloc[-recent_alert_idx-1] | |
| shortage_prob = CONFIG.ALERT_TO_SHORTAGE.get(alert_status, 0.0) | |
| predictions.append({ | |
| "date": date.strftime('%Y-%m-%d'), | |
| "headcount": headcount, | |
| "source": 'CSV' if not date_data.empty else 'WMA' | |
| }) | |
| shortage_probs.append(shortage_prob) | |
| alert_statuses.append(alert_status) | |
| # Set suggested actions based on the alert status of the initial date | |
| alert_status = alert_statuses[0] | |
| if alert_status == 'Normal': | |
| suggested_actions = 'No action needed' | |
| elif alert_status == 'Warning': | |
| suggested_actions = 'Review staffing levels' | |
| elif alert_status == 'Critical': | |
| suggested_actions = 'Urgent hiring needed' | |
| else: | |
| suggested_actions = 'Monitor' | |
| site_calendar_value = site_calendar_date.strftime('%Y-%m-%d') + f" ({'Weekday' if is_weekday else 'Weekend'})" | |
| logger.info(f"Forecast generated for trade: {trade}") | |
| return predictions, shortage_probs, site_calendar_value, suggested_actions, alert_status, None | |
| except Exception as e: | |
| logger.error(f"Forecast error for trade {trade}: {str(e)}") | |
| return [], [], None, 'N/A', 'Normal', f"Forecast error: {str(e)}" | |
| # Visualization with improved readability | |
| def create_heatmap( | |
| df: pd.DataFrame, | |
| predictions_dict: Dict, | |
| shortage_probs_dict: Dict, | |
| site_calendar_date: pd.Timestamp | |
| ) -> go.Figure: | |
| """Create a clean and clear heatmap with day-specific shortage probabilities.""" | |
| try: | |
| heatmap_data = [] | |
| future_dates = pd.date_range(site_calendar_date, periods=CONFIG.FORECAST_DAYS + 1, freq='D')[1:] | |
| for trade in predictions_dict.keys(): | |
| probs = shortage_probs_dict.get(trade, [0.0] * (CONFIG.FORECAST_DAYS + 1)) | |
| # Initial date | |
| heatmap_data.append({ | |
| 'Date': site_calendar_date.strftime('%Y-%m-%d'), | |
| 'Trade': trade, | |
| 'Shortage_Probability': probs[0] | |
| }) | |
| # Forecasted dates | |
| for i, date in enumerate(future_dates): | |
| heatmap_data.append({ | |
| 'Date': date.strftime('%Y-%m-%d'), | |
| 'Trade': trade, | |
| 'Shortage_Probability': probs[i + 1] | |
| }) | |
| heatmap_df = pd.DataFrame(heatmap_data) | |
| if heatmap_df.empty: | |
| return go.Figure().update_layout(title="Shortage Risk Heatmap (No Data)") | |
| display_probs = heatmap_df['Shortage_Probability'] * 100 | |
| custom_colorscale = [ | |
| [0.0, '#00FF00'], # Light green for Normal | |
| [0.4, '#00FF00'], | |
| [0.41, '#FFA500'], # Light orange for Warning | |
| [0.7, '#FFA500'], | |
| [0.71, '#FF0000'], # Bright red for Critical | |
| [1.0, '#FF0000'] | |
| ] | |
| fig = go.Figure(data=go.Heatmap( | |
| x=heatmap_df['Date'], | |
| y=heatmap_df['Trade'], | |
| z=heatmap_df['Shortage_Probability'], | |
| colorscale=custom_colorscale, | |
| zmin=0, zmax=1, | |
| text=display_probs.round(1).astype(str) + '%', | |
| texttemplate="%{text}", | |
| textfont={"size": 18, "color": "black"}, | |
| hovertemplate="Trade: %{y}<br>Date: %{x}<br>Shortage Risk: %{text}<extra></extra>", | |
| colorbar=dict( | |
| title="Shortage Risk", | |
| tickvals=[0, 0.41, 0.71], | |
| ticktext=["0%", "41%", "71%"], | |
| len=0.8, | |
| thickness=20 | |
| ) | |
| )) | |
| fig.update_layout( | |
| title="Shortage Risk Heatmap", | |
| xaxis_title="Date", | |
| yaxis_title="Trade", | |
| xaxis=dict(tickangle=45, tickformat="%Y-%m-%d", gridcolor='black', gridwidth=1), | |
| yaxis=dict(autorange="reversed", gridcolor='black', gridwidth=1), | |
| font=dict(size=14), | |
| margin=dict(l=100, r=50, t=80, b=100), | |
| plot_bgcolor="white", | |
| paper_bgcolor="white", | |
| width=800, | |
| height=400 + len(heatmap_df['Trade'].unique()) * 50 | |
| ) | |
| fig.update_traces(xgap=3, ygap=3) | |
| return fig | |
| except Exception as e: | |
| logger.error(f"Error creating heatmap: {str(e)}") | |
| return go.Figure().update_layout(title=f"Error in Heatmap: {str(e)}") | |
| def create_chart(df: pd.DataFrame, predictions_dict: Dict) -> go.Figure: | |
| """Create a line chart for attendance forecasts.""" | |
| try: | |
| combined_df = pd.DataFrame() | |
| for trade, predictions in predictions_dict.items(): | |
| trade_df = df[df['Trade'].str.lower() == trade.lower()][['Date', 'Attendance']].copy() | |
| trade_df['Type'] = 'Historical' | |
| trade_df['Trade'] = trade | |
| forecast_df = pd.DataFrame(predictions) | |
| if not forecast_df.empty: | |
| forecast_df['Date'] = pd.to_datetime(forecast_df['date']) | |
| forecast_df['Attendance'] = forecast_df['headcount'] | |
| forecast_df['Type'] = 'Forecast' | |
| forecast_df['Trade'] = trade | |
| combined_df = pd.concat([combined_df, trade_df, forecast_df[['Date', 'Attendance', 'Type', 'Trade']]]) | |
| if combined_df.empty: | |
| return go.Figure().update_layout(title="Labour Attendance Forecast (No Data)") | |
| fig = px.line( | |
| combined_df, | |
| x='Date', | |
| y='Attendance', | |
| color='Trade', | |
| line_dash='Type', | |
| markers=True, | |
| title='Labour Attendance Forecast by Trade' | |
| ) | |
| return fig | |
| except Exception as e: | |
| logger.error(f"Error creating chart: {str(e)}") | |
| return go.Figure().update_layout(title=f"Error in Chart: {str(e)}") | |
| def generate_pdf_summary(trade_results: Dict) -> Optional[str]: | |
| """Generate a PDF summary of the forecast.""" | |
| try: | |
| buffer = io.BytesIO() | |
| with PdfPages(buffer) as pdf: | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| if not trade_results: | |
| ax.text(0.1, 0.5, "No data available for summary", fontsize=12) | |
| else: | |
| for i, (trade, data) in enumerate(trade_results.items()): | |
| ax.text( | |
| 0.1, 0.9 - 0.1*i, | |
| f"{trade}: {data['Attendance']} (Actual), Shortage Risk: {data['Shortage_risk'][0]*100:.0f}%, Alert: {data['Alert_status']}", | |
| fontsize=12 | |
| ) | |
| ax.set_title("Weekly Labour Forecast Summary") | |
| ax.axis('off') | |
| pdf.savefig() | |
| plt.close() | |
| pdf_base64 = base64.b64encode(buffer.getvalue()).decode() | |
| logger.info("PDF summary generated") | |
| return pdf_base64 | |
| except Exception as e: | |
| logger.error(f"Error generating PDF: {str(e)}") | |
| return None | |
| def format_output(trade_results: Dict, site_calendar_date: pd.Timestamp) -> Tuple[str, str]: | |
| """Format the output for display and return contractor notifications.""" | |
| output = [] | |
| notifications = [] | |
| for trade, data in trade_results.items(): | |
| output.append(f"Trade: {trade}") | |
| for key in CONFIG.REQUIRED_COLUMNS + ['Alert_status', 'Forecast_Next_3_Days__c', 'Shortage_risk', 'Suggested_actions']: | |
| if key == 'Date': | |
| value = site_calendar_date.strftime('%Y-%m-%d') | |
| elif key == 'Forecast_Next_3_Days__c': | |
| value = ', '.join([f"{item['date']}: {item['headcount']}" for item in data.get(key, [])]) or 'N/A' | |
| else: | |
| value = data.get(key, 'N/A') | |
| if key in ['Weather', 'Alert_status', 'Suggested_actions', 'Trade']: | |
| value = str(value) | |
| elif key == 'Shortage_risk': | |
| value = str(round(value[0], 2)) | |
| elif key == 'Attendance': | |
| value = str(int(value)) | |
| output.append(f" • {key}: {value}") | |
| alert_status = data.get('Alert_status', 'Normal') | |
| shortage_risk = data.get('Shortage_risk', [0])[0] * 100 | |
| suggested_actions = data.get('Suggested_actions', 'No action needed') | |
| if alert_status == 'Critical': | |
| notification = f"Urgent Alert for {trade}: {suggested_actions} due to high shortage risk of {round(shortage_risk)}%." | |
| elif alert_status == 'Warning': | |
| notification = f"Warning for {trade}: {suggested_actions} due to moderate shortage risk of {round(shortage_risk)}%." | |
| else: | |
| notification = f"Notice for {trade}: {suggested_actions}, shortage risk is low at {round(shortage_risk)}%." | |
| notifications.append(notification) | |
| output.append("") | |
| formatted_output = "\n".join(output) or "No valid trade data available." | |
| formatted_notifications = "Contractor Notifications:\n" + "\n".join([f" • {notification}" for notification in notifications]) or "No notifications available." | |
| return formatted_output, formatted_notifications | |
| def push_to_salesforce(sf: Optional[Salesforce], trade_results: Dict, site_calendar_date: pd.Timestamp) -> Optional[str]: | |
| """Push forecast results to Salesforce Labour_Attendance_Forecast__c object.""" | |
| try: | |
| if sf is None: | |
| return "Salesforce connection not established" | |
| records_to_upsert = [] | |
| for trade, data in trade_results.items(): | |
| forecast_json = ', '.join([f"{item['date']}: {item['headcount']}" for item in data.get('Forecast_Next_3_Days__c', [])]) | |
| record = { | |
| 'Trade__c': trade, | |
| 'Date__c': site_calendar_date.strftime('%Y-%m-%d'), | |
| 'Expected_Headcount__c': int(data['Attendance']), | |
| 'Actual_Headcount__c': int(data['Attendance']), | |
| 'Forecast_Next_3_Days__c': forecast_json, | |
| 'Shortage_Risk__c': float(data['Shortage_risk'][0]), | |
| 'Suggested_Actions__c': str(data['Suggested_actions']), | |
| 'Alert_Status__c': str(data['Alert_status']), | |
| 'Dashboard_Display__c': True | |
| } | |
| records_to_upsert.append(record) | |
| for record in records_to_upsert: | |
| sf.Labour_Attendance_Forecast__c.create(record) | |
| logger.info(f"Successfully pushed {len(records_to_upsert)} records to Salesforce") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error pushing to Salesforce: {str(e)}") | |
| return f"Error pushing to Salesforce: {str(e)}" | |
| # Main forecast function with improved error handling | |
| def forecast_labour( | |
| csv_file: Optional[Union[str, Path]], | |
| trade_filter: Optional[str] = None, | |
| site_calendar_date: Optional[str] = None | |
| ) -> Tuple[str, Optional[go.Figure], Optional[go.Figure], Optional[str], str]: | |
| """Generate labour forecast based on CSV input and push to Salesforce.""" | |
| try: | |
| logger.info("Starting forecast process") | |
| if not csv_file: | |
| return "Error: No CSV file uploaded", None, None, None, "No notifications available." | |
| df, error = process_csv(csv_file) | |
| if error: | |
| return error, None, None, None, "No notifications available." | |
| try: | |
| site_calendar_date = pd.to_datetime(site_calendar_date) | |
| if pd.isna(site_calendar_date): | |
| raise ValueError("Invalid site calendar date") | |
| except ValueError as e: | |
| logger.error(f"Date error: {str(e)}") | |
| return f"Error: {str(e)}", None, None, None, "No notifications available." | |
| unique_trades = df['Trade'].dropna().unique() | |
| logger.info(f"Unique trades in CSV: {list(unique_trades)}") | |
| selected_trades = [] | |
| if trade_filter and trade_filter.strip(): | |
| selected_trades = [t.strip() for t in trade_filter.split(',') if t.strip()] | |
| selected_trades = [t for t in selected_trades if any(t.lower() == ut.lower() for ut in unique_trades)] | |
| if not selected_trades: | |
| logger.warning(f"No valid trades found in filter: {trade_filter}. Defaulting to all trades.") | |
| selected_trades = unique_trades | |
| else: | |
| logger.info("Trade filter empty. Using all trades.") | |
| selected_trades = unique_trades | |
| logger.info(f"Selected trades: {list(selected_trades)}") | |
| trade_results = {} | |
| predictions_dict = {} | |
| shortage_probs_dict = {} | |
| errors = [] | |
| for trade in selected_trades: | |
| trade_df = df[df['Trade'].str.lower() == trade.lower()] | |
| date_match = trade_df[trade_df['Date'] == site_calendar_date] | |
| if date_match.empty: | |
| errors.append(f"No data for trade {trade} on {site_calendar_date.strftime('%Y-%m-%d')}") | |
| continue | |
| if len(date_match) > 1: | |
| errors.append(f"Warning: Multiple rows for trade {trade} on {site_calendar_date.strftime('%Y-%m-%d')}") | |
| predictions, shortage_probs, site_calendar, suggested_actions, alert_status, forecast_error = weighted_moving_average_forecast(df, trade, site_calendar_date) | |
| if forecast_error: | |
| errors.append(forecast_error) | |
| continue | |
| predictions_dict[trade] = predictions | |
| shortage_probs_dict[trade] = shortage_probs | |
| record = date_match.iloc[0] | |
| result_data = { | |
| 'Date': site_calendar_date, | |
| 'Trade': trade, | |
| 'Weather': record['Weather'], | |
| 'Attendance': record['Attendance'], | |
| 'Forecast_Next_3_Days__c': predictions, | |
| 'Shortage_risk': shortage_probs, | |
| 'Suggested_actions': suggested_actions, | |
| 'Alert_status': alert_status | |
| } | |
| trade_results[trade] = result_data | |
| if not trade_results: | |
| error_msg = "No valid trade data processed" | |
| if errors: | |
| error_msg += f". Errors: {'; '.join(errors)}" | |
| return error_msg, None, None, None, "No notifications available." | |
| sf = connect_to_salesforce() | |
| sf_error = push_to_salesforce(sf, trade_results, site_calendar_date) | |
| if sf_error: | |
| errors.append(sf_error) | |
| line_chart = create_chart(df, predictions_dict) | |
| heatmap = create_heatmap(df, predictions_dict, shortage_probs_dict, site_calendar_date) | |
| pdf_summary = generate_pdf_summary(trade_results) | |
| formatted_output, formatted_notifications = format_output(trade_results, site_calendar_date) | |
| error_msg = "; ".join(errors) if errors else None | |
| final_output = formatted_output + (f"\nWarnings: {error_msg}" if error_msg else "") | |
| pdf_link = f'<a href="data:application/pdf;base64,{pdf_summary}" download="summary.pdf">Download Summary PDF</a>' if pdf_summary else "PDF generation failed." | |
| return final_output, line_chart, heatmap, pdf_link, formatted_notifications | |
| except Exception as e: | |
| logger.error(f"Unexpected error in forecast: {str(e)}") | |
| return f"Error processing file: {str(e)}", None, None, None, "No notifications available." | |
| # Gradio interface with improved UI | |
| def gradio_interface(): | |
| """Launch Gradio interface for the application.""" | |
| with gr.Blocks(theme=gr.themes.Soft()) as interface: | |
| gr.Markdown("# Labour Attendance Forecast (Updated June 2025)") | |
| gr.Markdown("Upload a CSV with columns: Date, Attendance, Trade, Weather, Alert_status (optional)") | |
| gr.Markdown("Enter trade names (e.g., 'Painter, Electrician') or leave blank for all trades") | |
| gr.Markdown("Enter site calendar date (YYYY-MM-DD) for CSV data and 3-day forecast") | |
| with gr.Row(): | |
| csv_input = gr.File(label="Upload CSV", file_types=[".csv"]) | |
| trade_input = gr.Textbox(label="Filter by Trades", placeholder="e.g., Painter, Electrician") | |
| site_calendar_input = gr.Textbox( | |
| label="Site Calendar Date (YYYY-MM-DD)", | |
| placeholder="e.g., 2025-04-25", | |
| value="2025-04-01" # Default date for example | |
| ) | |
| forecast_button = gr.Button("Generate Forecast") | |
| result_output = gr.Textbox(label="Forecast Result", lines=20, show_copy_button=True) | |
| line_chart_output = gr.Plot(label="Forecast Trendline") | |
| heatmap_output = gr.Plot(label="Shortage Risk Heatmap") | |
| notification_output = gr.Textbox(label="Contractor Notifications", lines=5, show_copy_button=True) | |
| pdf_output = gr.HTML(label="Download Summary PDF") | |
| forecast_button.click( | |
| fn=forecast_labour, | |
| inputs=[csv_input, trade_input, site_calendar_input], | |
| outputs=[result_output, line_chart_output, heatmap_output, pdf_output, notification_output] | |
| ) | |
| logger.info("Launching Gradio interface") | |
| return interface | |
| if __name__ == '__main__': | |
| interface = gradio_interface() | |
| interface.launch(share=False) |