Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from typing import Optional, Dict, Any, List | |
| import uvicorn | |
| import logging | |
| import os | |
| import pandas as pd | |
| from datetime import datetime | |
| import shutil | |
| from pathlib import Path | |
| import numpy as np | |
| import json | |
| import joblib | |
| from sklearn.metrics import classification_report | |
| from sklearn.multioutput import MultiOutputClassifier | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from models.tfidf_logreg import TfidfLogisticRegression | |
| # Import existing utilities | |
| from dataset_utils import ( | |
| load_and_preprocess_data, | |
| save_label_encoders, | |
| load_label_encoders | |
| ) | |
| from config import ( | |
| TEXT_COLUMN, | |
| LABEL_COLUMNS, | |
| BATCH_SIZE, | |
| MODEL_SAVE_DIR | |
| ) | |
| from models.tfidf_logreg import TfidfLogisticRegression | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="LOGREG Compliance Predictor API") | |
| UPLOAD_DIR = Path("uploads") | |
| MODEL_SAVE_DIR = Path("saved_models") | |
| UPLOAD_DIR.mkdir(parents=True, exist_ok=True) | |
| MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True) | |
| # Define paths for vectorizer, model, and encoders | |
| TFIDF_PATH = os.path.join(str(MODEL_SAVE_DIR), "tfidf_vectorizer.pkl") | |
| MODEL_PATH = os.path.join(str(MODEL_SAVE_DIR), "logreg_models.pkl") | |
| ENCODERS_PATH = os.path.join(os.path.dirname(__file__), "label_encoders.pkl") | |
| training_status = { | |
| "is_training": False, | |
| "current_epoch": 0, | |
| "total_epochs": 0, | |
| "current_loss": 0.0, | |
| "start_time": None, | |
| "end_time": None, | |
| "status": "idle", | |
| "metrics": None | |
| } | |
| class TrainingConfig(BaseModel): | |
| batch_size: int = 32 | |
| num_epochs: int = 1 # Not used for LGBM, but kept for API compatibility | |
| random_state: int = 42 | |
| class TrainingResponse(BaseModel): | |
| message: str | |
| training_id: str | |
| status: str | |
| download_url: Optional[str] = None | |
| class ValidationResponse(BaseModel): | |
| message: str | |
| metrics: Dict[str, Any] | |
| predictions: List[Dict[str, Any]] | |
| class TransactionData(BaseModel): | |
| Transaction_Id: str | |
| Hit_Seq: int | |
| Hit_Id_List: str | |
| Origin: str | |
| Designation: str | |
| Keywords: str | |
| Name: str | |
| SWIFT_Tag: str | |
| Currency: str | |
| Entity: str | |
| Message: str | |
| City: str | |
| Country: str | |
| State: str | |
| Hit_Type: str | |
| Record_Matching_String: str | |
| WatchList_Match_String: str | |
| Payment_Sender_Name: Optional[str] = "" | |
| Payment_Reciever_Name: Optional[str] = "" | |
| Swift_Message_Type: str | |
| Text_Sanction_Data: str | |
| Matched_Sanctioned_Entity: str | |
| Is_Match: int | |
| Red_Flag_Reason: str | |
| Risk_Level: str | |
| Risk_Score: float | |
| Risk_Score_Description: str | |
| CDD_Level: str | |
| PEP_Status: str | |
| Value_Date: str | |
| Last_Review_Date: str | |
| Next_Review_Date: str | |
| Sanction_Description: str | |
| Checker_Notes: str | |
| Sanction_Context: str | |
| Maker_Action: str | |
| Customer_ID: int | |
| Customer_Type: str | |
| Industry: str | |
| Transaction_Date_Time: str | |
| Transaction_Type: str | |
| Transaction_Channel: str | |
| Originating_Bank: str | |
| Beneficiary_Bank: str | |
| Geographic_Origin: str | |
| Geographic_Destination: str | |
| Match_Score: float | |
| Match_Type: str | |
| Sanctions_List_Version: str | |
| Screening_Date_Time: str | |
| Risk_Category: str | |
| Risk_Drivers: str | |
| Alert_Status: str | |
| Investigation_Outcome: str | |
| Case_Owner_Analyst: str | |
| Escalation_Level: str | |
| Escalation_Date: str | |
| Regulatory_Reporting_Flags: bool | |
| Audit_Trail_Timestamp: str | |
| Source_Of_Funds: str | |
| Purpose_Of_Transaction: str | |
| Beneficial_Owner: str | |
| Sanctions_Exposure_History: bool | |
| class PredictionRequest(BaseModel): | |
| transaction_data: TransactionData | |
| model_name: str = "logreg_models" # Default to tfidf_logreg if not specified | |
| class BatchPredictionResponse(BaseModel): | |
| message: str | |
| predictions: List[Dict[str, Any]] | |
| metrics: Optional[Dict[str, Any]] = None | |
| async def root(): | |
| return {"message": "LOGREG Compliance Predictor API"} | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| async def get_training_status(): | |
| return training_status | |
| async def start_training( | |
| config: str = Form(...), | |
| background_tasks: BackgroundTasks = None, | |
| file: UploadFile = File(...) | |
| ): | |
| if training_status["is_training"]: | |
| raise HTTPException(status_code=400, detail="Training is already in progress") | |
| if not file.filename.endswith('.csv'): | |
| raise HTTPException(status_code=400, detail="Only CSV files are allowed") | |
| try: | |
| config_dict = json.loads(config) | |
| training_config = TrainingConfig(**config_dict) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid config parameters: {str(e)}") | |
| file_path = UPLOAD_DIR / file.filename | |
| with file_path.open("wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| training_id = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| training_status.update({ | |
| "is_training": True, | |
| "current_epoch": 0, | |
| "total_epochs": 1, | |
| "start_time": datetime.now().isoformat(), | |
| "status": "starting" | |
| }) | |
| background_tasks.add_task(train_model_task, training_config, str(file_path), training_id) | |
| download_url = f"/v1/logreg/download-model/{training_id}" | |
| return TrainingResponse( | |
| message="Training started successfully", | |
| training_id=training_id, | |
| status="started", | |
| download_url=download_url | |
| ) | |
| async def validate_model( | |
| file: UploadFile = File(...), | |
| model_name: str = "logreg_models" | |
| ): | |
| if not file.filename.endswith('.csv'): | |
| raise HTTPException(status_code=400, detail="Only CSV files are allowed") | |
| try: | |
| file_path = UPLOAD_DIR / file.filename | |
| with file_path.open("wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| data_df, label_encoders = load_and_preprocess_data(str(file_path)) | |
| model_path = MODEL_SAVE_DIR / f"{model_name}.pkl" | |
| if not model_path.exists(): | |
| raise HTTPException(status_code=404, detail="LOGREG model file not found") | |
| model = TfidfLOGREG(label_encoders) | |
| model.load_model(model_name) | |
| X = data_df[TEXT_COLUMN] | |
| y = data_df[LABEL_COLUMNS] | |
| # Type and shape check for X | |
| if not isinstance(X, pd.Series) or not pd.api.types.is_string_dtype(X): | |
| raise HTTPException(status_code=400, detail=f"TEXT_COLUMN ('{TEXT_COLUMN}') must be a pandas Series of strings. Got type: {type(X)}, dtype: {getattr(X, 'dtype', None)}") | |
| reports, y_true_list, y_pred_list = model.evaluate(X, y) | |
| all_probs = model.predict_proba(X) | |
| predictions = [] | |
| for i, col in enumerate(LABEL_COLUMNS): | |
| label_encoder = label_encoders[col] | |
| true_labels_orig = label_encoder.inverse_transform(y_true_list[i]) | |
| pred_labels_orig = label_encoder.inverse_transform(y_pred_list[i]) | |
| for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]): | |
| class_probs = {label: float(prob) for label, prob in zip(label_encoder.classes_, probs)} | |
| predictions.append({ | |
| "field": col, | |
| "true_label": true, | |
| "predicted_label": pred, | |
| "probabilities": class_probs | |
| }) | |
| return ValidationResponse( | |
| message="Validation completed successfully", | |
| metrics=reports, | |
| predictions=predictions | |
| ) | |
| except Exception as e: | |
| logger.error(f"Validation failed: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}") | |
| finally: | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| async def predict( | |
| request: Optional[PredictionRequest] = None, | |
| file: UploadFile = File(None), | |
| model_name: str = "logreg_models" | |
| ): | |
| try: | |
| # Load vectorizer, model, and encoders | |
| tfidf = joblib.load(TFIDF_PATH) | |
| model = joblib.load(MODEL_PATH) | |
| encoders = joblib.load(ENCODERS_PATH) | |
| # Batch prediction from CSV | |
| if file and file.filename: | |
| if not file.filename.endswith('.csv'): | |
| raise HTTPException(status_code=400, detail="Only CSV files are allowed") | |
| file_path = UPLOAD_DIR / file.filename | |
| with file_path.open("wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| try: | |
| data_df, _ = load_and_preprocess_data(str(file_path)) | |
| # Concatenate all fields into a single string for each row | |
| texts = data_df.apply(lambda row: " ".join([str(val) for val in row.values if pd.notna(val)]), axis=1) | |
| X_vec = tfidf.transform(texts) | |
| preds = model.predict(X_vec) | |
| predictions = [] | |
| for i, pred in enumerate(preds): | |
| decoded = { | |
| col: encoders[col].inverse_transform([label])[0] | |
| for col, label in zip(LABEL_COLUMNS, pred) | |
| } | |
| predictions.append({ | |
| "transaction_id": data_df.iloc[i].get('Transaction_Id', f"transaction_{i}"), | |
| "predictions": decoded | |
| }) | |
| return BatchPredictionResponse( | |
| message="Batch prediction completed successfully", | |
| predictions=predictions | |
| ) | |
| finally: | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| # Single prediction | |
| elif request and request.transaction_data: | |
| input_data = pd.DataFrame([request.transaction_data.dict()]) | |
| text_input = " ".join([ | |
| str(val) for val in input_data.iloc[0].values if pd.notna(val) | |
| ]) | |
| X_vec = tfidf.transform([text_input]) | |
| pred = model.predict(X_vec)[0] | |
| decoded = { | |
| col: encoders[col].inverse_transform([p])[0] | |
| for col, p in zip(LABEL_COLUMNS, pred) | |
| } | |
| return decoded | |
| else: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Either provide a transaction in the request body or upload a CSV file" | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def download_model(model_id: str): | |
| model_path = MODEL_SAVE_DIR / f"{model_id}.pkl" | |
| if not model_path.exists(): | |
| raise HTTPException(status_code=404, detail="Model not found") | |
| return FileResponse( | |
| path=model_path, | |
| filename=f"logreg_model_{model_id}.pkl", | |
| media_type="application/octet-stream" | |
| ) | |
| async def train_model_task(config: TrainingConfig, file_path: str, training_id: str): | |
| try: | |
| data_df_original, label_encoders = load_and_preprocess_data(file_path) | |
| save_label_encoders(label_encoders) | |
| X = data_df_original[TEXT_COLUMN] | |
| y = data_df_original[LABEL_COLUMNS] | |
| model = TfidfXGB(label_encoders) | |
| model.train(X, y) | |
| model.save_model(training_id) | |
| training_status.update({ | |
| "is_training": False, | |
| "end_time": datetime.now().isoformat(), | |
| "status": "completed" | |
| }) | |
| except Exception as e: | |
| logger.error(f"Training failed: {str(e)}") | |
| training_status.update({ | |
| "is_training": False, | |
| "end_time": datetime.now().isoformat(), | |
| "status": "failed", | |
| "error": str(e) | |
| }) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) |