Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import pandas as pd | |
| import joblib | |
| import os | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.preprocessing import LabelEncoder | |
| from sklearn.multioutput import MultiOutputClassifier | |
| import config | |
| from sklearn.metrics import accuracy_score | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.linear_model import LogisticRegression | |
| # --- Configuration --- | |
| LABEL_COLUMNS = [ | |
| "Red_Flag_Reason", "Maker_Action", "Escalation_Level", | |
| "Risk_Category", "Risk_Drivers", "Investigation_Outcome" | |
| ] | |
| TEXT_COLUMN = "Sanction_Context" | |
| MODEL_DIR = "/tmp" | |
| MODEL_PATH = os.path.join(MODEL_DIR, "logreg_model.pkl") | |
| TFIDF_PATH = os.path.join(MODEL_DIR, "tfidf_vectorizer.pkl") | |
| ENCODERS_PATH = os.path.join(MODEL_DIR, "label_encoders.pkl") | |
| # --- FastAPI App --- | |
| app = FastAPI() | |
| # --- Input Schema --- | |
| class TransactionData(BaseModel): | |
| Transaction_Id: str | |
| Hit_Seq: int | |
| Hit_Id_List: str | |
| Origin: str | |
| Designation: str | |
| Keywords: str | |
| Name: str | |
| SWIFT_Tag: str | |
| Currency: str | |
| Entity: str | |
| Message: str | |
| City: str | |
| Country: str | |
| State: str | |
| Hit_Type: str | |
| Record_Matching_String: str | |
| WatchList_Match_String: str | |
| Payment_Sender_Name: Optional[str] = "" | |
| Payment_Reciever_Name: Optional[str] = "" | |
| Swift_Message_Type: str | |
| Text_Sanction_Data: str | |
| Matched_Sanctioned_Entity: str | |
| Is_Match: int | |
| Red_Flag_Reason: str | |
| Risk_Level: str | |
| Risk_Score: float | |
| Risk_Score_Description: str | |
| CDD_Level: str | |
| PEP_Status: str | |
| Value_Date: str | |
| Last_Review_Date: str | |
| Next_Review_Date: str | |
| Sanction_Description: str | |
| Checker_Notes: str | |
| Sanction_Context: str | |
| Maker_Action: str | |
| Customer_ID: int | |
| Customer_Type: str | |
| Industry: str | |
| Transaction_Date_Time: str | |
| Transaction_Type: str | |
| Transaction_Channel: str | |
| Originating_Bank: str | |
| Beneficiary_Bank: str | |
| Geographic_Origin: str | |
| Geographic_Destination: str | |
| Match_Score: float | |
| Match_Type: str | |
| Sanctions_List_Version: str | |
| Screening_Date_Time: str | |
| Risk_Category: str | |
| Risk_Drivers: str | |
| Alert_Status: str | |
| Investigation_Outcome: str | |
| Case_Owner_Analyst: str | |
| Escalation_Level: str | |
| Escalation_Date: str | |
| Regulatory_Reporting_Flags: bool | |
| Audit_Trail_Timestamp: str | |
| Source_Of_Funds: str | |
| Purpose_Of_Transaction: str | |
| Beneficial_Owner: str | |
| Sanctions_Exposure_History: bool | |
| class PredictionRequest(BaseModel): | |
| transaction_data: TransactionData | |
| class DataPathInput(BaseModel): | |
| data_path: str | |
| # --- Root --- | |
| def health_check(): | |
| return {"status": "healthy", "message": "LOGREG TF-IDF API is running"} | |
| # --- Train --- | |
| def train(): | |
| try: | |
| # Load data | |
| df = pd.read_csv(config.DATA_PATH) | |
| # Prepare features and labels | |
| X = df[config.TEXT_COLUMN] | |
| y = df[config.LABEL_COLUMNS] | |
| # Encode labels using LabelEncoder per column | |
| label_encoders = {} | |
| y_encoded = pd.DataFrame() | |
| for col in config.LABEL_COLUMNS: | |
| le = LabelEncoder() | |
| y_encoded[col] = le.fit_transform(y[col]) | |
| label_encoders[col] = le | |
| # Train/test split | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y_encoded, test_size=config.TEST_SIZE, random_state=config.RANDOM_STATE | |
| ) | |
| # TF-IDF vectorizer | |
| vectorizer = TfidfVectorizer( | |
| max_features=config.TFIDF_MAX_FEATURES, | |
| ngram_range=config.NGRAM_RANGE, | |
| stop_words='english' if config.USE_STOPWORDS else None | |
| ) | |
| X_train_vec = vectorizer.fit_transform(X_train) | |
| X_test_vec = vectorizer.transform(X_test) | |
| # Model | |
| model = MultiOutputClassifier(LogisticRegression(max_iter=1000)) | |
| model.fit(X_train_vec, y_train) | |
| # Predictions | |
| y_pred = model.predict(X_test_vec) | |
| # Accuracy | |
| accuracy = { | |
| col: accuracy_score(y_test[col], [pred[i] for pred in y_pred]) | |
| for i, col in enumerate(config.LABEL_COLUMNS) | |
| } | |
| # Save all | |
| joblib.dump(model, config.MODEL_PATH) | |
| joblib.dump(vectorizer, config.TFIDF_VECTORIZER_PATH) | |
| joblib.dump(label_encoders, config.LABEL_ENCODERS_PATH) | |
| return { | |
| "message": "Training completed successfully.", | |
| "accuracy": accuracy | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # --- Validate (only structure check) --- | |
| def validate_model(input: DataPathInput): | |
| try: | |
| df = pd.read_csv(input.data_path) | |
| required_columns = [TEXT_COLUMN] + LABEL_COLUMNS | |
| missing = [col for col in required_columns if col not in df.columns] | |
| if missing: | |
| return {"status": " Invalid input", "missing_columns": missing} | |
| else: | |
| return {"status": " Input is valid."} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Validation error: {str(e)}") | |
| # --- Test --- | |
| def test_model(input: DataPathInput): | |
| try: | |
| df = pd.read_csv(input.data_path) | |
| df = df.dropna(subset=[TEXT_COLUMN]) | |
| tfidf = joblib.load(TFIDF_PATH) | |
| model = joblib.load(MODEL_PATH) | |
| encoders = joblib.load(ENCODERS_PATH) | |
| X_vec = tfidf.transform(df[TEXT_COLUMN]) | |
| preds = model.predict(X_vec) | |
| decoded_preds = [] | |
| for pred in preds: | |
| decoded = { | |
| col: encoders[col].inverse_transform([label])[0] | |
| for col, label in zip(LABEL_COLUMNS, pred) | |
| } | |
| decoded_preds.append(decoded) | |
| return {"predictions": decoded_preds[:8]} # Sample 5 predictions | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # --- Predict --- | |
| def predict(request: PredictionRequest): | |
| try: | |
| input_data = pd.DataFrame([request.transaction_data.dict()]) | |
| # Structured field-based formatted text input | |
| text_input = f""" | |
| Transaction ID: {input_data.get('Transaction_Id', [''])[0]} | |
| Origin: {input_data.get('Origin', [''])[0]} | |
| Designation: {input_data.get('Designation', [''])[0]} | |
| Keywords: {input_data.get('Keywords', [''])[0]} | |
| Name: {input_data.get('Name', [''])[0]} | |
| SWIFT Tag: {input_data.get('SWIFT_Tag', [''])[0]} | |
| Currency: {input_data.get('Currency', [''])[0]} | |
| Entity: {input_data.get('Entity', [''])[0]} | |
| Message: {input_data.get('Message', [''])[0]} | |
| City: {input_data.get('City', [''])[0]} | |
| Country: {input_data.get('Country', [''])[0]} | |
| State: {input_data.get('State', [''])[0]} | |
| Hit Type: {input_data.get('Hit_Type', [''])[0]} | |
| Record Matching String: {input_data.get('Record_Matching_String', [''])[0]} | |
| WatchList Match String: {input_data.get('WatchList_Match_String', [''])[0]} | |
| Payment Sender: {input_data.get('Payment_Sender_Name', [''])[0]} | |
| Payment Receiver: {input_data.get('Payment_Reciever_Name', [''])[0]} | |
| Swift Message Type: {input_data.get('Swift_Message_Type', [''])[0]} | |
| Text Sanction Data: {input_data.get('Text_Sanction_Data', [''])[0]} | |
| Matched Sanctioned Entity: {input_data.get('Matched_Sanctioned_Entity', [''])[0]} | |
| Red Flag Reason: {input_data.get('Red_Flag_Reason', [''])[0]} | |
| Risk Level: {input_data.get('Risk_Level', [''])[0]} | |
| Risk Score: {input_data.get('Risk_Score', [''])[0]} | |
| CDD Level: {input_data.get('CDD_Level', [''])[0]} | |
| PEP Status: {input_data.get('PEP_Status', [''])[0]} | |
| Sanction Description: {input_data.get('Sanction_Description', [''])[0]} | |
| Checker Notes: {input_data.get('Checker_Notes', [''])[0]} | |
| Sanction Context: {input_data.get('Sanction_Context', [''])[0]} | |
| Maker Action: {input_data.get('Maker_Action', [''])[0]} | |
| Customer Type: {input_data.get('Customer_Type', [''])[0]} | |
| Industry: {input_data.get('Industry', [''])[0]} | |
| Transaction Type: {input_data.get('Transaction_Type', [''])[0]} | |
| Transaction Channel: {input_data.get('Transaction_Channel', [''])[0]} | |
| Geographic Origin: {input_data.get('Geographic_Origin', [''])[0]} | |
| Geographic Destination: {input_data.get('Geographic_Destination', [''])[0]} | |
| Risk Category: {input_data.get('Risk_Category', [''])[0]} | |
| Risk Drivers: {input_data.get('Risk_Drivers', [''])[0]} | |
| Alert Status: {input_data.get('Alert_Status', [''])[0]} | |
| Investigation Outcome: {input_data.get('Investigation_Outcome', [''])[0]} | |
| Source of Funds: {input_data.get('Source_Of_Funds', [''])[0]} | |
| Purpose of Transaction: {input_data.get('Purpose_Of_Transaction', [''])[0]} | |
| Beneficial Owner: {input_data.get('Beneficial_Owner', [''])[0]} | |
| """ | |
| # Load TF-IDF and model | |
| tfidf = joblib.load(TFIDF_PATH) | |
| model = joblib.load(MODEL_PATH) | |
| encoders = joblib.load(ENCODERS_PATH) | |
| # Predict | |
| X_vec = tfidf.transform([text_input]) | |
| pred = model.predict(X_vec)[0] | |
| # Decode predictions | |
| decoded = { | |
| col: encoders[col].inverse_transform([p])[0] | |
| for col, p in zip(LABEL_COLUMNS, pred) | |
| } | |
| return {"prediction": decoded} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |