Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import pandas as pd | |
| import joblib | |
| import os | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.multioutput import MultiOutputClassifier | |
| from sklearn.pipeline import Pipeline | |
| # ========== Config ========== | |
| DATA_PATH = "data/synthetic_transactions_samples_5000.csv" | |
| MODEL_DIR = "models" | |
| MODEL_PATH = os.path.join(MODEL_DIR, "logreg_model.pkl") | |
| # ========== FastAPI Init ========== | |
| app = FastAPI() | |
| # ========== Input Schema ========== | |
| class TransactionData(BaseModel): | |
| Transaction_Id: str | |
| Hit_Seq: int | |
| Hit_Id_List: str | |
| Origin: str | |
| Designation: str | |
| Keywords: str | |
| Name: str | |
| SWIFT_Tag: str | |
| Currency: str | |
| Entity: str | |
| Message: str | |
| City: str | |
| Country: str | |
| State: str | |
| Hit_Type: str | |
| Record_Matching_String: str | |
| WatchList_Match_String: str | |
| Payment_Sender_Name: Optional[str] = "" | |
| Payment_Reciever_Name: Optional[str] = "" | |
| Swift_Message_Type: str | |
| Text_Sanction_Data: str | |
| Matched_Sanctioned_Entity: str | |
| Is_Match: int | |
| Red_Flag_Reason: str | |
| Risk_Level: str | |
| Risk_Score: float | |
| Risk_Score_Description: str | |
| CDD_Level: str | |
| PEP_Status: str | |
| Value_Date: str | |
| Last_Review_Date: str | |
| Next_Review_Date: str | |
| Sanction_Description: str | |
| Checker_Notes: str | |
| Sanction_Context: str | |
| Maker_Action: str | |
| Customer_ID: int | |
| Customer_Type: str | |
| Industry: str | |
| Transaction_Date_Time: str | |
| Transaction_Type: str | |
| Transaction_Channel: str | |
| Originating_Bank: str | |
| Beneficiary_Bank: str | |
| Geographic_Origin: str | |
| Geographic_Destination: str | |
| Match_Score: float | |
| Match_Type: str | |
| Sanctions_List_Version: str | |
| Screening_Date_Time: str | |
| Risk_Category: str | |
| Risk_Drivers: str | |
| Alert_Status: str | |
| Investigation_Outcome: str | |
| Case_Owner_Analyst: str | |
| Escalation_Level: str | |
| Escalation_Date: str | |
| Regulatory_Reporting_Flags: bool | |
| Audit_Trail_Timestamp: str | |
| Source_Of_Funds: str | |
| Purpose_Of_Transaction: str | |
| Beneficial_Owner: str | |
| Sanctions_Exposure_History: bool | |
| # ========== Utils ========== | |
| def create_text_input(row): | |
| return f""" | |
| Transaction ID: {row['Transaction_Id']} | |
| Origin: {row['Origin']} | |
| Designation: {row['Designation']} | |
| Keywords: {row['Keywords']} | |
| Name: {row['Name']} | |
| SWIFT Tag: {row['SWIFT_Tag']} | |
| Currency: {row['Currency']} | |
| Entity: {row['Entity']} | |
| Message: {row['Message']} | |
| City: {row['City']} | |
| Country: {row['Country']} | |
| State: {row['State']} | |
| Hit Type: {row['Hit_Type']} | |
| Record Matching String: {row['Record_Matching_String']} | |
| WatchList Match String: {row['WatchList_Match_String']} | |
| Payment Sender: {row['Payment_Sender_Name']} | |
| Payment Receiver: {row['Payment_Reciever_Name']} | |
| Swift Message Type: {row['Swift_Message_Type']} | |
| Text Sanction Data: {row['Text_Sanction_Data']} | |
| Matched Sanctioned Entity: {row['Matched_Sanctioned_Entity']} | |
| Red Flag Reason: {row['Red_Flag_Reason']} | |
| Risk Level: {row['Risk_Level']} | |
| Risk Score: {row['Risk_Score']} | |
| CDD Level: {row['CDD_Level']} | |
| PEP Status: {row['PEP_Status']} | |
| Sanction Description: {row['Sanction_Description']} | |
| Checker Notes: {row['Checker_Notes']} | |
| Sanction Context: {row['Sanction_Context']} | |
| Maker Action: {row['Maker_Action']} | |
| Customer Type: {row['Customer_Type']} | |
| Industry: {row['Industry']} | |
| Transaction Type: {row['Transaction_Type']} | |
| Transaction Channel: {row['Transaction_Channel']} | |
| Geographic Origin: {row['Geographic_Origin']} | |
| Geographic Destination: {row['Geographic_Destination']} | |
| Risk Category: {row['Risk_Category']} | |
| Risk Drivers: {row['Risk_Drivers']} | |
| Alert Status: {row['Alert_Status']} | |
| Investigation Outcome: {row['Investigation_Outcome']} | |
| Source of Funds: {row['Source_Of_Funds']} | |
| Purpose of Transaction: {row['Purpose_Of_Transaction']} | |
| Beneficial Owner: {row['Beneficial_Owner']} | |
| """ | |
| # ========== Root ========== | |
| def root(): | |
| return {"message": "TF-IDF Logistic Regression API is running."} | |
| # ========== API Routes ========== | |
| def train_model(): | |
| df = pd.read_csv(DATA_PATH) | |
| df = df.fillna("") | |
| df["text_input"] = df.apply(create_text_input, axis=1) | |
| X = df["text_input"] | |
| y = df[["Maker_Action", "Escalation_Level", "Risk_Category", "Risk_Drivers", "Investigation_Outcome", "Red_Flag_Reason"]] | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
| vectorizer = TfidfVectorizer() | |
| classifier = MultiOutputClassifier(LogisticRegression(max_iter=1000)) | |
| pipeline = Pipeline([ | |
| ("vectorizer", vectorizer), | |
| ("classifier", classifier) | |
| ]) | |
| pipeline.fit(X_train, y_train) | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| joblib.dump(pipeline, MODEL_PATH) | |
| accuracy = pipeline.score(X_test, y_test) | |
| return {"message": "Model trained and saved.", "accuracy": accuracy} | |
| def predict(request: TransactionData): | |
| try: | |
| model = joblib.load(MODEL_PATH) | |
| input_data = pd.DataFrame([request.dict()]) | |
| input_data = input_data.fillna("") | |
| text_input = create_text_input(input_data.iloc[0]) | |
| prediction = model.predict([text_input])[0] | |
| return { | |
| "Maker_Action": prediction[0], | |
| "Escalation_Level": prediction[1], | |
| "Risk_Category": prediction[2], | |
| "Risk_Drivers": prediction[3], | |
| "Investigation_Outcome": prediction[4], | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def validate_input(request: TransactionData): | |
| return {"message": "Input is valid."} | |
| def test_api(): | |
| return {"message": "Test successful."} | |