subbunanepalli commited on
Commit
5968853
Β·
verified Β·
1 Parent(s): 54dc338

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -0
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import torch
4
+ import pandas as pd
5
+ import numpy as np
6
+ import os
7
+ from typing import Optional
8
+ from torch.utils.data import DataLoader
9
+
10
+ from transformers import DebertaTokenizer
11
+ from config import (
12
+ TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, DEVICE,
13
+ DEBERTA_MODEL_NAME, METADATA_COLUMNS
14
+ )
15
+ from dataset_utils import load_label_encoders, get_tokenizer, ComplianceDataset
16
+ from train_utils import predict_probabilities
17
+ from models.deberta_model import DebertaMultiOutputModel
18
+
19
+ app = FastAPI()
20
+
21
+ # βœ… Load model and tokenizer
22
+ model_path = "saved_models/DEBERTA_model.pth"
23
+ tokenizer = get_tokenizer(DEBERTA_MODEL_NAME)
24
+ label_encoders = load_label_encoders()
25
+ num_classes = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
26
+
27
+ model = DebertaMultiOutputModel(num_classes).to(DEVICE)
28
+ if not os.path.exists(model_path):
29
+ raise FileNotFoundError(f"❌ Model not found at {model_path}")
30
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
31
+ model.eval()
32
+
33
+ # βœ… Request schema
34
+ class TransactionData(BaseModel):
35
+ Transaction_Id: str
36
+ Hit_Seq: int
37
+ Hit_Id_List: str
38
+ Origin: str
39
+ Designation: str
40
+ Keywords: str
41
+ Name: str
42
+ SWIFT_Tag: str
43
+ Currency: str
44
+ Entity: str
45
+ Message: str
46
+ City: str
47
+ Country: str
48
+ State: str
49
+ Hit_Type: str
50
+ Record_Matching_String: str
51
+ WatchList_Match_String: str
52
+ Payment_Sender_Name: Optional[str] = ""
53
+ Payment_Reciever_Name: Optional[str] = ""
54
+ Swift_Message_Type: str
55
+ Text_Sanction_Data: str
56
+ Matched_Sanctioned_Entity: str
57
+ Is_Match: int
58
+ Red_Flag_Reason: str
59
+ Risk_Level: str
60
+ Risk_Score: float
61
+ Risk_Score_Description: str
62
+ CDD_Level: str
63
+ PEP_Status: str
64
+ Value_Date: str
65
+ Last_Review_Date: str
66
+ Next_Review_Date: str
67
+ Sanction_Description: str
68
+ Checker_Notes: str
69
+ Sanction_Context: str
70
+ Maker_Action: str
71
+ Customer_ID: int
72
+ Customer_Type: str
73
+ Industry: str
74
+ Transaction_Date_Time: str
75
+ Transaction_Type: str
76
+ Transaction_Channel: str
77
+ Originating_Bank: str
78
+ Beneficiary_Bank: str
79
+ Geographic_Origin: str
80
+ Geographic_Destination: str
81
+ Match_Score: float
82
+ Match_Type: str
83
+ Sanctions_List_Version: str
84
+ Screening_Date_Time: str
85
+ Risk_Category: str
86
+ Risk_Drivers: str
87
+ Alert_Status: str
88
+ Investigation_Outcome: str
89
+ Case_Owner_Analyst: str
90
+ Escalation_Level: str
91
+ Escalation_Date: str
92
+ Regulatory_Reporting_Flags: bool
93
+ Audit_Trail_Timestamp: str
94
+ Source_Of_Funds: str
95
+ Purpose_Of_Transaction: str
96
+ Beneficial_Owner: str
97
+ Sanctions_Exposure_History: bool
98
+
99
+ class PredictionRequest(BaseModel):
100
+ transaction_data: TransactionData
101
+
102
+ # βœ… Health check routes
103
+ @app.get("/")
104
+ async def root():
105
+ return {"status": "healthy", "message": "DeBERTa API is running"}
106
+
107
+ @app.get("/health")
108
+ async def health_check():
109
+ return {"status": "healthy"}
110
+
111
+ # βœ… Inference endpoint
112
+ @app.post("/predict")
113
+ async def predict(request: PredictionRequest):
114
+ try:
115
+ input_data = pd.DataFrame([request.transaction_data.dict()])
116
+
117
+ # 🧠 Construct input text
118
+ text_input = f"""
119
+ Transaction ID: {input_data['Transaction_Id'].iloc[0]}
120
+ Origin: {input_data['Origin'].iloc[0]}
121
+ Designation: {input_data['Designation'].iloc[0]}
122
+ Keywords: {input_data['Keywords'].iloc[0]}
123
+ Name: {input_data['Name'].iloc[0]}
124
+ SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]}
125
+ Currency: {input_data['Currency'].iloc[0]}
126
+ Entity: {input_data['Entity'].iloc[0]}
127
+ Message: {input_data['Message'].iloc[0]}
128
+ City: {input_data['City'].iloc[0]}
129
+ Country: {input_data['Country'].iloc[0]}
130
+ State: {input_data['State'].iloc[0]}
131
+ Hit Type: {input_data['Hit_Type'].iloc[0]}
132
+ Record Matching String: {input_data['Record_Matching_String'].iloc[0]}
133
+ WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]}
134
+ Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]}
135
+ Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]}
136
+ Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]}
137
+ Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]}
138
+ Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]}
139
+ Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]}
140
+ Risk Level: {input_data['Risk_Level'].iloc[0]}
141
+ Risk Score: {input_data['Risk_Score'].iloc[0]}
142
+ CDD Level: {input_data['CDD_Level'].iloc[0]}
143
+ PEP Status: {input_data['PEP_Status'].iloc[0]}
144
+ Sanction Description: {input_data['Sanction_Description'].iloc[0]}
145
+ Checker Notes: {input_data['Checker_Notes'].iloc[0]}
146
+ Sanction Context: {input_data['Sanction_Context'].iloc[0]}
147
+ Maker Action: {input_data['Maker_Action'].iloc[0]}
148
+ Customer Type: {input_data['Customer_Type'].iloc[0]}
149
+ Industry: {input_data['Industry'].iloc[0]}
150
+ Transaction Type: {input_data['Transaction_Type'].iloc[0]}
151
+ Transaction Channel: {input_data['Transaction_Channel'].iloc[0]}
152
+ Geographic Origin: {input_data['Geographic_Origin'].iloc[0]}
153
+ Geographic Destination: {input_data['Geographic_Destination'].iloc[0]}
154
+ Risk Category: {input_data['Risk_Category'].iloc[0]}
155
+ Risk Drivers: {input_data['Risk_Drivers'].iloc[0]}
156
+ Alert Status: {input_data['Alert_Status'].iloc[0]}
157
+ Investigation Outcome: {input_data['Investigation_Outcome'].iloc[0]}
158
+ Source of Funds: {input_data['Source_Of_Funds'].iloc[0]}
159
+ Purpose of Transaction: {input_data['Purpose_Of_Transaction'].iloc[0]}
160
+ Beneficial Owner: {input_data['Beneficial_Owner'].iloc[0]}
161
+ """
162
+
163
+ # βš™οΈ Create dataset
164
+ dataset = ComplianceDataset(
165
+ texts=[text_input],
166
+ labels=[[0] * len(LABEL_COLUMNS)],
167
+ tokenizer=tokenizer,
168
+ max_len=MAX_LEN
169
+ )
170
+ loader = DataLoader(dataset, batch_size=1)
171
+
172
+ all_probabilities = predict_probabilities(model, loader)
173
+
174
+ # πŸ“Š Format response
175
+ response = {}
176
+ for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
177
+ pred = np.argmax(probs[0])
178
+ decoded_pred = label_encoders[col].inverse_transform([pred])[0]
179
+ response[col] = {
180
+ "prediction": decoded_pred,
181
+ "probabilities": {
182
+ label: float(probs[0][j])
183
+ for j, label in enumerate(label_encoders[col].classes_)
184
+ }
185
+ }
186
+
187
+ return response
188
+
189
+ except Exception as e:
190
+ raise HTTPException(status_code=500, detail=str(e))
191
+
192
+ # πŸ–₯️ Entry point for Spaces
193
+ if __name__ == "__main__":
194
+ import uvicorn
195
+ port = int(os.environ.get("PORT", 7860))
196
+ uvicorn.run(app, host="0.0.0.0", port=port)