SailajaS commited on
Commit
a9baa2a
Β·
verified Β·
1 Parent(s): db86ab6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -9
app.py CHANGED
@@ -62,6 +62,9 @@ encoder = LabelEncoder()
62
  df["Case Problem Encoded"] = encoder.fit_transform(df["Case Problem"])
63
  df["Feedback Encoded"] = encoder.fit_transform(df["Feedback"])
64
 
 
 
 
65
  # βœ… Train Model
66
  X = df[["Case Problem Encoded"]]
67
  y = df["Feedback Encoded"]
@@ -73,9 +76,6 @@ model.fit(X_train, y_train)
73
  joblib.dump(model, "feedback_model.pkl")
74
  print("βœ… Model trained successfully!")
75
 
76
- # βœ… Save encoder classes for future use
77
- joblib.dump(encoder, "case_problem_encoder.pkl")
78
-
79
  # βœ… API Input Model
80
  class PredictionInput(BaseModel):
81
  case_problem: str
@@ -86,11 +86,19 @@ async def predict_feedback(data: PredictionInput):
86
  if model is None:
87
  return {"error": "Model is not trained yet."}
88
 
 
 
 
89
  # βœ… Convert input to lowercase to match training data
90
  case_problem_lower = data.case_problem.lower()
91
 
 
92
  if case_problem_lower not in df["Case Problem"].values:
93
- return {"error": "Invalid case problem. Please enter a valid category from the dataset."}
 
 
 
 
94
 
95
  try:
96
  case_problem_encoded = encoder.transform([case_problem_lower])
@@ -100,16 +108,17 @@ async def predict_feedback(data: PredictionInput):
100
  except Exception as e:
101
  return {"error": str(e)}
102
 
103
- # βœ… Gradio UI with async execution
104
  def gradio_interface(case_problem):
105
  if model is None:
106
  return "Model not trained yet."
107
 
108
- # βœ… Convert input to lowercase for consistency
109
  case_problem_lower = case_problem.lower()
110
 
111
  if case_problem_lower not in df["Case Problem"].values:
112
- return "Invalid case problem. Please enter a valid category from the dataset."
 
113
 
114
  try:
115
  case_problem_encoded = encoder.transform([case_problem_lower])
@@ -126,9 +135,9 @@ def start_app():
126
  fn=gradio_interface,
127
  inputs="text",
128
  outputs="text",
129
- live=True # βœ… Ensures Gradio UI updates properly
130
  )
131
- gr_interface.launch(share=True, debug=True) # βœ… Debugging enabled to see errors
132
  uvicorn.run(app, host="0.0.0.0", port=8000)
133
 
134
  if __name__ == "__main__":
 
62
  df["Case Problem Encoded"] = encoder.fit_transform(df["Case Problem"])
63
  df["Feedback Encoded"] = encoder.fit_transform(df["Feedback"])
64
 
65
+ # βœ… Save encoder for later use
66
+ joblib.dump(encoder, "case_problem_encoder.pkl")
67
+
68
  # βœ… Train Model
69
  X = df[["Case Problem Encoded"]]
70
  y = df["Feedback Encoded"]
 
76
  joblib.dump(model, "feedback_model.pkl")
77
  print("βœ… Model trained successfully!")
78
 
 
 
 
79
  # βœ… API Input Model
80
  class PredictionInput(BaseModel):
81
  case_problem: str
 
86
  if model is None:
87
  return {"error": "Model is not trained yet."}
88
 
89
+ # βœ… Load encoder
90
+ encoder = joblib.load("case_problem_encoder.pkl")
91
+
92
  # βœ… Convert input to lowercase to match training data
93
  case_problem_lower = data.case_problem.lower()
94
 
95
+ # βœ… Check if input exists in training data
96
  if case_problem_lower not in df["Case Problem"].values:
97
+ valid_problems = list(df["Case Problem"].unique()) # Get valid options
98
+ return {
99
+ "error": f"Invalid case problem. Please enter a valid category.",
100
+ "Valid Categories": valid_problems
101
+ }
102
 
103
  try:
104
  case_problem_encoded = encoder.transform([case_problem_lower])
 
108
  except Exception as e:
109
  return {"error": str(e)}
110
 
111
+ # βœ… Gradio UI with suggestions for valid categories
112
  def gradio_interface(case_problem):
113
  if model is None:
114
  return "Model not trained yet."
115
 
116
+ encoder = joblib.load("case_problem_encoder.pkl")
117
  case_problem_lower = case_problem.lower()
118
 
119
  if case_problem_lower not in df["Case Problem"].values:
120
+ valid_problems = ", ".join(df["Case Problem"].unique())
121
+ return f"Invalid case problem. Please enter a valid category. Options: {valid_problems}"
122
 
123
  try:
124
  case_problem_encoded = encoder.transform([case_problem_lower])
 
135
  fn=gradio_interface,
136
  inputs="text",
137
  outputs="text",
138
+ live=True
139
  )
140
+ gr_interface.launch(share=True, debug=True) # βœ… Debugging enabled
141
  uvicorn.run(app, host="0.0.0.0", port=8000)
142
 
143
  if __name__ == "__main__":