namanpenguin commited on
Commit
7d1afe2
·
verified ·
1 Parent(s): e03a7c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -27
app.py CHANGED
@@ -3,22 +3,91 @@ from pydantic import BaseModel
3
  import torch
4
  from transformers import BertTokenizer
5
  from models.bert_model import BertMultiOutputModel
6
- from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, DEVICE
7
- from dataset_utils import load_label_encoders
 
8
  import numpy as np
9
  import os
 
 
 
10
 
11
  app = FastAPI()
12
 
13
  # Load the model and tokenizer
14
  model_path = "BERT_model.pth"
15
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
16
  model = BertMultiOutputModel([len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE)
17
  model.load_state_dict(torch.load(model_path, map_location=DEVICE))
18
  model.eval()
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class PredictionRequest(BaseModel):
21
- sanction_context: str
22
 
23
  @app.get("/")
24
  async def root():
@@ -31,38 +100,88 @@ async def health_check():
31
  @app.post("/predict")
32
  async def predict(request: PredictionRequest):
33
  try:
34
- # Tokenize the input text
35
- inputs = tokenizer(
36
- request.sanction_context,
37
- padding='max_length',
38
- truncation=True,
39
- max_length=MAX_LEN,
40
- return_tensors="pt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  )
42
 
43
- # Move inputs to device
44
- input_ids = inputs['input_ids'].to(DEVICE)
45
- attention_mask = inputs['attention_mask'].to(DEVICE)
46
-
47
- # Get predictions
48
- with torch.no_grad():
49
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
50
- probabilities = [torch.softmax(output, dim=1).cpu().numpy() for output in outputs]
51
- predictions = [np.argmax(prob, axis=1) for prob in probabilities]
52
-
53
  # Load label encoders to decode predictions
54
  label_encoders = load_label_encoders()
55
 
56
  # Format response
57
  response = {}
58
- for i, (col, pred, prob) in enumerate(zip(LABEL_COLUMNS, predictions, probabilities)):
59
- decoded_pred = label_encoders[col].inverse_transform(pred)[0]
 
 
 
 
 
 
 
 
 
60
  response[col] = {
61
  "prediction": decoded_pred,
62
- "probabilities": {
63
- label: float(prob[0][j])
64
- for j, label in enumerate(label_encoders[col].classes_)
65
- }
66
  }
67
 
68
  return response
 
3
  import torch
4
  from transformers import BertTokenizer
5
  from models.bert_model import BertMultiOutputModel
6
+ from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, DEVICE, METADATA_COLUMNS
7
+ from dataset_utils import load_label_encoders, get_tokenizer, ComplianceDataset
8
+ from train_utils import predict_probabilities
9
  import numpy as np
10
  import os
11
+ import pandas as pd
12
+ from typing import Dict, Any, Optional
13
+ from torch.utils.data import DataLoader
14
 
15
  app = FastAPI()
16
 
17
  # Load the model and tokenizer
18
  model_path = "BERT_model.pth"
19
+ tokenizer = get_tokenizer('bert-base-uncased')
20
  model = BertMultiOutputModel([len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE)
21
  model.load_state_dict(torch.load(model_path, map_location=DEVICE))
22
  model.eval()
23
 
24
+ class TransactionData(BaseModel):
25
+ Transaction_Id: str
26
+ Hit_Seq: int
27
+ Hit_Id_List: str
28
+ Origin: str
29
+ Designation: str
30
+ Keywords: str
31
+ Name: str
32
+ SWIFT_Tag: str
33
+ Currency: str
34
+ Entity: str
35
+ Message: str
36
+ City: str
37
+ Country: str
38
+ State: str
39
+ Hit_Type: str
40
+ Record_Matching_String: str
41
+ WatchList_Match_String: str
42
+ Payment_Sender_Name: Optional[str] = ""
43
+ Payment_Reciever_Name: Optional[str] = ""
44
+ Swift_Message_Type: str
45
+ Text_Sanction_Data: str
46
+ Matched_Sanctioned_Entity: str
47
+ Is_Match: int
48
+ Red_Flag_Reason: str
49
+ Risk_Level: str
50
+ Risk_Score: float
51
+ Risk_Score_Description: str
52
+ CDD_Level: str
53
+ PEP_Status: str
54
+ Value_Date: str
55
+ Last_Review_Date: str
56
+ Next_Review_Date: str
57
+ Sanction_Description: str
58
+ Checker_Notes: str
59
+ Sanction_Context: str
60
+ Maker_Action: str
61
+ Customer_ID: int
62
+ Customer_Type: str
63
+ Industry: str
64
+ Transaction_Date_Time: str
65
+ Transaction_Type: str
66
+ Transaction_Channel: str
67
+ Originating_Bank: str
68
+ Beneficiary_Bank: str
69
+ Geographic_Origin: str
70
+ Geographic_Destination: str
71
+ Match_Score: float
72
+ Match_Type: str
73
+ Sanctions_List_Version: str
74
+ Screening_Date_Time: str
75
+ Risk_Category: str
76
+ Risk_Drivers: str
77
+ Alert_Status: str
78
+ Investigation_Outcome: str
79
+ Case_Owner_Analyst: str
80
+ Escalation_Level: str
81
+ Escalation_Date: str
82
+ Regulatory_Reporting_Flags: bool
83
+ Audit_Trail_Timestamp: str
84
+ Source_Of_Funds: str
85
+ Purpose_Of_Transaction: str
86
+ Beneficial_Owner: str
87
+ Sanctions_Exposure_History: bool
88
+
89
  class PredictionRequest(BaseModel):
90
+ transaction_data: TransactionData
91
 
92
  @app.get("/")
93
  async def root():
 
100
  @app.post("/predict")
101
  async def predict(request: PredictionRequest):
102
  try:
103
+ # Convert transaction data to DataFrame
104
+ input_data = pd.DataFrame([request.transaction_data.dict()])
105
+
106
+ # Create the text input by combining relevant fields
107
+ text_input = f"""
108
+ Transaction ID: {input_data['Transaction_Id'].iloc[0]}
109
+ Origin: {input_data['Origin'].iloc[0]}
110
+ Designation: {input_data['Designation'].iloc[0]}
111
+ Keywords: {input_data['Keywords'].iloc[0]}
112
+ Name: {input_data['Name'].iloc[0]}
113
+ SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]}
114
+ Currency: {input_data['Currency'].iloc[0]}
115
+ Entity: {input_data['Entity'].iloc[0]}
116
+ Message: {input_data['Message'].iloc[0]}
117
+ City: {input_data['City'].iloc[0]}
118
+ Country: {input_data['Country'].iloc[0]}
119
+ State: {input_data['State'].iloc[0]}
120
+ Hit Type: {input_data['Hit_Type'].iloc[0]}
121
+ Record Matching String: {input_data['Record_Matching_String'].iloc[0]}
122
+ WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]}
123
+ Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]}
124
+ Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]}
125
+ Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]}
126
+ Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]}
127
+ Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]}
128
+ Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]}
129
+ Risk Level: {input_data['Risk_Level'].iloc[0]}
130
+ Risk Score: {input_data['Risk_Score'].iloc[0]}
131
+ CDD Level: {input_data['CDD_Level'].iloc[0]}
132
+ PEP Status: {input_data['PEP_Status'].iloc[0]}
133
+ Sanction Description: {input_data['Sanction_Description'].iloc[0]}
134
+ Checker Notes: {input_data['Checker_Notes'].iloc[0]}
135
+ Sanction Context: {input_data['Sanction_Context'].iloc[0]}
136
+ Maker Action: {input_data['Maker_Action'].iloc[0]}
137
+ Customer Type: {input_data['Customer_Type'].iloc[0]}
138
+ Industry: {input_data['Industry'].iloc[0]}
139
+ Transaction Type: {input_data['Transaction_Type'].iloc[0]}
140
+ Transaction Channel: {input_data['Transaction_Channel'].iloc[0]}
141
+ Geographic Origin: {input_data['Geographic_Origin'].iloc[0]}
142
+ Geographic Destination: {input_data['Geographic_Destination'].iloc[0]}
143
+ Risk Category: {input_data['Risk_Category'].iloc[0]}
144
+ Risk Drivers: {input_data['Risk_Drivers'].iloc[0]}
145
+ Alert Status: {input_data['Alert_Status'].iloc[0]}
146
+ Investigation Outcome: {input_data['Investigation_Outcome'].iloc[0]}
147
+ Source of Funds: {input_data['Source_Of_Funds'].iloc[0]}
148
+ Purpose of Transaction: {input_data['Purpose_Of_Transaction'].iloc[0]}
149
+ Beneficial Owner: {input_data['Beneficial_Owner'].iloc[0]}
150
+ """
151
+
152
+ # Create dataset instance
153
+ dataset = ComplianceDataset(
154
+ texts=[text_input],
155
+ labels=[[0] * len(LABEL_COLUMNS)], # Dummy labels for prediction
156
+ tokenizer=tokenizer,
157
+ max_len=MAX_LEN
158
  )
159
 
160
+ # Create DataLoader
161
+ loader = DataLoader(dataset, batch_size=1, shuffle=False)
162
+
163
+ # Get prediction probabilities using the predict_probabilities function
164
+ all_probabilities = predict_probabilities(model, loader)
165
+
 
 
 
 
166
  # Load label encoders to decode predictions
167
  label_encoders = load_label_encoders()
168
 
169
  # Format response
170
  response = {}
171
+ for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
172
+ # Get the prediction (argmax of probabilities)
173
+ pred = np.argmax(probs[0])
174
+ decoded_pred = label_encoders[col].inverse_transform([pred])[0]
175
+
176
+ # Get probabilities for each class
177
+ class_probs = {
178
+ label: float(probs[0][j])
179
+ for j, label in enumerate(label_encoders[col].classes_)
180
+ }
181
+
182
  response[col] = {
183
  "prediction": decoded_pred,
184
+ "probabilities": class_probs
 
 
 
185
  }
186
 
187
  return response