ganeshkonapalli commited on
Commit
c778365
Β·
verified Β·
1 Parent(s): 4281e33

Create train_test_validate.py

Browse files
Files changed (1) hide show
  1. train_test_validate.py +242 -0
train_test_validate.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import joblib
4
+ import requests
5
+ import pandas as pd
6
+ from typing import List
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.feature_extraction.text import TfidfVectorizer
9
+ from sklearn.multioutput import MultiOutputClassifier
10
+ from sklearn.pipeline import Pipeline
11
+ from sklearn.preprocessing import LabelEncoder
12
+ from sklearn.linear_model import LogisticRegression
13
+ from pydantic import BaseModel, ValidationError
14
+ import argparse
15
+
16
+ # --- CONFIG ---
17
+ DATA_PATH = "data.csv"
18
+ TEXT_COLUMN = "Sanction_Context"
19
+ LABEL_COLUMNS = [
20
+ "Red_Flag_Reason", "Maker_Action", "Escalation_Level",
21
+ "Risk_Category", "Risk_Drivers", "Investigation_Outcome"
22
+ ]
23
+ MODEL_SAVE_DIR = "models"
24
+ LABEL_ENCODERS_PATH = os.path.join(MODEL_SAVE_DIR, "label_encoders.pkl")
25
+ TFIDF_MAX_FEATURES = 1000
26
+ NGRAM_RANGE = (1, 2)
27
+ USE_STOPWORDS = True
28
+ RANDOM_STATE = 42
29
+ TEST_SIZE = 0.2
30
+ API_URL = "https://your-hf-api-url.hf.space/predict" # Replace with actual URL
31
+
32
+ os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
33
+
34
+ # --- Pydantic schema ---
35
+ class TransactionData(BaseModel):
36
+ Transaction_Id: str
37
+ Hit_Seq: int
38
+ Hit_Id_List: str
39
+ Origin: str
40
+ Designation: str
41
+ Keywords: str
42
+ Name: str
43
+ SWIFT_Tag: str
44
+ Currency: str
45
+ Entity: str
46
+ Message: str
47
+ City: str
48
+ Country: str
49
+ State: str
50
+ Hit_Type: str
51
+ Record_Matching_String: str
52
+ WatchList_Match_String: str
53
+ Payment_Sender_Name: str
54
+ Payment_Reciever_Name: str
55
+ Swift_Message_Type: str
56
+ Text_Sanction_Data: str
57
+ Matched_Sanctioned_Entity: str
58
+ Is_Match: int
59
+ Red_Flag_Reason: str
60
+ Risk_Level: str
61
+ Risk_Score: float
62
+ Risk_Score_Description: str
63
+ CDD_Level: str
64
+ PEP_Status: str
65
+ Value_Date: str
66
+ Last_Review_Date: str
67
+ Next_Review_Date: str
68
+ Sanction_Description: str
69
+ Checker_Notes: str
70
+ Sanction_Context: str
71
+ Maker_Action: str
72
+ Customer_ID: int
73
+ Customer_Type: str
74
+ Industry: str
75
+ Transaction_Date_Time: str
76
+ Transaction_Type: str
77
+ Transaction_Channel: str
78
+ Originating_Bank: str
79
+ Beneficiary_Bank: str
80
+ Geographic_Origin: str
81
+ Geographic_Destination: str
82
+ Match_Score: float
83
+ Match_Type: str
84
+ Sanctions_List_Version: str
85
+ Screening_Date_Time: str
86
+ Risk_Category: str
87
+ Risk_Drivers: str
88
+ Alert_Status: str
89
+ Investigation_Outcome: str
90
+ Case_Owner_Analyst: str
91
+ Escalation_Level: str
92
+ Escalation_Date: str
93
+ Regulatory_Reporting_Flags: bool
94
+ Audit_Trail_Timestamp: str
95
+ Source_Of_Funds: str
96
+ Purpose_Of_Transaction: str
97
+ Beneficial_Owner: str
98
+ Sanctions_Exposure_History: bool
99
+
100
+ # --- Train function ---
101
+ def train_pipeline():
102
+ print("πŸ“₯ Loading dataset...")
103
+ df = pd.read_csv(DATA_PATH)
104
+ df.dropna(subset=[TEXT_COLUMN] + LABEL_COLUMNS, inplace=True)
105
+
106
+ label_encoders = {}
107
+ for col in LABEL_COLUMNS:
108
+ le = LabelEncoder()
109
+ df[col] = le.fit_transform(df[col])
110
+ label_encoders[col] = le
111
+
112
+ X = df[TEXT_COLUMN]
113
+ Y = df[LABEL_COLUMNS]
114
+
115
+ print("βœ‚οΈ Splitting train/test...")
116
+ X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=TEST_SIZE, random_state=RANDOM_STATE)
117
+
118
+ print("πŸ”§ Building pipeline with Logistic Regression...")
119
+ stop_words = "english" if USE_STOPWORDS else None
120
+ pipeline = Pipeline([
121
+ ('tfidf', TfidfVectorizer(max_features=TFIDF_MAX_FEATURES, ngram_range=NGRAM_RANGE, stop_words=stop_words)),
122
+ ('clf', MultiOutputClassifier(LogisticRegression(random_state=RANDOM_STATE, max_iter=1000)))
123
+ ])
124
+
125
+ print("πŸ‹οΈ Training...")
126
+ pipeline.fit(X_train, y_train)
127
+
128
+ model_path = os.path.join(MODEL_SAVE_DIR, "logreg_model.pkl")
129
+ print(f"πŸ’Ύ Saving model to {model_path}")
130
+ joblib.dump(pipeline, model_path)
131
+
132
+ print(f"πŸ’Ύ Saving label encoders to {LABEL_ENCODERS_PATH}")
133
+ joblib.dump(label_encoders, LABEL_ENCODERS_PATH)
134
+
135
+ tfidf_path = os.path.join(MODEL_SAVE_DIR, "tfidf_vectorizer.pkl")
136
+ joblib.dump(pipeline.named_steps["tfidf"], tfidf_path)
137
+
138
+ print("βœ… Training complete.")
139
+
140
+ # --- Input Validator ---
141
+ def validate_sample_input(sample_input):
142
+ try:
143
+ validated = TransactionData(**sample_input)
144
+ print("βœ… Input is valid.")
145
+ except ValidationError as e:
146
+ print("❌ Validation error:")
147
+ print(e.json(indent=2))
148
+
149
+ # --- API Test ---
150
+ def test_api(sample_payload):
151
+ headers = {"Content-Type": "application/json"}
152
+ print(f"πŸš€ Posting to {API_URL}")
153
+ response = requests.post(API_URL, headers=headers, data=json.dumps(sample_payload))
154
+ print("πŸ“₯ Status Code:", response.status_code)
155
+ try:
156
+ print("πŸ“€ Response:", json.dumps(response.json(), indent=2))
157
+ except Exception as e:
158
+ print("❌ Failed to parse response:", str(e))
159
+
160
+ # --- Sample Payload (unchanged) ---
161
+ sample_payload = {
162
+ "transaction_data": {
163
+ "Transaction_Id": "TXN12345",
164
+ "Hit_Seq": 1,
165
+ "Hit_Id_List": "HIT789",
166
+ "Origin": "India",
167
+ "Designation": "Manager",
168
+ "Keywords": "fraud",
169
+ "Name": "John Doe",
170
+ "SWIFT_Tag": "TAG001",
171
+ "Currency": "INR",
172
+ "Entity": "ABC Ltd",
173
+ "Message": "Payment for services",
174
+ "City": "Hyderabad",
175
+ "Country": "India",
176
+ "State": "Telangana",
177
+ "Hit_Type": "Individual",
178
+ "Record_Matching_String": "John Doe",
179
+ "WatchList_Match_String": "Doe, John",
180
+ "Payment_Sender_Name": "John Doe",
181
+ "Payment_Reciever_Name": "Jane Smith",
182
+ "Swift_Message_Type": "MT103",
183
+ "Text_Sanction_Data": "Suspicious transfer to offshore account",
184
+ "Matched_Sanctioned_Entity": "John Doe",
185
+ "Is_Match": 1,
186
+ "Red_Flag_Reason": "High value transaction",
187
+ "Risk_Level": "High",
188
+ "Risk_Score": 87.5,
189
+ "Risk_Score_Description": "Very High",
190
+ "CDD_Level": "Enhanced",
191
+ "PEP_Status": "Yes",
192
+ "Value_Date": "2023-01-01",
193
+ "Last_Review_Date": "2023-06-01",
194
+ "Next_Review_Date": "2024-06-01",
195
+ "Sanction_Description": "OFAC List",
196
+ "Checker_Notes": "Urgent check required",
197
+ "Sanction_Context": "Payment matched with OFAC entry",
198
+ "Maker_Action": "Escalate",
199
+ "Customer_ID": 1001,
200
+ "Customer_Type": "Corporate",
201
+ "Industry": "Finance",
202
+ "Transaction_Date_Time": "2023-12-15T10:00:00",
203
+ "Transaction_Type": "Credit",
204
+ "Transaction_Channel": "Online",
205
+ "Originating_Bank": "ABC Bank",
206
+ "Beneficiary_Bank": "XYZ Bank",
207
+ "Geographic_Origin": "India",
208
+ "Geographic_Destination": "USA",
209
+ "Match_Score": 96.2,
210
+ "Match_Type": "Exact",
211
+ "Sanctions_List_Version": "2023-V5",
212
+ "Screening_Date_Time": "2023-12-15T09:55:00",
213
+ "Risk_Category": "Sanctions",
214
+ "Risk_Drivers": "PEP, High Value",
215
+ "Alert_Status": "Open",
216
+ "Investigation_Outcome": "Pending",
217
+ "Case_Owner_Analyst": "analyst1",
218
+ "Escalation_Level": "L2",
219
+ "Escalation_Date": "2023-12-16",
220
+ "Regulatory_Reporting_Flags": True,
221
+ "Audit_Trail_Timestamp": "2023-12-15T10:05:00",
222
+ "Source_Of_Funds": "Corporate Account",
223
+ "Purpose_Of_Transaction": "Service Payment",
224
+ "Beneficial_Owner": "John Doe",
225
+ "Sanctions_Exposure_History": False
226
+ }
227
+ }
228
+
229
+ # --- Main Entry ---
230
+ if __name__ == "__main__":
231
+ parser = argparse.ArgumentParser()
232
+ parser.add_argument("--train", action="store_true", help="Train the model")
233
+ parser.add_argument("--validate", action="store_true", help="Validate sample input")
234
+ parser.add_argument("--test", action="store_true", help="Test prediction API")
235
+ args = parser.parse_args()
236
+
237
+ if args.train:
238
+ train_pipeline()
239
+ if args.validate:
240
+ validate_sample_input(sample_payload["transaction_data"])
241
+ if args.test:
242
+ test_api(sample_payload)