subbunanepalli commited on
Commit
f8cea34
·
verified ·
1 Parent(s): c8da0d8

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +12 -0
predict.py CHANGED
@@ -4,15 +4,26 @@ import joblib
4
  from validate import TransactionData
5
  from utils import create_text_input
6
 
 
7
  MODEL_PATH = "models/logreg_model.pkl"
8
 
9
  def predict(request: TransactionData):
10
  try:
 
11
  model = joblib.load(MODEL_PATH)
 
 
 
 
 
 
12
  input_df = pd.DataFrame([request.dict()]).fillna("")
13
  text_input = create_text_input(input_df.iloc[0])
 
 
14
  prediction = model.predict([text_input])[0]
15
 
 
16
  return {
17
  "Maker_Action": prediction[0],
18
  "Escalation_Level": prediction[1],
@@ -21,5 +32,6 @@ def predict(request: TransactionData):
21
  "Investigation_Outcome": prediction[4],
22
  "Alert_Status": prediction[5]
23
  }
 
24
  except Exception as e:
25
  raise HTTPException(status_code=500, detail=str(e))
 
4
  from validate import TransactionData
5
  from utils import create_text_input
6
 
7
+ # === Path to saved model ===
8
  MODEL_PATH = "models/logreg_model.pkl"
9
 
10
  def predict(request: TransactionData):
11
  try:
12
+ # Load the model pipeline (TfidfVectorizer + MultiOutputClassifier)
13
  model = joblib.load(MODEL_PATH)
14
+
15
+ # Safety check to ensure it's a model
16
+ if not hasattr(model, "predict"):
17
+ raise ValueError("Loaded object is not a model pipeline")
18
+
19
+ # Prepare input
20
  input_df = pd.DataFrame([request.dict()]).fillna("")
21
  text_input = create_text_input(input_df.iloc[0])
22
+
23
+ # Make prediction
24
  prediction = model.predict([text_input])[0]
25
 
26
+ # Return predictions as dict
27
  return {
28
  "Maker_Action": prediction[0],
29
  "Escalation_Level": prediction[1],
 
32
  "Investigation_Outcome": prediction[4],
33
  "Alert_Status": prediction[5]
34
  }
35
+
36
  except Exception as e:
37
  raise HTTPException(status_code=500, detail=str(e))