Update app.py
Browse files
app.py
CHANGED
|
@@ -62,6 +62,9 @@ encoder = LabelEncoder()
|
|
| 62 |
df["Case Problem Encoded"] = encoder.fit_transform(df["Case Problem"])
|
| 63 |
df["Feedback Encoded"] = encoder.fit_transform(df["Feedback"])
|
| 64 |
|
|
|
|
|
|
|
|
|
|
| 65 |
# β
Train Model
|
| 66 |
X = df[["Case Problem Encoded"]]
|
| 67 |
y = df["Feedback Encoded"]
|
|
@@ -73,9 +76,6 @@ model.fit(X_train, y_train)
|
|
| 73 |
joblib.dump(model, "feedback_model.pkl")
|
| 74 |
print("β
Model trained successfully!")
|
| 75 |
|
| 76 |
-
# β
Save encoder classes for future use
|
| 77 |
-
joblib.dump(encoder, "case_problem_encoder.pkl")
|
| 78 |
-
|
| 79 |
# β
API Input Model
|
| 80 |
class PredictionInput(BaseModel):
|
| 81 |
case_problem: str
|
|
@@ -86,11 +86,19 @@ async def predict_feedback(data: PredictionInput):
|
|
| 86 |
if model is None:
|
| 87 |
return {"error": "Model is not trained yet."}
|
| 88 |
|
|
|
|
|
|
|
|
|
|
| 89 |
# β
Convert input to lowercase to match training data
|
| 90 |
case_problem_lower = data.case_problem.lower()
|
| 91 |
|
|
|
|
| 92 |
if case_problem_lower not in df["Case Problem"].values:
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
try:
|
| 96 |
case_problem_encoded = encoder.transform([case_problem_lower])
|
|
@@ -100,16 +108,17 @@ async def predict_feedback(data: PredictionInput):
|
|
| 100 |
except Exception as e:
|
| 101 |
return {"error": str(e)}
|
| 102 |
|
| 103 |
-
# β
Gradio UI with
|
| 104 |
def gradio_interface(case_problem):
|
| 105 |
if model is None:
|
| 106 |
return "Model not trained yet."
|
| 107 |
|
| 108 |
-
|
| 109 |
case_problem_lower = case_problem.lower()
|
| 110 |
|
| 111 |
if case_problem_lower not in df["Case Problem"].values:
|
| 112 |
-
|
|
|
|
| 113 |
|
| 114 |
try:
|
| 115 |
case_problem_encoded = encoder.transform([case_problem_lower])
|
|
@@ -126,9 +135,9 @@ def start_app():
|
|
| 126 |
fn=gradio_interface,
|
| 127 |
inputs="text",
|
| 128 |
outputs="text",
|
| 129 |
-
live=True
|
| 130 |
)
|
| 131 |
-
gr_interface.launch(share=True, debug=True) # β
Debugging enabled
|
| 132 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 133 |
|
| 134 |
if __name__ == "__main__":
|
|
|
|
| 62 |
df["Case Problem Encoded"] = encoder.fit_transform(df["Case Problem"])
|
| 63 |
df["Feedback Encoded"] = encoder.fit_transform(df["Feedback"])
|
| 64 |
|
| 65 |
+
# β
Save encoder for later use
|
| 66 |
+
joblib.dump(encoder, "case_problem_encoder.pkl")
|
| 67 |
+
|
| 68 |
# β
Train Model
|
| 69 |
X = df[["Case Problem Encoded"]]
|
| 70 |
y = df["Feedback Encoded"]
|
|
|
|
| 76 |
joblib.dump(model, "feedback_model.pkl")
|
| 77 |
print("β
Model trained successfully!")
|
| 78 |
|
|
|
|
|
|
|
|
|
|
| 79 |
# β
API Input Model
|
| 80 |
class PredictionInput(BaseModel):
|
| 81 |
case_problem: str
|
|
|
|
| 86 |
if model is None:
|
| 87 |
return {"error": "Model is not trained yet."}
|
| 88 |
|
| 89 |
+
# β
Load encoder
|
| 90 |
+
encoder = joblib.load("case_problem_encoder.pkl")
|
| 91 |
+
|
| 92 |
# β
Convert input to lowercase to match training data
|
| 93 |
case_problem_lower = data.case_problem.lower()
|
| 94 |
|
| 95 |
+
# β
Check if input exists in training data
|
| 96 |
if case_problem_lower not in df["Case Problem"].values:
|
| 97 |
+
valid_problems = list(df["Case Problem"].unique()) # Get valid options
|
| 98 |
+
return {
|
| 99 |
+
"error": f"Invalid case problem. Please enter a valid category.",
|
| 100 |
+
"Valid Categories": valid_problems
|
| 101 |
+
}
|
| 102 |
|
| 103 |
try:
|
| 104 |
case_problem_encoded = encoder.transform([case_problem_lower])
|
|
|
|
| 108 |
except Exception as e:
|
| 109 |
return {"error": str(e)}
|
| 110 |
|
| 111 |
+
# β
Gradio UI with suggestions for valid categories
|
| 112 |
def gradio_interface(case_problem):
|
| 113 |
if model is None:
|
| 114 |
return "Model not trained yet."
|
| 115 |
|
| 116 |
+
encoder = joblib.load("case_problem_encoder.pkl")
|
| 117 |
case_problem_lower = case_problem.lower()
|
| 118 |
|
| 119 |
if case_problem_lower not in df["Case Problem"].values:
|
| 120 |
+
valid_problems = ", ".join(df["Case Problem"].unique())
|
| 121 |
+
return f"Invalid case problem. Please enter a valid category. Options: {valid_problems}"
|
| 122 |
|
| 123 |
try:
|
| 124 |
case_problem_encoded = encoder.transform([case_problem_lower])
|
|
|
|
| 135 |
fn=gradio_interface,
|
| 136 |
inputs="text",
|
| 137 |
outputs="text",
|
| 138 |
+
live=True
|
| 139 |
)
|
| 140 |
+
gr_interface.launch(share=True, debug=True) # β
Debugging enabled
|
| 141 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 142 |
|
| 143 |
if __name__ == "__main__":
|