subbunanepalli commited on
Commit
79d6b55
·
verified ·
1 Parent(s): 9470189

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +347 -0
app.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form
2
+ from fastapi.responses import FileResponse
3
+ from pydantic import BaseModel
4
+ from typing import Optional, Dict, Any, List
5
+ import uvicorn
6
+ import logging
7
+ import os
8
+ import pandas as pd
9
+ from datetime import datetime
10
+ import shutil
11
+ from pathlib import Path
12
+ import numpy as np
13
+ import json
14
+ import joblib
15
+ from sklearn.metrics import classification_report
16
+ from sklearn.multioutput import MultiOutputClassifier
17
+ from sklearn.feature_extraction.text import TfidfVectorizer
18
+ from sklearn.linear_model import LogisticRegression
19
+
20
+
21
+ # Import existing utilities
22
+ from dataset_utils import (
23
+ load_and_preprocess_data,
24
+ save_label_encoders,
25
+ load_label_encoders
26
+ )
27
+ from config import (
28
+ TEXT_COLUMN,
29
+ LABEL_COLUMNS,
30
+ BATCH_SIZE,
31
+ MODEL_SAVE_DIR
32
+ )
33
+ from models.tfidf_logreg import TfidfLogisticRegression
34
+
35
+ # Configure logging
36
+ logging.basicConfig(level=logging.INFO)
37
+ logger = logging.getLogger(__name__)
38
+
39
+ app = FastAPI(title="LOGREG Compliance Predictor API")
40
+
41
+ UPLOAD_DIR = Path("uploads")
42
+ MODEL_SAVE_DIR = Path("saved_models")
43
+ UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
44
+ MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
45
+
46
+ # Define paths for vectorizer, model, and encoders
47
+ TFIDF_PATH = os.path.join(str(MODEL_SAVE_DIR), "tfidf_vectorizer.pkl")
48
+ MODEL_PATH = os.path.join(str(MODEL_SAVE_DIR), "logreg_models.pkl")
49
+ ENCODERS_PATH = os.path.join(os.path.dirname(__file__), "label_encoders.pkl")
50
+
51
+ training_status = {
52
+ "is_training": False,
53
+ "current_epoch": 0,
54
+ "total_epochs": 0,
55
+ "current_loss": 0.0,
56
+ "start_time": None,
57
+ "end_time": None,
58
+ "status": "idle",
59
+ "metrics": None
60
+ }
61
+
62
+ class TrainingConfig(BaseModel):
63
+ batch_size: int = 32
64
+ num_epochs: int = 1 # Not used for LGBM, but kept for API compatibility
65
+ random_state: int = 42
66
+
67
+ class TrainingResponse(BaseModel):
68
+ message: str
69
+ training_id: str
70
+ status: str
71
+ download_url: Optional[str] = None
72
+
73
+ class ValidationResponse(BaseModel):
74
+ message: str
75
+ metrics: Dict[str, Any]
76
+ predictions: List[Dict[str, Any]]
77
+
78
+ class TransactionData(BaseModel):
79
+ Transaction_Id: str
80
+ Hit_Seq: int
81
+ Hit_Id_List: str
82
+ Origin: str
83
+ Designation: str
84
+ Keywords: str
85
+ Name: str
86
+ SWIFT_Tag: str
87
+ Currency: str
88
+ Entity: str
89
+ Message: str
90
+ City: str
91
+ Country: str
92
+ State: str
93
+ Hit_Type: str
94
+ Record_Matching_String: str
95
+ WatchList_Match_String: str
96
+ Payment_Sender_Name: Optional[str] = ""
97
+ Payment_Reciever_Name: Optional[str] = ""
98
+ Swift_Message_Type: str
99
+ Text_Sanction_Data: str
100
+ Matched_Sanctioned_Entity: str
101
+ Is_Match: int
102
+ Red_Flag_Reason: str
103
+ Risk_Level: str
104
+ Risk_Score: float
105
+ Risk_Score_Description: str
106
+ CDD_Level: str
107
+ PEP_Status: str
108
+ Value_Date: str
109
+ Last_Review_Date: str
110
+ Next_Review_Date: str
111
+ Sanction_Description: str
112
+ Checker_Notes: str
113
+ Sanction_Context: str
114
+ Maker_Action: str
115
+ Customer_ID: int
116
+ Customer_Type: str
117
+ Industry: str
118
+ Transaction_Date_Time: str
119
+ Transaction_Type: str
120
+ Transaction_Channel: str
121
+ Originating_Bank: str
122
+ Beneficiary_Bank: str
123
+ Geographic_Origin: str
124
+ Geographic_Destination: str
125
+ Match_Score: float
126
+ Match_Type: str
127
+ Sanctions_List_Version: str
128
+ Screening_Date_Time: str
129
+ Risk_Category: str
130
+ Risk_Drivers: str
131
+ Alert_Status: str
132
+ Investigation_Outcome: str
133
+ Case_Owner_Analyst: str
134
+ Escalation_Level: str
135
+ Escalation_Date: str
136
+ Regulatory_Reporting_Flags: bool
137
+ Audit_Trail_Timestamp: str
138
+ Source_Of_Funds: str
139
+ Purpose_Of_Transaction: str
140
+ Beneficial_Owner: str
141
+ Sanctions_Exposure_History: bool
142
+
143
+
144
+ class PredictionRequest(BaseModel):
145
+ transaction_data: TransactionData
146
+ model_name: str = "logreg_models" # Default to tfidf_logreg if not specified
147
+
148
+ class BatchPredictionResponse(BaseModel):
149
+ message: str
150
+ predictions: List[Dict[str, Any]]
151
+ metrics: Optional[Dict[str, Any]] = None
152
+
153
+ @app.get("/")
154
+ async def root():
155
+ return {"message": "LOGREG Compliance Predictor API"}
156
+
157
+ @app.get("/v1/logreg/health")
158
+ async def health_check():
159
+ return {"status": "healthy"}
160
+
161
+ @app.get("/v1/logreg/training-status")
162
+ async def get_training_status():
163
+ return training_status
164
+
165
+ @app.post("/v1/logreg/train", response_model=TrainingResponse)
166
+ async def start_training(
167
+ config: str = Form(...),
168
+ background_tasks: BackgroundTasks = None,
169
+ file: UploadFile = File(...)
170
+ ):
171
+ if training_status["is_training"]:
172
+ raise HTTPException(status_code=400, detail="Training is already in progress")
173
+ if not file.filename.endswith('.csv'):
174
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
175
+ try:
176
+ config_dict = json.loads(config)
177
+ training_config = TrainingConfig(**config_dict)
178
+ except Exception as e:
179
+ raise HTTPException(status_code=400, detail=f"Invalid config parameters: {str(e)}")
180
+ file_path = UPLOAD_DIR / file.filename
181
+ with file_path.open("wb") as buffer:
182
+ shutil.copyfileobj(file.file, buffer)
183
+ training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
184
+ training_status.update({
185
+ "is_training": True,
186
+ "current_epoch": 0,
187
+ "total_epochs": 1,
188
+ "start_time": datetime.now().isoformat(),
189
+ "status": "starting"
190
+ })
191
+ background_tasks.add_task(train_model_task, training_config, str(file_path), training_id)
192
+ download_url = f"/v1/logreg/download-model/{training_id}"
193
+ return TrainingResponse(
194
+ message="Training started successfully",
195
+ training_id=training_id,
196
+ status="started",
197
+ download_url=download_url
198
+ )
199
+
200
+ @app.post("/v1/logreg/validate")
201
+ async def validate_model(
202
+ file: UploadFile = File(...),
203
+ model_name: str = "logreg_models"
204
+ ):
205
+ if not file.filename.endswith('.csv'):
206
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
207
+ try:
208
+ file_path = UPLOAD_DIR / file.filename
209
+ with file_path.open("wb") as buffer:
210
+ shutil.copyfileobj(file.file, buffer)
211
+ data_df, label_encoders = load_and_preprocess_data(str(file_path))
212
+ model_path = MODEL_SAVE_DIR / f"{model_name}.pkl"
213
+ if not model_path.exists():
214
+ raise HTTPException(status_code=404, detail="XGB model file not found")
215
+ model = TfidfLOGREG(label_encoders)
216
+ model.load_model(model_name)
217
+ X = data_df[TEXT_COLUMN]
218
+ y = data_df[LABEL_COLUMNS]
219
+ # Type and shape check for X
220
+ if not isinstance(X, pd.Series) or not pd.api.types.is_string_dtype(X):
221
+ raise HTTPException(status_code=400, detail=f"TEXT_COLUMN ('{TEXT_COLUMN}') must be a pandas Series of strings. Got type: {type(X)}, dtype: {getattr(X, 'dtype', None)}")
222
+ reports, y_true_list, y_pred_list = model.evaluate(X, y)
223
+ all_probs = model.predict_proba(X)
224
+ predictions = []
225
+ for i, col in enumerate(LABEL_COLUMNS):
226
+ label_encoder = label_encoders[col]
227
+ true_labels_orig = label_encoder.inverse_transform(y_true_list[i])
228
+ pred_labels_orig = label_encoder.inverse_transform(y_pred_list[i])
229
+ for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]):
230
+ class_probs = {label: float(prob) for label, prob in zip(label_encoder.classes_, probs)}
231
+ predictions.append({
232
+ "field": col,
233
+ "true_label": true,
234
+ "predicted_label": pred,
235
+ "probabilities": class_probs
236
+ })
237
+ return ValidationResponse(
238
+ message="Validation completed successfully",
239
+ metrics=reports,
240
+ predictions=predictions
241
+ )
242
+ except Exception as e:
243
+ logger.error(f"Validation failed: {str(e)}")
244
+ raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
245
+ finally:
246
+ if os.path.exists(file_path):
247
+ os.remove(file_path)
248
+
249
+ @app.post("/v1/logreg/predict")
250
+ async def predict(
251
+ request: Optional[PredictionRequest] = None,
252
+ file: UploadFile = File(None),
253
+ model_name: str = "logreg_models"
254
+ ):
255
+ try:
256
+ # Load vectorizer, model, and encoders
257
+ tfidf = joblib.load(TFIDF_PATH)
258
+ model = joblib.load(MODEL_PATH)
259
+ encoders = joblib.load(ENCODERS_PATH)
260
+ # Batch prediction from CSV
261
+ if file and file.filename:
262
+ if not file.filename.endswith('.csv'):
263
+ raise HTTPException(status_code=400, detail="Only CSV files are allowed")
264
+ file_path = UPLOAD_DIR / file.filename
265
+ with file_path.open("wb") as buffer:
266
+ shutil.copyfileobj(file.file, buffer)
267
+ try:
268
+ data_df, _ = load_and_preprocess_data(str(file_path))
269
+ # Concatenate all fields into a single string for each row
270
+ texts = data_df.apply(lambda row: " ".join([str(val) for val in row.values if pd.notna(val)]), axis=1)
271
+ X_vec = tfidf.transform(texts)
272
+ preds = model.predict(X_vec)
273
+ predictions = []
274
+ for i, pred in enumerate(preds):
275
+ decoded = {
276
+ col: encoders[col].inverse_transform([label])[0]
277
+ for col, label in zip(LABEL_COLUMNS, pred)
278
+ }
279
+ predictions.append({
280
+ "transaction_id": data_df.iloc[i].get('Transaction_Id', f"transaction_{i}"),
281
+ "predictions": decoded
282
+ })
283
+ return BatchPredictionResponse(
284
+ message="Batch prediction completed successfully",
285
+ predictions=predictions
286
+ )
287
+ finally:
288
+ if os.path.exists(file_path):
289
+ os.remove(file_path)
290
+ # Single prediction
291
+ elif request and request.transaction_data:
292
+ input_data = pd.DataFrame([request.transaction_data.dict()])
293
+ text_input = " ".join([
294
+ str(val) for val in input_data.iloc[0].values if pd.notna(val)
295
+ ])
296
+ X_vec = tfidf.transform([text_input])
297
+ pred = model.predict(X_vec)[0]
298
+ decoded = {
299
+ col: encoders[col].inverse_transform([p])[0]
300
+ for col, p in zip(LABEL_COLUMNS, pred)
301
+ }
302
+ return decoded
303
+ else:
304
+ raise HTTPException(
305
+ status_code=400,
306
+ detail="Either provide a transaction in the request body or upload a CSV file"
307
+ )
308
+ except Exception as e:
309
+ raise HTTPException(status_code=500, detail=str(e))
310
+
311
+ @app.get("/v1/logreg/download-model/{model_id}")
312
+ async def download_model(model_id: str):
313
+ model_path = MODEL_SAVE_DIR / f"{model_id}.pkl"
314
+ if not model_path.exists():
315
+ raise HTTPException(status_code=404, detail="Model not found")
316
+ return FileResponse(
317
+ path=model_path,
318
+ filename=f"logreg_model_{model_id}.pkl",
319
+ media_type="application/octet-stream"
320
+ )
321
+
322
+ async def train_model_task(config: TrainingConfig, file_path: str, training_id: str):
323
+ try:
324
+ data_df_original, label_encoders = load_and_preprocess_data(file_path)
325
+ save_label_encoders(label_encoders)
326
+ X = data_df_original[TEXT_COLUMN]
327
+ y = data_df_original[LABEL_COLUMNS]
328
+ model = TfidfXGB(label_encoders)
329
+ model.train(X, y)
330
+ model.save_model(training_id)
331
+ training_status.update({
332
+ "is_training": False,
333
+ "end_time": datetime.now().isoformat(),
334
+ "status": "completed"
335
+ })
336
+ except Exception as e:
337
+ logger.error(f"Training failed: {str(e)}")
338
+ training_status.update({
339
+ "is_training": False,
340
+ "end_time": datetime.now().isoformat(),
341
+ "status": "failed",
342
+ "error": str(e)
343
+ })
344
+
345
+ if __name__ == "__main__":
346
+ port = int(os.environ.get("PORT", 7860))
347
+ uvicorn.run(app, host="0.0.0.0", port=port)