CDART / app.py
SailajaS's picture
Update app.py
c45f63b verified
from fastapi import FastAPI
import pandas as pd
import uvicorn
import joblib
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from pydantic import BaseModel
import gradio as gr
import os
import requests
app = FastAPI()
# βœ… Correct Hugging Face Dataset URL
DATASET_URL = "https://huggingface.co/datasets/SailajaS/CDART/resolve/9bfa82e31390ff6523f6b93777f745a78ecb2cd6/Sample_Case_Records__Real_Unique.csv?download=true"
# File path for saving dataset
DATASET_PATH = "dataset.csv"
# Function to download dataset
def download_dataset():
print("πŸ“₯ Downloading dataset from Hugging Face...")
try:
response = requests.get(DATASET_URL, timeout=10)
if response.status_code == 200:
with open(DATASET_PATH, "wb") as file:
file.write(response.content)
print("βœ… Dataset downloaded successfully!")
else:
raise Exception(f"❌ Failed to download dataset: {response.status_code}")
except requests.exceptions.RequestException as e:
print(f"❌ Error downloading dataset: {e}")
raise Exception("Dataset download failed.")
# βœ… Download dataset at startup
download_dataset()
# βœ… Load dataset with error handling
try:
df = pd.read_csv(DATASET_PATH, encoding="utf-8", delimiter=",", on_bad_lines="skip")
except:
try:
df = pd.read_csv(DATASET_PATH, encoding="utf-8", delimiter=";", on_bad_lines="skip")
except:
raise Exception("❌ Unable to read CSV. Check delimiter and format.")
# βœ… Check if necessary columns exist
required_columns = ["Case Problem", "Feedback"]
for col in required_columns:
if col not in df.columns:
raise Exception(f"❌ Column '{col}' is missing from the dataset!")
# βœ… Convert "Case Problem" & "Feedback" to lowercase and remove spaces
df["Case Problem"] = df["Case Problem"].astype(str).str.strip().str.lower()
df["Feedback"] = df["Feedback"].astype(str).str.strip().str.lower()
# βœ… Train and save LabelEncoders for both input and output
case_problem_encoder = LabelEncoder()
feedback_encoder = LabelEncoder()
df["Case Problem Encoded"] = case_problem_encoder.fit_transform(df["Case Problem"])
df["Feedback Encoded"] = feedback_encoder.fit_transform(df["Feedback"])
# βœ… Save encoders
joblib.dump(case_problem_encoder, "case_problem_encoder.pkl")
joblib.dump(feedback_encoder, "feedback_encoder.pkl")
# βœ… Train Model
X = df[["Case Problem Encoded"]]
y = df["Feedback Encoded"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
# βœ… Save model
joblib.dump(model, "feedback_model.pkl")
print("βœ… Model trained successfully!")
# βœ… API Input Model
class PredictionInput(BaseModel):
case_problem: str
@app.post("/predict/")
async def predict_feedback(data: PredictionInput):
""" Predicts feedback based on Case Problem """
if model is None:
return {"error": "Model is not trained yet."}
# βœ… Load encoders
case_problem_encoder = joblib.load("case_problem_encoder.pkl")
feedback_encoder = joblib.load("feedback_encoder.pkl")
# βœ… Convert input to lowercase and remove spaces
case_problem_lower = data.case_problem.strip().lower()
# βœ… Check if input exists in training data
if case_problem_lower not in df["Case Problem"].values:
valid_problems = list(df["Case Problem"].unique()) # Get valid options
return {
"error": f"Invalid case problem. Please enter a valid category.",
"Valid Categories": valid_problems
}
try:
case_problem_encoded = case_problem_encoder.transform([case_problem_lower])
prediction = model.predict([[case_problem_encoded[0]]])
feedback_predicted = feedback_encoder.inverse_transform(prediction)[0]
return {"Predicted Feedback": feedback_predicted}
except Exception as e:
return {"error": str(e)}
# βœ… Gradio UI with Submit button
def gradio_interface(case_problem):
if model is None:
return "Model not trained yet."
case_problem_encoder = joblib.load("case_problem_encoder.pkl")
feedback_encoder = joblib.load("feedback_encoder.pkl")
case_problem_lower = case_problem.strip().lower()
if case_problem_lower not in df["Case Problem"].values:
valid_problems = ", ".join(df["Case Problem"].unique())
return f"Invalid case problem. Please enter a valid category. Options: {valid_problems}"
try:
case_problem_encoded = case_problem_encoder.transform([case_problem_lower])
prediction = model.predict([[case_problem_encoded[0]]])
feedback_predicted = feedback_encoder.inverse_transform(prediction)[0]
return f"Predicted Feedback: {feedback_predicted}"
except Exception as e:
return f"Error: {str(e)}"
# βœ… Start API & Gradio with Submit button
def start_app():
""" Start API and Gradio Interface """
gr_interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Textbox(label="Enter Case Problem"),
outputs=gr.Textbox(label="Predicted Feedback"),
live=False, # βœ… Submit button enabled
allow_flagging="never",
)
gr_interface.launch(share=True, debug=True) # βœ… Debugging enabled
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
start_app()