Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import csv | |
| from typing import Optional, Dict, List, Tuple, Any | |
| import pandas as pd | |
| import numpy as np | |
| from datetime import datetime | |
| import gradio as gr | |
| from simple_salesforce import Salesforce | |
| from dotenv import load_dotenv | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import io | |
| import base64 | |
| from matplotlib.backends.backend_pdf import PdfPages | |
| import matplotlib.pyplot as plt | |
| import json | |
| import mimetypes | |
| from difflib import get_close_matches | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s - [Context: %(context)s]', | |
| handlers=[ | |
| logging.FileHandler('labour_forecast.log'), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Configuration class | |
| class Config: | |
| REQUIRED_COLUMNS = ['Date', 'Attendance', 'Trade', 'Weather', 'Alert_status', 'Shortage_risk', 'Suggested_actions'] | |
| ENCODINGS = ['utf-8', 'utf-8-sig', 'latin1', 'iso-8859-1', 'utf-16', 'cp1252', 'ascii', 'mac_roman'] | |
| FORECAST_DAYS = 3 | |
| WMA_WEIGHTS = {3: np.array([0.5, 0.3, 0.2]), 2: np.array([0.6, 0.4]), 1: np.array([1.0])} | |
| WEATHER_IMPACT = {'Sunny': 0, 'Rainy': 1, 'Cloudy': 0.5, 'N/A': 0.5} | |
| WEEKEND_ADJUSTMENT = 0.8 | |
| MIN_HISTORY_DAYS = 30 | |
| # Salesforce connection | |
| def connect_salesforce() -> Tuple[Optional[Salesforce], Optional[str]]: | |
| load_dotenv() | |
| try: | |
| sf = Salesforce( | |
| username=os.getenv('SF_USERNAME'), | |
| password=os.getenv('SF_PASSWORD'), | |
| security_token=os.getenv('SF_SECURITY_TOKEN') | |
| ) | |
| logger.info("Successfully connected to Salesforce", extra={'context': 'Salesforce'}) | |
| return sf, None | |
| except Exception as e: | |
| logger.error(f"Failed to connect to Salesforce: {str(e)}", extra={'context': 'Salesforce'}) | |
| return None, f"Salesforce connection failed: {str(e)}" | |
| # Data processing | |
| def process_csv(file_path: str) -> Tuple[Optional[pd.DataFrame], Optional[str]]: | |
| if not os.path.exists(file_path): | |
| error_msg = f"File not found: {file_path}" | |
| logger.error(error_msg, extra={'context': 'DataProcessing'}) | |
| return None, error_msg | |
| try: | |
| # Validate file type | |
| mime_type, _ = mimetypes.guess_type(file_path) | |
| if mime_type not in ['text/csv', 'text/plain']: | |
| error_msg = f"Invalid file type for {file_path}. Expected CSV, got {mime_type or 'unknown'}" | |
| logger.error(error_msg, extra={'context': 'DataProcessing'}) | |
| return None, error_msg | |
| # Check if file is empty | |
| if os.path.getsize(file_path) == 0: | |
| error_msg = f"CSV file is empty: {file_path}" | |
| logger.error(error_msg, extra={'context': 'DataProcessing'}) | |
| return None, error_msg | |
| # Detect delimiter | |
| with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: | |
| try: | |
| sample = f.read(1024) | |
| sniffer = csv.Sniffer() | |
| delimiter = sniffer.sniff(sample).delimiter | |
| logger.info(f"Detected delimiter: '{delimiter}' for {file_path}", extra={'context': 'DataProcessing'}) | |
| except csv.Error: | |
| delimiter = ',' | |
| logger.warning(f"Could not detect delimiter for {file_path}, using default ','", | |
| extra={'context': 'DataProcessing'}) | |
| for encoding in Config.ENCODINGS: | |
| try: | |
| df = pd.read_csv(file_path, encoding=encoding, dtype_backend='numpy_nullable', sep=delimiter) | |
| # Log raw column names | |
| raw_columns = df.columns.tolist() | |
| logger.info(f"Raw columns found: {', '.join(raw_columns)} in {file_path}", | |
| extra={'context': 'DataProcessing'}) | |
| # Normalize column names: strip spaces, convert to title case | |
| df.columns = df.columns.str.strip().str.title() | |
| normalized_columns = df.columns.tolist() | |
| logger.info(f"Normalized columns: {', '.join(normalized_columns)} in {file_path}", | |
| extra={'context': 'DataProcessing'}) | |
| # Create a mapping of normalized to original columns for case-insensitive matching | |
| col_map = {col.lower(): col for col in df.columns} | |
| missing_columns = [] | |
| for required_col in Config.REQUIRED_COLUMNS: | |
| if required_col.lower() not in col_map: | |
| missing_columns.append(required_col) | |
| if missing_columns: | |
| suggestions = {} | |
| for missing_col in missing_columns: | |
| close_matches = get_close_matches(missing_col, df.columns, n=2, cutoff=0.6) | |
| if close_matches: | |
| suggestions[missing_col] = close_matches | |
| suggestion_msg = "\n".join([f" - {col}: Did you mean {', '.join(suggestions[col])}?" | |
| for col in suggestions]) | |
| error_msg = (f"Missing columns: {', '.join(missing_columns)} in {file_path}\n" | |
| f"Required columns: {', '.join(Config.REQUIRED_COLUMNS)}\n" | |
| f"Found columns (after normalization): {', '.join(normalized_columns)}\n" | |
| f"Raw columns: {', '.join(raw_columns)}\n" | |
| f"Suggestions:\n{suggestion_msg if suggestion_msg else 'No similar column names found.'}") | |
| logger.error(error_msg, extra={'context': 'DataProcessing'}) | |
| return None, error_msg | |
| # Rename columns to match expected case | |
| rename_dict = {col_map[required_col.lower()]: required_col for required_col in Config.REQUIRED_COLUMNS} | |
| df.rename(columns=rename_dict, inplace=True) | |
| if df.empty: | |
| error_msg = f"CSV file contains no data: {file_path}" | |
| logger.error(error_msg, extra={'context': 'DataProcessing'}) | |
| return None, error_msg | |
| df['Date'] = pd.to_datetime(df['Date'], dayfirst=True, errors='coerce') | |
| if df['Date'].isna().all(): | |
| error_msg = f"All dates in CSV are invalid: {file_path}" | |
| logger.error(error_msg, extra={'context': 'DataProcessing'}) | |
| return None, error_msg | |
| date_range = (df['Date'].max() - df['Date'].min()).days | |
| if date_range < Config.MIN_HISTORY_DAYS - 1: | |
| error_msg = f"CSV must contain at least {Config.MIN_HISTORY_DAYS} days of data, found {date_range} days: {file_path}" | |
| logger.error(error_msg, extra={'context': 'DataProcessing'}) | |
| return None, error_msg | |
| df['Attendance'] = pd.to_numeric(df['Attendance'], errors='coerce').fillna(0).astype('Int64') | |
| df['Shortage_risk'] = pd.to_numeric( | |
| df['Shortage_risk'].replace('%', '', regex=True), errors='coerce' | |
| ).fillna(0.5) / 100 | |
| for col in ['Weather', 'Alert_status', 'Suggested_actions', 'Trade']: | |
| df[col] = df[col].astype(str).replace('nan', 'N/A') | |
| logger.info(f"CSV file processed successfully with encoding {encoding} and delimiter '{delimiter}': {file_path}", | |
| extra={'context': 'DataProcessing'}) | |
| return df, None | |
| except Exception as e: | |
| logger.warning(f"Failed with encoding {encoding} for {file_path}: {str(e)}", | |
| extra={'context': 'DataProcessing'}) | |
| continue | |
| error_msg = f"Could not decode CSV file with any supported encoding: {file_path}. Tried encodings: {', '.join(Config.ENCODINGS)}" | |
| logger.error(error_msg, extra={'context': 'DataProcessing'}) | |
| return None, error_msg | |
| except Exception as e: | |
| error_msg = f"Unexpected error reading CSV file {file_path}: {str(e)}" | |
| logger.error(error_msg, extra={'context': 'DataProcessing'}) | |
| return None, error_msg | |
| # Forecasting logic | |
| def weighted_moving_average_forecast( | |
| df: pd.DataFrame, trade: str, site_calendar_date: pd.Timestamp | |
| ) -> Tuple[List[Dict[str, Any]], float, str, Optional[str]]: | |
| try: | |
| trade_df = df[df['Trade'] == trade].copy() | |
| if trade_df.empty: | |
| return [], 0.5, "", f"No data found for trade: {trade}" | |
| is_weekday = site_calendar_date.weekday() < 5 | |
| site_calendar = 1 if is_weekday else 0 | |
| trade_df = trade_df[trade_df['Date'] <= site_calendar_date] | |
| recent_data = trade_df.tail(30)[['Date', 'Attendance', 'Weather', 'Shortage_risk']] | |
| if recent_data.empty: | |
| return [], 0.5, "", f"No data for trade {trade} on or before {site_calendar_date.strftime('%Y-%m-%d')}" | |
| predictions = [] | |
| shortage_prob = recent_data['Shortage_risk'].mean() | |
| future_dates = pd.date_range(site_calendar_date, periods=Config.FORECAST_DAYS + 1, freq='D')[1:] | |
| for date in future_dates: | |
| future_data = df[(df['Trade'] == trade) & (df['Date'] == date)] | |
| if not future_data.empty: | |
| record = future_data.iloc[0] | |
| headcount = int(record['Attendance']) if pd.notna(record['Attendance']) else 0 | |
| shortage_prob = record['Shortage_risk'] if pd.notna(record['Shortage_risk']) else 0.5 | |
| else: | |
| recent_attendance = recent_data['Attendance'].values[-3:] | |
| weights = Config.WMA_WEIGHTS.get(len(recent_attendance), Config.WMA_WEIGHTS[1]) | |
| forecast_value = np.average(recent_attendance, weights=weights) | |
| latest_weather = Config.WEATHER_IMPACT.get(recent_data['Weather'].iloc[-1], 0.5) | |
| forecast_value *= (1 - 0.1 * latest_weather) | |
| headcount = round(forecast_value * (1 if site_calendar == 1 else Config.WEEKEND_ADJUSTMENT)) | |
| attendance_trend = recent_data['Attendance'].pct_change().mean() if len(recent_data) > 1 else 0 | |
| shortage_prob = min(max(shortage_prob + attendance_trend * 0.1, 0), 1) | |
| predictions.append({"date": date.strftime('%Y-%m-%d'), "headcount": headcount}) | |
| site_calendar_value = site_calendar_date.strftime('%Y-%m-%d') + f" ({'Weekday' if is_weekday else 'Weekend'})" | |
| logger.info(f"Forecast generated for trade: {trade}", extra={'context': 'Forecasting'}) | |
| return predictions, shortage_prob, site_calendar_value, None | |
| except Exception as e: | |
| logger.error(f"Forecast error for trade {trade}: {str(e)}", extra={'context': 'Forecasting'}) | |
| return [], 0.5, "", f"Forecast error: {str(e)}" | |
| # Salesforce operations | |
| def get_project_id(sf: Salesforce) -> Tuple[Optional[str], Optional[str]]: | |
| try: | |
| query = "SELECT Id FROM Project__c ORDER BY CreatedDate DESC LIMIT 1" | |
| result = sf.query(query) | |
| if result['totalSize'] > 0: | |
| return result['records'][0]['Id'], None | |
| return None, "No project found in Salesforce" | |
| except Exception as e: | |
| logger.error(f"Error fetching Project ID: {str(e)}", extra={'context': 'Salesforce'}) | |
| return None, f"Error fetching Project ID: {str(e)}" | |
| def save_to_salesforce(sf: Salesforce, record: Dict[str, Any]) -> Dict[str, Any]: | |
| try: | |
| result = sf.Labour_Attendance_Forecast__c.create(record) | |
| logger.info(f"Record created for {record['Trade__c']}: {result['id']}", extra={'context': 'Salesforce'}) | |
| return {"success": f"Record created for {record['Trade__c']}", "record_id": result['id']} | |
| except Exception as e: | |
| logger.error(f"Error uploading to Salesforce for {record['Trade__c']}: {str(e)}", extra={'context': 'Salesforce'}) | |
| return {"error": f"Error uploading to Salesforce: {str(e)}"} | |
| # Visualization | |
| def create_heatmap( | |
| df: pd.DataFrame, predictions_dict: Dict[str, List[Dict[str, Any]]], | |
| shortage_probs: Dict[str, float], site_calendar_date: pd.Timestamp | |
| ) -> go.Figure: | |
| try: | |
| heatmap_data = [] | |
| future_dates = pd.date_range(site_calendar_date, periods=Config.FORECAST_DAYS + 1, freq='D')[1:] | |
| for trade in predictions_dict.keys(): | |
| trade_df = df[(df['Trade'] == trade) & (df['Date'] == site_calendar_date)] | |
| prob = trade_df.iloc[0]['Shortage_risk'] if not trade_df.empty and pd.notna(trade_df.iloc[0]['Shortage_risk']) else 0.5 | |
| heatmap_data.append({'Date': site_calendar_date.strftime('%Y-%m-%d'), 'Trade': trade, 'Shortage_Probability': prob}) | |
| for date in future_dates: | |
| future_data = df[(df['Trade'] == trade) & (df['Date'] == date)] | |
| prob = future_data.iloc[0]['Shortage_risk'] if not future_data.empty and pd.notna(future_data.iloc[0]['Shortage_risk']) else shortage_probs.get(trade, 0.5) | |
| heatmap_data.append({'Date': date.strftime('%Y-%m-%d'), 'Trade': trade, 'Shortage_Probability': prob}) | |
| heatmap_df = pd.DataFrame(heatmap_data) | |
| if heatmap_df.empty: | |
| return go.Figure().update_layout(title="Shortage Risk Heatmap (No Data)") | |
| # Custom colorscale: red at 0%, blue shades for values > 0% to 100% | |
| custom_colorscale = [ | |
| [0.0, 'rgb(255, 0, 0)'], # Red at exactly 0% | |
| [0.001, 'rgb(0, 0, 139)'], # Dark blue just above 0% | |
| [1.0, 'rgb(135, 206, 250)'] # Light blue at 100% | |
| ] | |
| fig = go.Figure(data=go.Heatmap( | |
| x=heatmap_df['Date'], | |
| y=heatmap_df['Trade'], | |
| z=heatmap_df['Shortage_Probability'], | |
| colorscale=custom_colorscale, | |
| zmin=0, zmax=1, | |
| text=(heatmap_df['Shortage_Probability'] * 100).round(0).astype(int).astype(str) + '%', | |
| texttemplate="%{text}", | |
| textfont={"size": 12}, | |
| colorbar=dict(title="Shortage Risk", tickvals=[0, 0.5, 1], ticktext=["0%", "50%", "100%"]) | |
| )) | |
| fig.update_layout( | |
| title="Shortage Risk Heatmap", | |
| xaxis_title="Date", | |
| yaxis_title="Trade", | |
| xaxis=dict(tickangle=45, tickformat="%Y-%m-%d"), | |
| yaxis=dict(autorange="reversed"), | |
| font=dict(size=14, family="Arial"), | |
| margin=dict(l=100, r=50, t=100, b=100), | |
| plot_bgcolor="white", | |
| paper_bgcolor="white", | |
| showlegend=False, | |
| hovermode="closest" | |
| ) | |
| return fig | |
| except Exception as e: | |
| logger.error(f"Error creating heatmap: {str(e)}", extra={'context': 'Visualization'}) | |
| return go.Figure().update_layout(title=f"Error in Heatmap: {str(e)}") | |
| def create_chart(df: pd.DataFrame, predictions_dict: Dict[str, List[Dict[str, Any]]]) -> go.Figure: | |
| try: | |
| combined_df = pd.DataFrame() | |
| for trade, predictions in predictions_dict.items(): | |
| trade_df = df[df['Trade'] == trade][['Date', 'Attendance']].copy() | |
| trade_df['Type'] = 'Historical' | |
| trade_df['Trade'] = trade | |
| forecast_df = pd.DataFrame(predictions) | |
| if not forecast_df.empty: | |
| forecast_df['Date'] = pd.to_datetime(forecast_df['date']) | |
| forecast_df['Attendance'] = forecast_df['headcount'] | |
| forecast_df['Type'] = 'Forecast' | |
| forecast_df['Trade'] = trade | |
| combined_df = pd.concat([combined_df, trade_df, forecast_df[['Date', 'Attendance', 'Type', 'Trade']]]) | |
| if combined_df.empty: | |
| return go.Figure().update_layout(title="Labour Attendance Forecast (No Data)") | |
| fig = px.line( | |
| combined_df, | |
| x='Date', | |
| y='Attendance', | |
| color='Trade', | |
| line_dash='Type', | |
| markers=True, | |
| title='Labour Attendance Forecast by Trade' | |
| ) | |
| fig.update_traces(line=dict(width=3), marker=dict(size=8)) | |
| fig.update_layout( | |
| font=dict(size=14, family="Arial"), | |
| plot_bgcolor="white", | |
| paper_bgcolor="white", | |
| hovermode="x unified" | |
| ) | |
| return fig | |
| except Exception as e: | |
| logger.error(f"Error creating chart: {str(e)}", extra={'context': 'Visualization'}) | |
| return go.Figure().update_layout(title=f"Error in Chart: {str(e)}") | |
| def generate_pdf_summary(trade_results: Dict[str, Dict[str, Any]], project_id: str) -> Optional[str]: | |
| try: | |
| buffer = io.BytesIO() | |
| with PdfPages(buffer) as pdf: | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| if not trade_results: | |
| ax.text(0.1, 0.5, "No data available for summary", fontsize=12) | |
| else: | |
| for i, (trade, data) in enumerate(trade_results.items()): | |
| ax.text(0.1, 0.9 - 0.1*i, | |
| f"{trade}: {data['Attendance']} (Actual)", fontsize=12) | |
| ax.set_title(f"Weekly Summary for Project {project_id}") | |
| ax.axis('off') | |
| pdf.savefig() | |
| plt.close() | |
| pdf_base64 = base64.b64encode(buffer.getvalue()).decode() | |
| logger.info("PDF summary generated", extra={'context': 'PDFGeneration'}) | |
| return pdf_base64 | |
| except Exception as e: | |
| logger.error(f"Error generating PDF: {str(e)}", extra={'context': 'PDFGeneration'}) | |
| return None | |
| def notify_contractor(trade: str, alert_status: str) -> str: | |
| msg = f"Notification sent to contractor for {trade} with alert status: {alert_status}" | |
| logger.info(msg, extra={'context': 'Notification'}) | |
| return msg | |
| def format_output(trade_results: Dict[str, Dict[str, Any]], site_calendar_date: pd.Timestamp) -> str: | |
| csv_columns = Config.REQUIRED_COLUMNS + ['Forecast_Next_3_Days__c'] | |
| output = [] | |
| for trade, data in trade_results.items(): | |
| output.append(f"Trade: {trade}") | |
| for key in csv_columns: | |
| if key == 'Date': | |
| value = pd.to_datetime(site_calendar_date).strftime('%Y-%m-%d') if pd.notna(site_calendar_date) else 'N/A' | |
| elif key == 'Forecast_Next_3_Days__c': | |
| value = ', '.join([f"{item['date']}: {item['headcount']}" for item in data.get(key, [])]) if data.get(key) else 'N/A' | |
| else: | |
| value = data.get(key, 'N/A') | |
| if key in ['Weather', 'Alert_status', 'Suggested_actions', 'Trade'] and value is not None: | |
| value = str(value) | |
| elif key == 'Shortage_risk' and value is not None: | |
| value = str(round(value, 2)) | |
| elif key == 'Attendance' and value is not None: | |
| value = str(int(value)) | |
| output.append(f" • {key}: {value}") | |
| output.append("") | |
| return "\n".join(output) if trade_results else "No valid trade data available." | |
| def forecast_labour( | |
| csv_file: Any, trade_filter: Optional[str] = None, site_calendar_date: Optional[str] = None | |
| ) -> Tuple[str, Optional[go.Figure], Optional[go.Figure], Optional[str], Optional[str]]: | |
| try: | |
| logger.info("Starting forecast process", extra={'context': 'Forecast'}) | |
| if not csv_file: | |
| return "Error: No CSV file provided", None, None, None, None | |
| df, error = process_csv(csv_file.name) | |
| if error: | |
| return error, None, None, None, None | |
| sf, sf_error = connect_salesforce() | |
| if sf_error: | |
| return sf_error, None, None, None, None | |
| try: | |
| site_calendar_date = pd.to_datetime(site_calendar_date) | |
| if pd.isna(site_calendar_date): | |
| raise ValueError("Invalid site calendar date") | |
| except ValueError as e: | |
| logger.error(f"Date error: {str(e)}", extra={'context': 'Forecast'}) | |
| return f"Error: {str(e)}", None, None, None, None | |
| unique_trades = df['Trade'].dropna().unique() | |
| selected_trades = [t.strip() for t in trade_filter.split(',') if t.strip()] if trade_filter else unique_trades | |
| selected_trades = [t for t in selected_trades if t in unique_trades] | |
| if not selected_trades: | |
| return f"Error: None of the specified trades '{trade_filter}' found in CSV", None, None, None, None | |
| trade_results = {} | |
| predictions_dict = {} | |
| shortage_probs = {} | |
| errors = [] | |
| project_id, error = get_project_id(sf) | |
| if error: | |
| return f"Error: {error}", None, None, None, None | |
| for trade in selected_trades: | |
| trade_df = df[df['Trade'] == trade] | |
| date_match = trade_df[trade_df['Date'] == site_calendar_date] | |
| if date_match.empty: | |
| errors.append(f"No data for trade {trade} on {site_calendar_date.strftime('%Y-%m-%d')}") | |
| continue | |
| if len(date_match) > 1: | |
| errors.append(f"Warning: Multiple rows for trade {trade} on {site_calendar_date.strftime('%Y-%m-%d')}") | |
| predictions, shortage_prob, site_calendar, forecast_error = weighted_moving_average_forecast( | |
| df, trade, site_calendar_date | |
| ) | |
| if forecast_error: | |
| errors.append(forecast_error) | |
| continue | |
| predictions_dict[trade] = predictions | |
| shortage_probs[trade] = shortage_prob | |
| record = date_match.iloc[0] | |
| result_data = { | |
| 'Date': site_calendar_date, | |
| 'Trade': trade, | |
| 'Weather': record['Weather'], | |
| 'Alert_status': record['Alert_status'], | |
| 'Shortage_risk': record['Shortage_risk'], | |
| 'Suggested_actions': record['Suggested_actions'], | |
| 'Attendance': record['Attendance'], | |
| 'Forecast_Next_3_Days__c': predictions, | |
| 'Shortage_Probability': round(shortage_prob, 2) | |
| } | |
| salesforce_record = { | |
| 'Trade__c': trade, | |
| 'Shortage_Risk__c': record['Shortage_risk'], | |
| 'Suggested_Actions__c': record['Suggested_actions'], | |
| 'Expected_Headcount__c': predictions[0]['headcount'] if predictions else 0, | |
| 'Actual_Headcount__c': int(record['Attendance']) if pd.notna(record['Attendance']) else 0, | |
| 'Forecast_Next_3_Days__c': json.dumps(predictions), | |
| 'Project_ID__c': project_id, | |
| 'Alert_Status__c': record['Alert_status'], | |
| 'Dashboard_Display__c': True, | |
| 'Date__c': site_calendar_date.date().isoformat() | |
| } | |
| sf_result = save_to_salesforce(sf, salesforce_record) | |
| result_data.update(sf_result) | |
| trade_results[trade] = result_data | |
| if not trade_results: | |
| error_msg = "No valid trade data processed" | |
| if errors: | |
| error_msg += f". Errors: {'; '.join(errors)}" | |
| return error_msg, None, None, None, None | |
| line_chart = create_chart(df, predictions_dict) | |
| heatmap = create_heatmap(df, predictions_dict, shortage_probs, site_calendar_date) | |
| pdf_summary = generate_pdf_summary(trade_results, project_id) | |
| notification = notify_contractor(selected_trades[0], trade_results[selected_trades[0]]['Alert_status']) | |
| error_msg = "; ".join(errors) if errors else None | |
| return ( | |
| format_output(trade_results, site_calendar_date) + (f"\nWarnings: {error_msg}" if error_msg else ""), | |
| line_chart, | |
| heatmap, | |
| f'<a href="data:application/pdf;base64,{pdf_summary}" download="summary.pdf">Download Summary PDF</a>' if pdf_summary else "Error generating PDF", | |
| notification | |
| ) | |
| except Exception as e: | |
| logger.error(f"Unexpected error in forecast: {str(e)}", extra={'context': 'Forecast'}) | |
| return f"Unexpected error processing file: {str(e)}", None, None, None, None | |
| def gradio_interface(): | |
| with gr.Blocks(theme=gr.themes.Default()) as interface: | |
| gr.Markdown( | |
| """ | |
| # Labour Attendance Forecast | |
| Upload a CSV file with the following columns (case-insensitive, comma-separated): | |
| ``` | |
| Date,Attendance,Trade,Weather,Alert_status,Shortage_risk,Suggested_actions | |
| ``` | |
| Example: | |
| ``` | |
| Date,Attendance,Trade,Weather,Alert_status,Shortage_risk,Suggested_actions | |
| 2025-05-24,50,Painter,Sunny,Normal,20%,Monitor | |
| ``` | |
| - The CSV must contain at least 30 days of historical data. | |
| - Use UTF-8 encoding and comma delimiters. | |
| - Column names are now matched case-insensitively (e.g., `attendance` will be recognized as `Attendance`). | |
| - Check `labour_forecast.log` for detailed error messages, including raw and normalized column names. | |
| Optionally, filter by trade names (comma-separated) and specify a site calendar date (YYYY-MM-DD). | |
| """ | |
| ) | |
| with gr.Row(): | |
| csv_input = gr.File(label="Upload CSV", file_types=[".csv"]) | |
| trade_input = gr.Textbox( | |
| label="Filter by Trades", | |
| placeholder="e.g., Painter, Electrician (leave blank for all trades)" | |
| ) | |
| site_calendar_input = gr.Textbox( | |
| label="Site Calendar Date (YYYY-MM-DD)", | |
| placeholder="e.g., 2025-05-24" | |
| ) | |
| forecast_button = gr.Button("Generate Forecast", variant="primary") | |
| result_output = gr.Textbox(label="Forecast Result", lines=20, show_copy_button=True) | |
| line_chart_output = gr.Plot(label="Forecast Trendline") | |
| heatmap_output = gr.Plot(label="Shortage Risk Heatmap") | |
| pdf_output = gr.HTML(label="Download Summary PDF") | |
| notification_output = gr.Textbox(label="Contractor Notification") | |
| forecast_button.click( | |
| fn=forecast_labour, | |
| inputs=[csv_input, trade_input, site_calendar_input], | |
| outputs=[result_output, line_chart_output, heatmap_output, pdf_output, notification_output] | |
| ) | |
| logger.info("Launching Gradio interface", extra={'context': 'Gradio'}) | |
| interface.launch() | |
| if __name__ == '__main__': | |
| gradio_interface() |