Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import pandas as pd | |
| import joblib | |
| import os | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.preprocessing import LabelEncoder | |
| from sklearn.multioutput import MultiOutputClassifier | |
| from sklearn.linear_model import LogisticRegression | |
| # --- Configuration --- | |
| LABEL_COLUMNS = [ | |
| "Red_Flag_Reason", "Maker_Action", "Escalation_Level", | |
| "Risk_Category", "Risk_Drivers", "Investigation_Outcome" | |
| ] | |
| TEXT_COLUMN = "Sanction_Context" | |
| MODEL_DIR = "/tmp" | |
| MODEL_PATH = os.path.join(MODEL_DIR, "logreg_model.pkl") | |
| TFIDF_PATH = os.path.join(MODEL_DIR, "tfidf_vectorizer.pkl") | |
| ENCODERS_PATH = os.path.join(MODEL_DIR, "label_encoders.pkl") | |
| # --- FastAPI App --- | |
| app = FastAPI() | |
| # --- Input Schema --- | |
| class TransactionData(BaseModel): | |
| Transaction_Id: str | |
| Hit_Seq: int | |
| Hit_Id_List: str | |
| Origin: str | |
| Designation: str | |
| Keywords: str | |
| Name: str | |
| SWIFT_Tag: str | |
| Currency: str | |
| Entity: str | |
| Message: str | |
| City: str | |
| Country: str | |
| State: str | |
| Hit_Type: str | |
| Record_Matching_String: str | |
| WatchList_Match_String: str | |
| Payment_Sender_Name: Optional[str] = "" | |
| Payment_Reciever_Name: Optional[str] = "" | |
| Swift_Message_Type: str | |
| Text_Sanction_Data: str | |
| Matched_Sanctioned_Entity: str | |
| Is_Match: int | |
| Red_Flag_Reason: str | |
| Risk_Level: str | |
| Risk_Score: float | |
| Risk_Score_Description: str | |
| CDD_Level: str | |
| PEP_Status: str | |
| Value_Date: str | |
| Last_Review_Date: str | |
| Next_Review_Date: str | |
| Sanction_Description: str | |
| Checker_Notes: str | |
| Sanction_Context: str | |
| Maker_Action: str | |
| Customer_ID: int | |
| Customer_Type: str | |
| Industry: str | |
| Transaction_Date_Time: str | |
| Transaction_Type: str | |
| Transaction_Channel: str | |
| Originating_Bank: str | |
| Beneficiary_Bank: str | |
| Geographic_Origin: str | |
| Geographic_Destination: str | |
| Match_Score: float | |
| Match_Type: str | |
| Sanctions_List_Version: str | |
| Screening_Date_Time: str | |
| Risk_Category: str | |
| Risk_Drivers: str | |
| Alert_Status: str | |
| Investigation_Outcome: str | |
| Case_Owner_Analyst: str | |
| Escalation_Level: str | |
| Escalation_Date: str | |
| Regulatory_Reporting_Flags: bool | |
| Audit_Trail_Timestamp: str | |
| Source_Of_Funds: str | |
| Purpose_Of_Transaction: str | |
| Beneficial_Owner: str | |
| Sanctions_Exposure_History: bool | |
| class PredictionRequest(BaseModel): | |
| transaction_data: TransactionData | |
| class DataPathInput(BaseModel): | |
| data_path: str | |
| def health_check(): | |
| return {"status": "healthy", "message": "logistic regression complience predictor API "} | |
| def train_model(input: DataPathInput): | |
| try: | |
| df = pd.read_csv(input.data_path) | |
| df.dropna(subset=[TEXT_COLUMN] + LABEL_COLUMNS, inplace=True) | |
| label_encoders = {} | |
| for col in LABEL_COLUMNS: | |
| le = LabelEncoder() | |
| df[col] = le.fit_transform(df[col]) | |
| label_encoders[col] = le | |
| tfidf = TfidfVectorizer(max_features=1000, ngram_range=(1, 2), stop_words="english") | |
| X_vec = tfidf.fit_transform(df[TEXT_COLUMN]) | |
| y = df[LABEL_COLUMNS] | |
| model = MultiOutputClassifier(LogisticRegression(max_iter=1000)) | |
| model.fit(X_vec, y) | |
| joblib.dump(model, MODEL_PATH) | |
| joblib.dump(tfidf, TFIDF_PATH) | |
| joblib.dump(label_encoders, ENCODERS_PATH) | |
| return {"status": "β Logistic Regression model trained and saved."} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def validate_model(input: DataPathInput): | |
| try: | |
| df = pd.read_csv(input.data_path) | |
| required_columns = [TEXT_COLUMN] + LABEL_COLUMNS | |
| missing = [col for col in required_columns if col not in df.columns] | |
| if missing: | |
| return {"status": "β Invalid input", "missing_columns": missing} | |
| else: | |
| return {"status": "β Input is valid."} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Validation error: {str(e)}") | |
| def test_model(input: DataPathInput): | |
| try: | |
| df = pd.read_csv(input.data_path) | |
| df = df.dropna(subset=[TEXT_COLUMN]) | |
| tfidf = joblib.load(TFIDF_PATH) | |
| model = joblib.load(MODEL_PATH) | |
| encoders = joblib.load(ENCODERS_PATH) | |
| X_vec = tfidf.transform(df[TEXT_COLUMN]) | |
| preds = model.predict(X_vec) | |
| decoded_preds = [] | |
| for pred in preds: | |
| decoded = { | |
| col: encoders[col].inverse_transform([label])[0] | |
| for col, label in zip(LABEL_COLUMNS, pred) | |
| } | |
| decoded_preds.append(decoded) | |
| return {"predictions": decoded_preds[:5]} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def predict(request: PredictionRequest): | |
| try: | |
| input_data = pd.DataFrame([request.transaction_data.dict()]) | |
| text_input = " ".join([str(val) for val in input_data.iloc[0].values if pd.notna(val)]) | |
| tfidf = joblib.load(TFIDF_PATH) | |
| model = joblib.load(MODEL_PATH) | |
| encoders = joblib.load(ENCODERS_PATH) | |
| X_vec = tfidf.transform([text_input]) | |
| pred = model.predict(X_vec)[0] | |
| decoded = { | |
| col: encoders[col].inverse_transform([p])[0] | |
| for col, p in zip(LABEL_COLUMNS, pred) | |
| } | |
| return {"prediction": decoded} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def train_test_validate(input: DataPathInput): | |
| try: | |
| train_model(input) | |
| validate_result = validate_model(input) | |
| test_result = test_model(input) | |
| return { | |
| "train": "β Done", | |
| "validate": validate_result, | |
| "test": test_result, | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |