pp / app.py
prudhviLatha's picture
Update app.py
39fd53f verified
import os
import gradio as gr
import pandas as pd
import numpy as np
from datetime import datetime
from simple_salesforce import Salesforce
from dotenv import load_dotenv
import plotly.express as px
# Load environment variables from .env
load_dotenv()
# Salesforce credentials
SF_USERNAME = os.getenv('SF_USERNAME')
SF_PASSWORD = os.getenv('SF_PASSWORD')
SF_SECURITY_TOKEN = os.getenv('SF_SECURITY_TOKEN')
# Connect to Salesforce
try:
sf = Salesforce(
username=SF_USERNAME,
password=SF_PASSWORD,
security_token=SF_SECURITY_TOKEN
)
except Exception as e:
sf = None
print(f"Error connecting to Salesforce: {str(e)}")
# Function to fetch Project ID from Salesforce automatically
def get_project_id():
if not sf:
return None, "Salesforce connection failed. Check credentials."
try:
query = "SELECT Id FROM Project__c ORDER BY CreatedDate DESC LIMIT 1"
result = sf.query(query)
if result['totalSize'] > 0:
return result['records'][0]['Id'], None
return None, "No project found in Salesforce."
except Exception as e:
return None, f"Error fetching Project ID: {str(e)}"
# Simple moving average forecast
def simple_forecast(df):
df['Date'] = pd.to_datetime(df['Date'], dayfirst=True)
df['Forecast'] = df['Attendance'].rolling(window=3, min_periods=1).mean()
future_dates = pd.date_range(df['Date'].max(), periods=4, freq='D')[1:]
future_preds = np.repeat(df['Forecast'].iloc[-1], 3)
predictions = [
{"date": date.strftime('%Y-%m-%d'), "headcount": round(pred)}
for date, pred in zip(future_dates, future_preds)
]
return predictions
# Save record to Salesforce
def save_to_salesforce(record):
if not sf:
return {"error": "Salesforce connection failed. Check credentials."}
try:
result = sf.Labour_Attendance_Forecast__c.create(record)
return {"success": f"Record created successfully for {record['Trade__c']}", "record_id": result['id']}
except Exception as e:
return {"error": f"Error uploading data to Salesforce for {record['Trade__c']}: {str(e)}"}
# Create line chart for multiple trades
def create_chart(df, predictions_dict):
combined_df = pd.DataFrame()
for trade, predictions in predictions_dict.items():
trade_df = df[df['Trade'] == trade].copy()
trade_df['Type'] = 'Historical'
trade_df['Trade'] = trade
forecast_df = pd.DataFrame(predictions)
forecast_df['Date'] = pd.to_datetime(forecast_df['date'])
forecast_df['Attendance'] = forecast_df['headcount']
forecast_df['Type'] = 'Forecast'
forecast_df['Trade'] = trade
combined_df = pd.concat([
combined_df,
trade_df[['Date', 'Attendance', 'Type', 'Trade']],
forecast_df[['Date', 'Attendance', 'Type', 'Trade']]
])
fig = px.line(
combined_df,
x='Date',
y='Attendance',
color='Trade',
line_dash='Type',
markers=True,
title='Labour Attendance Forecast by Trade'
)
return fig
# Format output in bullet/line-by-line style for multiple trades
def format_output(trade_results):
exclude_keys = {'Project__c', 'record_id', 'success'}
output = []
for trade, data in trade_results.items():
output.append(f"Trade: {trade}")
for key, value in data.items():
if key in exclude_keys:
continue
if isinstance(value, list):
value = ', '.join(str(item) for item in value)
output.append(f" • {key}: {value}")
output.append("")
return "\n".join(output)
# Forecast function for Gradio
def forecast_labour(csv_file):
try:
encodings = ['utf-8', 'latin1', 'iso-8859-1', 'utf-16']
df = None
for encoding in encodings:
try:
df = pd.read_csv(csv_file.name, encoding=encoding)
break
except UnicodeDecodeError:
continue
if df is None:
return "Error: Could not decode CSV file with any supported encoding (utf-8, latin1, iso-8859-1, utf-16). Please ensure the file is properly encoded.", None
df.columns = df.columns.str.strip().str.capitalize()
required_columns = ['Date', 'Attendance', 'Trade', 'Weather', 'Alert_status', 'Shortage_risk', 'Suggested_actions']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
return f"Error: CSV missing required columns: {', '.join(missing_columns)}", None
df['Date'] = pd.to_datetime(df['Date'], dayfirst=True)
df['Attendance'] = df['Attendance'].astype(int)
df['Shortage_risk'] = df['Shortage_risk'].replace('%', '', regex=True).astype(float) / 100
unique_trades = df['Trade'].unique()
if len(unique_trades) < 10:
return f"Error: CSV contains only {len(unique_trades)} trades, but a minimum of 10 trades is required.", None
selected_trades = unique_trades[:10]
trade_results = {}
predictions_dict = {}
project_id, error = get_project_id()
if error:
return f"Error: {error}", None
for trade in selected_trades:
trade_df = df[df['Trade'] == trade].copy()
if trade_df.empty:
continue
predictions = simple_forecast(trade_df)
predictions_dict[trade] = predictions
latest_record = trade_df.sort_values(by='Date').iloc[-1]
weather = latest_record['Weather']
alert_status = latest_record['Alert_status']
shortage_risk = latest_record['Shortage_risk']
suggested_actions = latest_record['Suggested_actions']
result_data = {
"Title": f"Labour Attendance Data for {trade}",
"Date": trade_df['Date'].max().strftime('%B %Y'),
"Trade": trade,
"Weather": weather,
"Forecast": predictions,
"Alert Status": alert_status,
"Shortage_risk": shortage_risk,
"Suggested_actions": suggested_actions,
"Expected_headcount": predictions[0]['headcount'],
"Actual_headcount": int(trade_df['Attendance'].iloc[-1]),
"Forecast_Next_3_Days__c": predictions,
"Project__c": project_id
}
salesforce_record = {
'Trade__c': trade,
'Shortage_Risk__c': shortage_risk,
'Suggested_Actions__c': suggested_actions,
'Expected_Headcount__c': result_data['Expected_headcount'],
'Actual_Headcount__c': result_data['Actual_headcount'],
'Forecast_Next_3_Days__c': str(predictions),
'Project_ID__c': project_id,
'Alert_Status__c': alert_status,
'Dashboard_Display__c': True,
'Date__c': trade_df['Date'].max().date().isoformat()
}
sf_result = save_to_salesforce(salesforce_record)
result_data.update(sf_result)
trade_results[trade] = result_data
chart = create_chart(df, predictions_dict)
return format_output(trade_results), chart
except Exception as e:
return f"Error processing file: {str(e)}", None
# Gradio UI without share
def gradio_interface():
gr.Interface(
fn=forecast_labour,
inputs=[
gr.File(label="Upload CSV with required columns for at least 10 trades")
],
outputs=[
gr.Textbox(label="Forecast Result", lines=20),
gr.Plot(label="Forecast Chart")
],
title="Labour Attendance Forecast",
description="Upload a CSV file with columns: Date, Attendance, Trade, Weather, Alert_Status, Shortage_Risk (e.g. 22%), Suggested_Actions. The file must contain data for at least 10 trades. "
).launch(share=False)
if __name__ == '__main__':
gradio_interface()