Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from typing import Optional, Dict, Any, List | |
| import uvicorn | |
| import logging | |
| import os | |
| import pandas as pd | |
| from datetime import datetime | |
| import shutil | |
| from pathlib import Path | |
| import numpy as np | |
| import sys | |
| import json | |
| import joblib | |
| from dataset_utils import load_and_preprocess_data, save_label_encoders, load_label_encoders | |
| from config import TEXT_COLUMN, LABEL_COLUMNS, BATCH_SIZE, MODEL_SAVE_DIR | |
| from tfidf_xgb import TfidfXGBoost # ✅ XGBoost model class | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="XGBoost Compliance Predictor API") | |
| UPLOAD_DIR = Path("uploads") | |
| MODEL_SAVE_DIR = Path("saved_models") | |
| UPLOAD_DIR.mkdir(parents=True, exist_ok=True) | |
| MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True) | |
| TFIDF_PATH = os.path.join(str(MODEL_SAVE_DIR), "tfidf_vectorizer.pkl") | |
| MODEL_PATH = os.path.join(str(MODEL_SAVE_DIR), "xgb_models.pkl") | |
| ENCODERS_PATH = os.path.join(os.path.dirname(__file__), "label_encoders.pkl") | |
| training_status = { | |
| "is_training": False, | |
| "current_epoch": 0, | |
| "total_epochs": 0, | |
| "current_loss": 0.0, | |
| "start_time": None, | |
| "end_time": None, | |
| "status": "idle", | |
| "metrics": None | |
| } | |
| # --- Pydantic Models (unchanged) --- | |
| class TrainingConfig(BaseModel): | |
| batch_size: int = 32 | |
| num_epochs: int = 1 | |
| random_state: int = 42 | |
| class TrainingResponse(BaseModel): | |
| message: str | |
| training_id: str | |
| status: str | |
| download_url: Optional[str] = None | |
| class ValidationResponse(BaseModel): | |
| message: str | |
| metrics: Dict[str, Any] | |
| predictions: List[Dict[str, Any]] | |
| class TransactionData(BaseModel): | |
| # same as original | |
| Transaction_Id: str | |
| Hit_Seq: int | |
| Hit_Id_List: str | |
| Origin: str | |
| Designation: str | |
| Keywords: str | |
| Name: str | |
| SWIFT_Tag: str | |
| Currency: str | |
| Entity: str | |
| Message: str | |
| City: str | |
| Country: str | |
| State: str | |
| Hit_Type: str | |
| Record_Matching_String: str | |
| WatchList_Match_String: str | |
| Payment_Sender_Name: Optional[str] = "" | |
| Payment_Reciever_Name: Optional[str] = "" | |
| Swift_Message_Type: str | |
| Text_Sanction_Data: str | |
| Matched_Sanctioned_Entity: str | |
| Is_Match: int | |
| Red_Flag_Reason: str | |
| Risk_Level: str | |
| Risk_Score: float | |
| Risk_Score_Description: str | |
| CDD_Level: str | |
| PEP_Status: str | |
| Value_Date: str | |
| Last_Review_Date: str | |
| Next_Review_Date: str | |
| Sanction_Description: str | |
| Checker_Notes: str | |
| Sanction_Context: str | |
| Maker_Action: str | |
| Customer_ID: int | |
| Customer_Type: str | |
| Industry: str | |
| Transaction_Date_Time: str | |
| Transaction_Type: str | |
| Transaction_Channel: str | |
| Originating_Bank: str | |
| Beneficiary_Bank: str | |
| Geographic_Origin: str | |
| Geographic_Destination: str | |
| Match_Score: float | |
| Match_Type: str | |
| Sanctions_List_Version: str | |
| Screening_Date_Time: str | |
| Risk_Category: str | |
| Risk_Drivers: str | |
| Alert_Status: str | |
| Investigation_Outcome: str | |
| Case_Owner_Analyst: str | |
| Escalation_Level: str | |
| Escalation_Date: str | |
| Regulatory_Reporting_Flags: bool | |
| Audit_Trail_Timestamp: str | |
| Source_Of_Funds: str | |
| Purpose_Of_Transaction: str | |
| Beneficial_Owner: str | |
| Sanctions_Exposure_History: bool | |
| class PredictionRequest(BaseModel): | |
| transaction_data: TransactionData | |
| model_name: str = "xgb_models" | |
| class BatchPredictionResponse(BaseModel): | |
| message: str | |
| predictions: List[Dict[str, Any]] | |
| metrics: Optional[Dict[str, Any]] = None | |
| # --- API Routes --- | |
| async def root(): | |
| return {"message": "XGBoost Compliance Predictor API"} | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| async def get_training_status(): | |
| return training_status | |
| async def start_training( | |
| config: str = Form(...), | |
| background_tasks: BackgroundTasks = None, | |
| file: UploadFile = File(...) | |
| ): | |
| if training_status["is_training"]: | |
| raise HTTPException(status_code=400, detail="Training is already in progress") | |
| if not file.filename.endswith('.csv'): | |
| raise HTTPException(status_code=400, detail="Only CSV files are allowed") | |
| try: | |
| config_dict = json.loads(config) | |
| training_config = TrainingConfig(**config_dict) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid config: {str(e)}") | |
| file_path = UPLOAD_DIR / file.filename | |
| with file_path.open("wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| training_id = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| training_status.update({ | |
| "is_training": True, | |
| "start_time": datetime.now().isoformat(), | |
| "status": "starting" | |
| }) | |
| background_tasks.add_task(train_model_task, training_config, str(file_path), training_id) | |
| download_url = f"/v1/xgb/download-model/{training_id}" | |
| return TrainingResponse( | |
| message="Training started", | |
| training_id=training_id, | |
| status="started", | |
| download_url=download_url | |
| ) | |
| async def validate_model(file: UploadFile = File(...), model_name: str = "xgb_models"): | |
| if not file.filename.endswith('.csv'): | |
| raise HTTPException(status_code=400, detail="Only CSV files allowed") | |
| try: | |
| file_path = UPLOAD_DIR / file.filename | |
| with file_path.open("wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| data_df, label_encoders = load_and_preprocess_data(str(file_path)) | |
| model = TfidfXGBoost(label_encoders) | |
| model.load_model(model_name) | |
| X = data_df[TEXT_COLUMN] | |
| y = data_df[LABEL_COLUMNS] | |
| reports, y_true_list, y_pred_list = model.evaluate(X, y) | |
| all_probs = model.predict_proba(X) | |
| predictions = [] | |
| for i, col in enumerate(LABEL_COLUMNS): | |
| le = label_encoders[col] | |
| true_labels = le.inverse_transform(y_true_list[i]) | |
| pred_labels = le.inverse_transform(y_pred_list[i]) | |
| for true, pred, probs in zip(true_labels, pred_labels, all_probs[i]): | |
| class_probs = {label: float(prob) for label, prob in zip(le.classes_, probs)} | |
| predictions.append({ | |
| "field": col, | |
| "true_label": true, | |
| "predicted_label": pred, | |
| "probabilities": class_probs | |
| }) | |
| return ValidationResponse( | |
| message="Validation completed", | |
| metrics=reports, | |
| predictions=predictions | |
| ) | |
| except Exception as e: | |
| logger.error(f"Validation failed: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Validation error: {str(e)}") | |
| finally: | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| async def predict(request: Optional[PredictionRequest] = None, file: UploadFile = File(None), model_name: str = "xgb_models"): | |
| try: | |
| tfidf = joblib.load(TFIDF_PATH) | |
| model = joblib.load(MODEL_PATH) | |
| encoders = joblib.load(ENCODERS_PATH) | |
| if file and file.filename: | |
| if not file.filename.endswith('.csv'): | |
| raise HTTPException(status_code=400, detail="Only CSV files allowed") | |
| file_path = UPLOAD_DIR / file.filename | |
| with file_path.open("wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| data_df, _ = load_and_preprocess_data(str(file_path)) | |
| texts = data_df.apply(lambda row: " ".join([str(v) for v in row.values if pd.notna(v)]), axis=1) | |
| X_vec = tfidf.transform(texts) | |
| preds = model.predict(X_vec) | |
| predictions = [] | |
| for i, pred in enumerate(preds): | |
| decoded = { | |
| col: encoders[col].inverse_transform([label])[0] | |
| for col, label in zip(LABEL_COLUMNS, pred) | |
| } | |
| predictions.append({ | |
| "transaction_id": data_df.iloc[i].get('Transaction_Id', f"row_{i}"), | |
| "predictions": decoded | |
| }) | |
| return BatchPredictionResponse(message="Batch prediction complete", predictions=predictions) | |
| elif request and request.transaction_data: | |
| input_df = pd.DataFrame([request.transaction_data.dict()]) | |
| text_input = " ".join([str(v) for v in input_df.iloc[0].values if pd.notna(v)]) | |
| X_vec = tfidf.transform([text_input]) | |
| pred = model.predict(X_vec)[0] | |
| decoded = { | |
| col: encoders[col].inverse_transform([p])[0] | |
| for col, p in zip(LABEL_COLUMNS, pred) | |
| } | |
| return decoded | |
| else: | |
| raise HTTPException(status_code=400, detail="Provide a transaction or upload a CSV file") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def download_model(model_id: str): | |
| model_path = MODEL_SAVE_DIR / f"{model_id}.pkl" | |
| if not model_path.exists(): | |
| raise HTTPException(status_code=404, detail="Model not found") | |
| return FileResponse(model_path, filename=f"xgb_model_{model_id}.pkl", media_type="application/octet-stream") | |
| async def train_model_task(config: TrainingConfig, file_path: str, training_id: str): | |
| try: | |
| data_df, label_encoders = load_and_preprocess_data(file_path) | |
| save_label_encoders(label_encoders) | |
| X = data_df[TEXT_COLUMN] | |
| y = data_df[LABEL_COLUMNS] | |
| model = TfidfXGBoost(label_encoders) | |
| model.train(X, y) | |
| model.save_model(training_id) | |
| training_status.update({ | |
| "is_training": False, | |
| "end_time": datetime.now().isoformat(), | |
| "status": "completed" | |
| }) | |
| except Exception as e: | |
| logger.error(f"Training failed: {str(e)}") | |
| training_status.update({ | |
| "is_training": False, | |
| "end_time": datetime.now().isoformat(), | |
| "status": "failed", | |
| "error": str(e) | |
| }) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |