Xgboost_hm / app.py
ganeshkonapalli's picture
Update app.py
7d1974f verified
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import Optional, Dict, Any, List
import uvicorn
import logging
import os
import pandas as pd
from datetime import datetime
import shutil
from pathlib import Path
import numpy as np
import sys
import json
import joblib
from dataset_utils import load_and_preprocess_data, save_label_encoders, load_label_encoders
from config import TEXT_COLUMN, LABEL_COLUMNS, BATCH_SIZE, MODEL_SAVE_DIR
from tfidf_xgb import TfidfXGBoost # ✅ XGBoost model class
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="XGBoost Compliance Predictor API")
UPLOAD_DIR = Path("uploads")
MODEL_SAVE_DIR = Path("saved_models")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
TFIDF_PATH = os.path.join(str(MODEL_SAVE_DIR), "tfidf_vectorizer.pkl")
MODEL_PATH = os.path.join(str(MODEL_SAVE_DIR), "xgb_models.pkl")
ENCODERS_PATH = os.path.join(os.path.dirname(__file__), "label_encoders.pkl")
training_status = {
"is_training": False,
"current_epoch": 0,
"total_epochs": 0,
"current_loss": 0.0,
"start_time": None,
"end_time": None,
"status": "idle",
"metrics": None
}
# --- Pydantic Models (unchanged) ---
class TrainingConfig(BaseModel):
batch_size: int = 32
num_epochs: int = 1
random_state: int = 42
class TrainingResponse(BaseModel):
message: str
training_id: str
status: str
download_url: Optional[str] = None
class ValidationResponse(BaseModel):
message: str
metrics: Dict[str, Any]
predictions: List[Dict[str, Any]]
class TransactionData(BaseModel):
# same as original
Transaction_Id: str
Hit_Seq: int
Hit_Id_List: str
Origin: str
Designation: str
Keywords: str
Name: str
SWIFT_Tag: str
Currency: str
Entity: str
Message: str
City: str
Country: str
State: str
Hit_Type: str
Record_Matching_String: str
WatchList_Match_String: str
Payment_Sender_Name: Optional[str] = ""
Payment_Reciever_Name: Optional[str] = ""
Swift_Message_Type: str
Text_Sanction_Data: str
Matched_Sanctioned_Entity: str
Is_Match: int
Red_Flag_Reason: str
Risk_Level: str
Risk_Score: float
Risk_Score_Description: str
CDD_Level: str
PEP_Status: str
Value_Date: str
Last_Review_Date: str
Next_Review_Date: str
Sanction_Description: str
Checker_Notes: str
Sanction_Context: str
Maker_Action: str
Customer_ID: int
Customer_Type: str
Industry: str
Transaction_Date_Time: str
Transaction_Type: str
Transaction_Channel: str
Originating_Bank: str
Beneficiary_Bank: str
Geographic_Origin: str
Geographic_Destination: str
Match_Score: float
Match_Type: str
Sanctions_List_Version: str
Screening_Date_Time: str
Risk_Category: str
Risk_Drivers: str
Alert_Status: str
Investigation_Outcome: str
Case_Owner_Analyst: str
Escalation_Level: str
Escalation_Date: str
Regulatory_Reporting_Flags: bool
Audit_Trail_Timestamp: str
Source_Of_Funds: str
Purpose_Of_Transaction: str
Beneficial_Owner: str
Sanctions_Exposure_History: bool
class PredictionRequest(BaseModel):
transaction_data: TransactionData
model_name: str = "xgb_models"
class BatchPredictionResponse(BaseModel):
message: str
predictions: List[Dict[str, Any]]
metrics: Optional[Dict[str, Any]] = None
# --- API Routes ---
@app.get("/")
async def root():
return {"message": "XGBoost Compliance Predictor API"}
@app.get("/v1/xgb/health")
async def health_check():
return {"status": "healthy"}
@app.get("/v1/xgb/training-status")
async def get_training_status():
return training_status
@app.post("/v1/xgb/train", response_model=TrainingResponse)
async def start_training(
config: str = Form(...),
background_tasks: BackgroundTasks = None,
file: UploadFile = File(...)
):
if training_status["is_training"]:
raise HTTPException(status_code=400, detail="Training is already in progress")
if not file.filename.endswith('.csv'):
raise HTTPException(status_code=400, detail="Only CSV files are allowed")
try:
config_dict = json.loads(config)
training_config = TrainingConfig(**config_dict)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid config: {str(e)}")
file_path = UPLOAD_DIR / file.filename
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
training_id = datetime.now().strftime("%Y%m%d_%H%M%S")
training_status.update({
"is_training": True,
"start_time": datetime.now().isoformat(),
"status": "starting"
})
background_tasks.add_task(train_model_task, training_config, str(file_path), training_id)
download_url = f"/v1/xgb/download-model/{training_id}"
return TrainingResponse(
message="Training started",
training_id=training_id,
status="started",
download_url=download_url
)
@app.post("/v1/xgb/validate")
async def validate_model(file: UploadFile = File(...), model_name: str = "xgb_models"):
if not file.filename.endswith('.csv'):
raise HTTPException(status_code=400, detail="Only CSV files allowed")
try:
file_path = UPLOAD_DIR / file.filename
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
data_df, label_encoders = load_and_preprocess_data(str(file_path))
model = TfidfXGBoost(label_encoders)
model.load_model(model_name)
X = data_df[TEXT_COLUMN]
y = data_df[LABEL_COLUMNS]
reports, y_true_list, y_pred_list = model.evaluate(X, y)
all_probs = model.predict_proba(X)
predictions = []
for i, col in enumerate(LABEL_COLUMNS):
le = label_encoders[col]
true_labels = le.inverse_transform(y_true_list[i])
pred_labels = le.inverse_transform(y_pred_list[i])
for true, pred, probs in zip(true_labels, pred_labels, all_probs[i]):
class_probs = {label: float(prob) for label, prob in zip(le.classes_, probs)}
predictions.append({
"field": col,
"true_label": true,
"predicted_label": pred,
"probabilities": class_probs
})
return ValidationResponse(
message="Validation completed",
metrics=reports,
predictions=predictions
)
except Exception as e:
logger.error(f"Validation failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Validation error: {str(e)}")
finally:
if os.path.exists(file_path):
os.remove(file_path)
@app.post("/v1/xgb/predict")
async def predict(request: Optional[PredictionRequest] = None, file: UploadFile = File(None), model_name: str = "xgb_models"):
try:
tfidf = joblib.load(TFIDF_PATH)
model = joblib.load(MODEL_PATH)
encoders = joblib.load(ENCODERS_PATH)
if file and file.filename:
if not file.filename.endswith('.csv'):
raise HTTPException(status_code=400, detail="Only CSV files allowed")
file_path = UPLOAD_DIR / file.filename
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
data_df, _ = load_and_preprocess_data(str(file_path))
texts = data_df.apply(lambda row: " ".join([str(v) for v in row.values if pd.notna(v)]), axis=1)
X_vec = tfidf.transform(texts)
preds = model.predict(X_vec)
predictions = []
for i, pred in enumerate(preds):
decoded = {
col: encoders[col].inverse_transform([label])[0]
for col, label in zip(LABEL_COLUMNS, pred)
}
predictions.append({
"transaction_id": data_df.iloc[i].get('Transaction_Id', f"row_{i}"),
"predictions": decoded
})
return BatchPredictionResponse(message="Batch prediction complete", predictions=predictions)
elif request and request.transaction_data:
input_df = pd.DataFrame([request.transaction_data.dict()])
text_input = " ".join([str(v) for v in input_df.iloc[0].values if pd.notna(v)])
X_vec = tfidf.transform([text_input])
pred = model.predict(X_vec)[0]
decoded = {
col: encoders[col].inverse_transform([p])[0]
for col, p in zip(LABEL_COLUMNS, pred)
}
return decoded
else:
raise HTTPException(status_code=400, detail="Provide a transaction or upload a CSV file")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/xgb/download-model/{model_id}")
async def download_model(model_id: str):
model_path = MODEL_SAVE_DIR / f"{model_id}.pkl"
if not model_path.exists():
raise HTTPException(status_code=404, detail="Model not found")
return FileResponse(model_path, filename=f"xgb_model_{model_id}.pkl", media_type="application/octet-stream")
async def train_model_task(config: TrainingConfig, file_path: str, training_id: str):
try:
data_df, label_encoders = load_and_preprocess_data(file_path)
save_label_encoders(label_encoders)
X = data_df[TEXT_COLUMN]
y = data_df[LABEL_COLUMNS]
model = TfidfXGBoost(label_encoders)
model.train(X, y)
model.save_model(training_id)
training_status.update({
"is_training": False,
"end_time": datetime.now().isoformat(),
"status": "completed"
})
except Exception as e:
logger.error(f"Training failed: {str(e)}")
training_status.update({
"is_training": False,
"end_time": datetime.now().isoformat(),
"status": "failed",
"error": str(e)
})
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)