subbunanepalli commited on
Commit
cfa3c0d
·
verified ·
1 Parent(s): 50769c6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import Optional
4
+ import pandas as pd
5
+ import joblib
6
+ import os
7
+
8
+ from sklearn.feature_extraction.text import TfidfVectorizer
9
+ from sklearn.linear_model import LogisticRegression
10
+ from sklearn.model_selection import train_test_split
11
+ from sklearn.multioutput import MultiOutputClassifier
12
+ from sklearn.pipeline import Pipeline
13
+
14
+ # ========== Config ==========
15
+ DATA_PATH = "data/synthetic_transactions_samples_5000.csv"
16
+ MODEL_DIR = "models"
17
+ MODEL_PATH = os.path.join(MODEL_DIR, "logreg_model.pkl")
18
+ VECTORIZER_PATH = os.path.join(MODEL_DIR, "tfidf_vectorizer.pkl")
19
+
20
+ # ========== FastAPI Init ==========
21
+ app = FastAPI()
22
+
23
+ # ========== Input Schema ==========
24
+ class TransactionData(BaseModel):
25
+ Transaction_Id: str
26
+ Hit_Seq: int
27
+ Hit_Id_List: str
28
+ Origin: str
29
+ Designation: str
30
+ Keywords: str
31
+ Name: str
32
+ SWIFT_Tag: str
33
+ Currency: str
34
+ Entity: str
35
+ Message: str
36
+ City: str
37
+ Country: str
38
+ State: str
39
+ Hit_Type: str
40
+ Record_Matching_String: str
41
+ WatchList_Match_String: str
42
+ Payment_Sender_Name: Optional[str] = ""
43
+ Payment_Reciever_Name: Optional[str] = ""
44
+ Swift_Message_Type: str
45
+ Text_Sanction_Data: str
46
+ Matched_Sanctioned_Entity: str
47
+ Is_Match: int
48
+ Red_Flag_Reason: str
49
+ Risk_Level: str
50
+ Risk_Score: float
51
+ Risk_Score_Description: str
52
+ CDD_Level: str
53
+ PEP_Status: str
54
+ Value_Date: str
55
+ Last_Review_Date: str
56
+ Next_Review_Date: str
57
+ Sanction_Description: str
58
+ Checker_Notes: str
59
+ Sanction_Context: str
60
+ Maker_Action: str
61
+ Customer_ID: int
62
+ Customer_Type: str
63
+ Industry: str
64
+ Transaction_Date_Time: str
65
+ Transaction_Type: str
66
+ Transaction_Channel: str
67
+ Originating_Bank: str
68
+ Beneficiary_Bank: str
69
+ Geographic_Origin: str
70
+ Geographic_Destination: str
71
+ Match_Score: float
72
+ Match_Type: str
73
+ Sanctions_List_Version: str
74
+ Screening_Date_Time: str
75
+ Risk_Category: str
76
+ Risk_Drivers: str
77
+ Alert_Status: str
78
+ Investigation_Outcome: str
79
+ Case_Owner_Analyst: str
80
+ Escalation_Level: str
81
+ Escalation_Date: str
82
+ Regulatory_Reporting_Flags: bool
83
+ Audit_Trail_Timestamp: str
84
+ Source_Of_Funds: str
85
+ Purpose_Of_Transaction: str
86
+ Beneficial_Owner: str
87
+ Sanctions_Exposure_History: bool
88
+
89
+ # ========== Utils ==========
90
+ def create_text_input(row):
91
+ return f"""
92
+ Transaction ID: {row['Transaction_Id']}
93
+ Origin: {row['Origin']}
94
+ Designation: {row['Designation']}
95
+ Keywords: {row['Keywords']}
96
+ Name: {row['Name']}
97
+ SWIFT Tag: {row['SWIFT_Tag']}
98
+ Currency: {row['Currency']}
99
+ Entity: {row['Entity']}
100
+ Message: {row['Message']}
101
+ City: {row['City']}
102
+ Country: {row['Country']}
103
+ State: {row['State']}
104
+ Hit Type: {row['Hit_Type']}
105
+ Record Matching String: {row['Record_Matching_String']}
106
+ WatchList Match String: {row['WatchList_Match_String']}
107
+ Payment Sender: {row['Payment_Sender_Name']}
108
+ Payment Receiver: {row['Payment_Reciever_Name']}
109
+ Swift Message Type: {row['Swift_Message_Type']}
110
+ Text Sanction Data: {row['Text_Sanction_Data']}
111
+ Matched Sanctioned Entity: {row['Matched_Sanctioned_Entity']}
112
+ Red Flag Reason: {row['Red_Flag_Reason']}
113
+ Risk Level: {row['Risk_Level']}
114
+ Risk Score: {row['Risk_Score']}
115
+ CDD Level: {row['CDD_Level']}
116
+ PEP Status: {row['PEP_Status']}
117
+ Sanction Description: {row['Sanction_Description']}
118
+ Checker Notes: {row['Checker_Notes']}
119
+ Sanction Context: {row['Sanction_Context']}
120
+ Maker Action: {row['Maker_Action']}
121
+ Customer Type: {row['Customer_Type']}
122
+ Industry: {row['Industry']}
123
+ Transaction Type: {row['Transaction_Type']}
124
+ Transaction Channel: {row['Transaction_Channel']}
125
+ Geographic Origin: {row['Geographic_Origin']}
126
+ Geographic Destination: {row['Geographic_Destination']}
127
+ Risk Category: {row['Risk_Category']}
128
+ Risk Drivers: {row['Risk_Drivers']}
129
+ Alert Status: {row['Alert_Status']}
130
+ Investigation Outcome: {row['Investigation_Outcome']}
131
+ Source of Funds: {row['Source_Of_Funds']}
132
+ Purpose of Transaction: {row['Purpose_Of_Transaction']}
133
+ Beneficial Owner: {row['Beneficial_Owner']}
134
+ """
135
+
136
+ # ========== API Routes ==========
137
+ @app.post("/train")
138
+ def train_model():
139
+ df = pd.read_csv(DATA_PATH)
140
+ df = df.fillna("")
141
+
142
+ df["text_input"] = df.apply(create_text_input, axis=1)
143
+ X = df["text_input"]
144
+ y = df[["Maker_Action", "Escalation_Level", "Risk_Category", "Risk_Drivers", "Investigation_Outcome"]]
145
+
146
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
147
+
148
+ vectorizer = TfidfVectorizer()
149
+ classifier = MultiOutputClassifier(LogisticRegression(max_iter=1000))
150
+ pipeline = Pipeline([
151
+ ("vectorizer", vectorizer),
152
+ ("classifier", classifier)
153
+ ])
154
+
155
+ pipeline.fit(X_train, y_train)
156
+
157
+ os.makedirs(MODEL_DIR, exist_ok=True)
158
+ joblib.dump(pipeline, MODEL_PATH)
159
+
160
+ accuracy = pipeline.score(X_test, y_test)
161
+ return {"message": "Model trained and saved.", "accuracy": accuracy}
162
+
163
+ @app.post("/predict")
164
+ def predict(request: TransactionData):
165
+ try:
166
+ model = joblib.load(MODEL_PATH)
167
+ input_data = pd.DataFrame([request.dict()])
168
+ input_data = input_data.fillna("")
169
+ text_input = create_text_input(input_data.iloc[0])
170
+ prediction = model.predict([text_input])[0]
171
+ return {
172
+ "Maker_Action": prediction[0],
173
+ "Escalation_Level": prediction[1],
174
+ "Risk_Category": prediction[2],
175
+ "Risk_Drivers": prediction[3],
176
+ "Investigation_Outcome": prediction[4],
177
+ }
178
+ except Exception as e:
179
+ raise HTTPException(status_code=500, detail=str(e))
180
+
181
+ @app.post("/validate")
182
+ def validate_input(request: TransactionData):
183
+ return {"message": "Input is valid."}
184
+
185
+ @app.get("/test")
186
+ def test_api():
187
+ return {"message": "Test successful."}