Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import joblib | |
| import requests | |
| import pandas as pd | |
| from typing import List | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.multioutput import MultiOutputClassifier | |
| from sklearn.pipeline import Pipeline | |
| from sklearn.preprocessing import LabelEncoder | |
| from sklearn.linear_model import LogisticRegression | |
| from pydantic import BaseModel, ValidationError | |
| import argparse | |
| # --- CONFIG --- | |
| DATA_PATH = "data.csv" | |
| TEXT_COLUMN = "Sanction_Context" | |
| LABEL_COLUMNS = [ | |
| "Red_Flag_Reason", "Maker_Action", "Escalation_Level", | |
| "Risk_Category", "Risk_Drivers", "Investigation_Outcome" | |
| ] | |
| MODEL_SAVE_DIR = "models" | |
| LABEL_ENCODERS_PATH = os.path.join(MODEL_SAVE_DIR, "label_encoders.pkl") | |
| TFIDF_MAX_FEATURES = 1000 | |
| NGRAM_RANGE = (1, 2) | |
| USE_STOPWORDS = True | |
| RANDOM_STATE = 42 | |
| TEST_SIZE = 0.2 | |
| API_URL = "https://your-hf-api-url.hf.space/predict" # Replace with actual URL | |
| os.makedirs(MODEL_SAVE_DIR, exist_ok=True) | |
| # --- Pydantic schema --- | |
| class TransactionData(BaseModel): | |
| Transaction_Id: str | |
| Hit_Seq: int | |
| Hit_Id_List: str | |
| Origin: str | |
| Designation: str | |
| Keywords: str | |
| Name: str | |
| SWIFT_Tag: str | |
| Currency: str | |
| Entity: str | |
| Message: str | |
| City: str | |
| Country: str | |
| State: str | |
| Hit_Type: str | |
| Record_Matching_String: str | |
| WatchList_Match_String: str | |
| Payment_Sender_Name: str | |
| Payment_Reciever_Name: str | |
| Swift_Message_Type: str | |
| Text_Sanction_Data: str | |
| Matched_Sanctioned_Entity: str | |
| Is_Match: int | |
| Red_Flag_Reason: str | |
| Risk_Level: str | |
| Risk_Score: float | |
| Risk_Score_Description: str | |
| CDD_Level: str | |
| PEP_Status: str | |
| Value_Date: str | |
| Last_Review_Date: str | |
| Next_Review_Date: str | |
| Sanction_Description: str | |
| Checker_Notes: str | |
| Sanction_Context: str | |
| Maker_Action: str | |
| Customer_ID: int | |
| Customer_Type: str | |
| Industry: str | |
| Transaction_Date_Time: str | |
| Transaction_Type: str | |
| Transaction_Channel: str | |
| Originating_Bank: str | |
| Beneficiary_Bank: str | |
| Geographic_Origin: str | |
| Geographic_Destination: str | |
| Match_Score: float | |
| Match_Type: str | |
| Sanctions_List_Version: str | |
| Screening_Date_Time: str | |
| Risk_Category: str | |
| Risk_Drivers: str | |
| Alert_Status: str | |
| Investigation_Outcome: str | |
| Case_Owner_Analyst: str | |
| Escalation_Level: str | |
| Escalation_Date: str | |
| Regulatory_Reporting_Flags: bool | |
| Audit_Trail_Timestamp: str | |
| Source_Of_Funds: str | |
| Purpose_Of_Transaction: str | |
| Beneficial_Owner: str | |
| Sanctions_Exposure_History: bool | |
| # --- Train function --- | |
| def train_pipeline(): | |
| print("π₯ Loading dataset...") | |
| df = pd.read_csv(DATA_PATH) | |
| df.dropna(subset=[TEXT_COLUMN] + LABEL_COLUMNS, inplace=True) | |
| label_encoders = {} | |
| for col in LABEL_COLUMNS: | |
| le = LabelEncoder() | |
| df[col] = le.fit_transform(df[col]) | |
| label_encoders[col] = le | |
| X = df[TEXT_COLUMN] | |
| Y = df[LABEL_COLUMNS] | |
| print("βοΈ Splitting train/test...") | |
| X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=TEST_SIZE, random_state=RANDOM_STATE) | |
| print("π§ Building pipeline with Logistic Regression...") | |
| stop_words = "english" if USE_STOPWORDS else None | |
| pipeline = Pipeline([ | |
| ('tfidf', TfidfVectorizer(max_features=TFIDF_MAX_FEATURES, ngram_range=NGRAM_RANGE, stop_words=stop_words)), | |
| ('clf', MultiOutputClassifier(LogisticRegression(random_state=RANDOM_STATE, max_iter=1000))) | |
| ]) | |
| print("ποΈ Training...") | |
| pipeline.fit(X_train, y_train) | |
| model_path = os.path.join(MODEL_SAVE_DIR, "logreg_model.pkl") | |
| print(f"πΎ Saving model to {model_path}") | |
| joblib.dump(pipeline, model_path) | |
| print(f"πΎ Saving label encoders to {LABEL_ENCODERS_PATH}") | |
| joblib.dump(label_encoders, LABEL_ENCODERS_PATH) | |
| tfidf_path = os.path.join(MODEL_SAVE_DIR, "tfidf_vectorizer.pkl") | |
| joblib.dump(pipeline.named_steps["tfidf"], tfidf_path) | |
| print("β Training complete.") | |
| # --- Input Validator --- | |
| def validate_sample_input(sample_input): | |
| try: | |
| validated = TransactionData(**sample_input) | |
| print("β Input is valid.") | |
| except ValidationError as e: | |
| print("β Validation error:") | |
| print(e.json(indent=2)) | |
| # --- API Test --- | |
| def test_api(sample_payload): | |
| headers = {"Content-Type": "application/json"} | |
| print(f"π Posting to {API_URL}") | |
| response = requests.post(API_URL, headers=headers, data=json.dumps(sample_payload)) | |
| print("π₯ Status Code:", response.status_code) | |
| try: | |
| print("π€ Response:", json.dumps(response.json(), indent=2)) | |
| except Exception as e: | |
| print("β Failed to parse response:", str(e)) | |
| # --- Sample Payload (unchanged) --- | |
| sample_payload = { | |
| "transaction_data": { | |
| "Transaction_Id": "TXN12345", | |
| "Hit_Seq": 1, | |
| "Hit_Id_List": "HIT789", | |
| "Origin": "India", | |
| "Designation": "Manager", | |
| "Keywords": "fraud", | |
| "Name": "John Doe", | |
| "SWIFT_Tag": "TAG001", | |
| "Currency": "INR", | |
| "Entity": "ABC Ltd", | |
| "Message": "Payment for services", | |
| "City": "Hyderabad", | |
| "Country": "India", | |
| "State": "Telangana", | |
| "Hit_Type": "Individual", | |
| "Record_Matching_String": "John Doe", | |
| "WatchList_Match_String": "Doe, John", | |
| "Payment_Sender_Name": "John Doe", | |
| "Payment_Reciever_Name": "Jane Smith", | |
| "Swift_Message_Type": "MT103", | |
| "Text_Sanction_Data": "Suspicious transfer to offshore account", | |
| "Matched_Sanctioned_Entity": "John Doe", | |
| "Is_Match": 1, | |
| "Red_Flag_Reason": "High value transaction", | |
| "Risk_Level": "High", | |
| "Risk_Score": 87.5, | |
| "Risk_Score_Description": "Very High", | |
| "CDD_Level": "Enhanced", | |
| "PEP_Status": "Yes", | |
| "Value_Date": "2023-01-01", | |
| "Last_Review_Date": "2023-06-01", | |
| "Next_Review_Date": "2024-06-01", | |
| "Sanction_Description": "OFAC List", | |
| "Checker_Notes": "Urgent check required", | |
| "Sanction_Context": "Payment matched with OFAC entry", | |
| "Maker_Action": "Escalate", | |
| "Customer_ID": 1001, | |
| "Customer_Type": "Corporate", | |
| "Industry": "Finance", | |
| "Transaction_Date_Time": "2023-12-15T10:00:00", | |
| "Transaction_Type": "Credit", | |
| "Transaction_Channel": "Online", | |
| "Originating_Bank": "ABC Bank", | |
| "Beneficiary_Bank": "XYZ Bank", | |
| "Geographic_Origin": "India", | |
| "Geographic_Destination": "USA", | |
| "Match_Score": 96.2, | |
| "Match_Type": "Exact", | |
| "Sanctions_List_Version": "2023-V5", | |
| "Screening_Date_Time": "2023-12-15T09:55:00", | |
| "Risk_Category": "Sanctions", | |
| "Risk_Drivers": "PEP, High Value", | |
| "Alert_Status": "Open", | |
| "Investigation_Outcome": "Pending", | |
| "Case_Owner_Analyst": "analyst1", | |
| "Escalation_Level": "L2", | |
| "Escalation_Date": "2023-12-16", | |
| "Regulatory_Reporting_Flags": True, | |
| "Audit_Trail_Timestamp": "2023-12-15T10:05:00", | |
| "Source_Of_Funds": "Corporate Account", | |
| "Purpose_Of_Transaction": "Service Payment", | |
| "Beneficial_Owner": "John Doe", | |
| "Sanctions_Exposure_History": False | |
| } | |
| } | |
| # --- Main Entry --- | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--train", action="store_true", help="Train the model") | |
| parser.add_argument("--validate", action="store_true", help="Validate sample input") | |
| parser.add_argument("--test", action="store_true", help="Test prediction API") | |
| args = parser.parse_args() | |
| if args.train: | |
| train_pipeline() | |
| if args.validate: | |
| validate_sample_input(sample_payload["transaction_data"]) | |
| if args.test: | |
| test_api(sample_payload) | |