ganeshkonapalli commited on
Commit
8d75fd3
·
verified ·
1 Parent(s): 97051d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -55
app.py CHANGED
@@ -4,22 +4,24 @@ from typing import Optional
4
  import pandas as pd
5
  import joblib
6
  import os
7
- from sklearn.metrics import classification_report
8
- from sklearn.model_selection import train_test_split
9
 
 
10
  app = FastAPI()
11
 
12
  # --- Model paths ---
13
  TFIDF_VECTORIZER_PATH = "models/tfidf_vectorizer.pkl"
14
- MODELS_PATH = "models/xgb_model.pkl"
15
  LABEL_ENCODERS_PATH = "models/label_encoders.pkl"
16
 
17
- # --- Load models ---
18
- tfidf_vectorizer = joblib.load(TFIDF_VECTORIZER_PATH)
19
- models = joblib.load(MODELS_PATH)
20
- label_encoders = joblib.load(LABEL_ENCODERS_PATH)
 
 
 
21
 
22
- # --- Input Schema ---
23
  class TransactionData(BaseModel):
24
  Transaction_Id: str
25
  Hit_Seq: int
@@ -88,81 +90,46 @@ class TransactionData(BaseModel):
88
  class PredictionRequest(BaseModel):
89
  transaction_data: TransactionData
90
 
 
91
  @app.get("/")
92
- async def root():
93
  return {"status": "healthy", "message": "XGBoost TF-IDF API is running"}
94
 
 
95
  @app.post("/predict")
96
  async def predict(request: PredictionRequest):
97
  try:
98
  input_data = pd.DataFrame([request.transaction_data.dict()])
99
 
 
100
  text_input = "\n".join([
101
- str(input_data[col].iloc[0])
102
- for col in input_data.columns
103
- if pd.notna(input_data[col].iloc[0])
104
  ])
105
 
 
106
  X_tfidf = tfidf_vectorizer.transform([text_input])
107
 
 
108
  response = {}
109
  for label, model in models.items():
110
  proba = model.predict_proba(X_tfidf)[0]
111
- pred = proba.argmax()
112
- decoded_label = label_encoders[label].inverse_transform([pred])[0]
113
  class_probs = {
114
  label_encoders[label].classes_[i]: float(prob)
115
  for i, prob in enumerate(proba)
116
  }
117
  response[label] = {
118
- "prediction": decoded_label,
119
  "probabilities": class_probs
120
  }
121
 
122
  return response
123
 
124
  except Exception as e:
125
- raise HTTPException(status_code=500, detail=str(e))
126
-
127
- @app.get("/validate")
128
- async def validate_model():
129
- try:
130
- DATA_PATH = "data.csv" # Ensure this file is present
131
- df = pd.read_csv(DATA_PATH)
132
- df.dropna(subset=["Sanction_Context"] + list(models.keys()), inplace=True)
133
-
134
- def concat_text(row):
135
- return "\n".join([
136
- str(row.get(col, "")) for col in row.index
137
- ])
138
-
139
- df["combined_text"] = df.apply(concat_text, axis=1)
140
- X = tfidf_vectorizer.transform(df["combined_text"])
141
-
142
- results = {}
143
-
144
- for label in models:
145
- encoder = label_encoders[label]
146
- y = encoder.transform(df[label])
147
- _, X_test, _, y_test = train_test_split(
148
- X, y, test_size=0.2, random_state=42
149
- )
150
- model = models[label]
151
- y_pred = model.predict(X_test)
152
-
153
- report = classification_report(
154
- encoder.inverse_transform(y_test),
155
- encoder.inverse_transform(y_pred),
156
- output_dict=True
157
- )
158
-
159
- results[label] = report
160
-
161
- return {"validation_reports": results}
162
-
163
- except Exception as e:
164
- raise HTTPException(status_code=500, detail=f"Validation failed: {e}")
165
 
 
166
  if __name__ == "__main__":
167
  import uvicorn
168
  port = int(os.environ.get("PORT", 7860))
 
4
  import pandas as pd
5
  import joblib
6
  import os
 
 
7
 
8
+ # Initialize FastAPI app
9
  app = FastAPI()
10
 
11
  # --- Model paths ---
12
  TFIDF_VECTORIZER_PATH = "models/tfidf_vectorizer.pkl"
13
+ MODELS_PATH = "models/xgb_models.pkl"
14
  LABEL_ENCODERS_PATH = "models/label_encoders.pkl"
15
 
16
+ # --- Load Models ---
17
+ try:
18
+ tfidf_vectorizer = joblib.load(TFIDF_VECTORIZER_PATH)
19
+ models = joblib.load(MODELS_PATH)
20
+ label_encoders = joblib.load(LABEL_ENCODERS_PATH)
21
+ except Exception as e:
22
+ raise RuntimeError(f"Model loading failed: {e}")
23
 
24
+ # --- Input Schemas ---
25
  class TransactionData(BaseModel):
26
  Transaction_Id: str
27
  Hit_Seq: int
 
90
  class PredictionRequest(BaseModel):
91
  transaction_data: TransactionData
92
 
93
+ # --- Root Route ---
94
  @app.get("/")
95
+ def health_check():
96
  return {"status": "healthy", "message": "XGBoost TF-IDF API is running"}
97
 
98
+ # --- Predict Route ---
99
  @app.post("/predict")
100
  async def predict(request: PredictionRequest):
101
  try:
102
  input_data = pd.DataFrame([request.transaction_data.dict()])
103
 
104
+ # Concatenate relevant fields into a single string
105
  text_input = "\n".join([
106
+ str(input_data[col].iloc[0]) for col in input_data.columns if pd.notna(input_data[col].iloc[0])
 
 
107
  ])
108
 
109
+ # TF-IDF transform
110
  X_tfidf = tfidf_vectorizer.transform([text_input])
111
 
112
+ # Predict for each label
113
  response = {}
114
  for label, model in models.items():
115
  proba = model.predict_proba(X_tfidf)[0]
116
+ pred_idx = proba.argmax()
117
+ pred_label = label_encoders[label].inverse_transform([pred_idx])[0]
118
  class_probs = {
119
  label_encoders[label].classes_[i]: float(prob)
120
  for i, prob in enumerate(proba)
121
  }
122
  response[label] = {
123
+ "prediction": pred_label,
124
  "probabilities": class_probs
125
  }
126
 
127
  return response
128
 
129
  except Exception as e:
130
+ raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ # --- Run Locally (optional) ---
133
  if __name__ == "__main__":
134
  import uvicorn
135
  port = int(os.environ.get("PORT", 7860))