|
|
import json |
|
|
import os |
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
from openai import OpenAI |
|
|
from transformers import pipeline |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
api_key = os.getenv("OPENAI_API_KEY") |
|
|
client = OpenAI(api_key=api_key) |
|
|
|
|
|
|
|
|
pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest") |
|
|
|
|
|
|
|
|
def process_csv(file): |
|
|
df = pd.read_csv(file) |
|
|
|
|
|
if "Feedback" not in df.columns or "Employee" not in df.columns: |
|
|
return None, "❌ Error: CSV must contain 'Employee' and 'Feedback' columns." |
|
|
|
|
|
|
|
|
df["Sentiment"] = df["Feedback"].apply(lambda x: pipe(x)[0]["label"]) |
|
|
|
|
|
return {"df": df}, "✅ CSV file processed successfully!" |
|
|
|
|
|
|
|
|
def predict_attrition_risk(employee_name: str, sentiment: str): |
|
|
"""Predicts attrition risk based on employee sentiment.""" |
|
|
risk_level = { |
|
|
"positive": "Low Risk - Positive sentiment detected.", |
|
|
"neutral": "Medium Risk - Neutral sentiment detected.", |
|
|
"negative": "High Risk - Negative sentiment detected." |
|
|
} |
|
|
return f"{employee_name}: {risk_level.get(sentiment.lower(), 'Unknown Sentiment')}" |
|
|
|
|
|
|
|
|
def analyze_attrition_with_llm(df_dict, hr_query): |
|
|
if df_dict is None or "df" not in df_dict: |
|
|
return "❌ Error: No processed employee data available. Please upload a CSV file first." |
|
|
|
|
|
df = df_dict["df"] |
|
|
|
|
|
employees_data = [ |
|
|
{"employee_name": row["Employee"], "sentiment": row["Sentiment"], "feedback": row["Feedback"]} |
|
|
for _, row in df.iterrows() |
|
|
] |
|
|
|
|
|
prompt = f"HR asked: '{hr_query}'. Here is the employee sentiment data:\n{json.dumps(employees_data, indent=2)}\n" \ |
|
|
"Based on sentiment, determine attrition risk and call the function predict_attrition_risk for specific employees." |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model="gpt-4-turbo", |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
functions=[ |
|
|
{ |
|
|
"name": "predict_attrition_risk", |
|
|
"description": "Predicts attrition risk based on sentiment.", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"employee_name": {"type": "string", "description": "Employee's name"}, |
|
|
"sentiment": {"type": "string", "description": "Extracted sentiment"} |
|
|
}, |
|
|
"required": ["employee_name", "sentiment"] |
|
|
} |
|
|
} |
|
|
], |
|
|
function_call="auto" |
|
|
) |
|
|
|
|
|
|
|
|
print("🔍 LLM Response:", response) |
|
|
|
|
|
|
|
|
message = response.choices[0].message |
|
|
|
|
|
if hasattr(message, "function_call") and message.function_call is not None: |
|
|
try: |
|
|
function_call = json.loads(message.function_call.arguments) |
|
|
employee_name = function_call.get("employee_name") |
|
|
sentiment = function_call.get("sentiment") |
|
|
|
|
|
if employee_name and sentiment: |
|
|
return f"🤖 **LLM Called Function:** `predict_attrition_risk({employee_name}, {sentiment})`\n\n" + predict_attrition_risk(employee_name, sentiment) |
|
|
except Exception as e: |
|
|
return f"❌ Error processing LLM function call: {str(e)}" |
|
|
|
|
|
return "🤖 No specific attrition risk prediction was made." |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("<h1 style='text-align: center;'>AI-Driven Employee Attrition Risk Analysis</h1>") |
|
|
|
|
|
file_input = gr.File(label="Upload Employee Feedback CSV", file_types=[".csv"]) |
|
|
process_button = gr.Button("Process CSV") |
|
|
process_message = gr.Markdown(label="Processing Status") |
|
|
hr_input = gr.Textbox(label="HR Query (e.g., 'Which employees are at high risk?')") |
|
|
analyze_button = gr.Button("Ask HR Query") |
|
|
output_text = gr.Markdown(label="Attrition Risk Prediction") |
|
|
df_state = gr.State() |
|
|
process_button.click(process_csv, inputs=file_input, outputs=[df_state, process_message]) |
|
|
analyze_button.click(analyze_attrition_with_llm, inputs=[df_state, hr_input], outputs=output_text) |
|
|
|
|
|
demo.launch(share=True) |
|
|
|