bert-new / app.py
namanpenguin's picture
Update app.py
7d1afe2 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import BertTokenizer
from models.bert_model import BertMultiOutputModel
from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, DEVICE, METADATA_COLUMNS
from dataset_utils import load_label_encoders, get_tokenizer, ComplianceDataset
from train_utils import predict_probabilities
import numpy as np
import os
import pandas as pd
from typing import Dict, Any, Optional
from torch.utils.data import DataLoader
app = FastAPI()
# Load the model and tokenizer
model_path = "BERT_model.pth"
tokenizer = get_tokenizer('bert-base-uncased')
model = BertMultiOutputModel([len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
class TransactionData(BaseModel):
Transaction_Id: str
Hit_Seq: int
Hit_Id_List: str
Origin: str
Designation: str
Keywords: str
Name: str
SWIFT_Tag: str
Currency: str
Entity: str
Message: str
City: str
Country: str
State: str
Hit_Type: str
Record_Matching_String: str
WatchList_Match_String: str
Payment_Sender_Name: Optional[str] = ""
Payment_Reciever_Name: Optional[str] = ""
Swift_Message_Type: str
Text_Sanction_Data: str
Matched_Sanctioned_Entity: str
Is_Match: int
Red_Flag_Reason: str
Risk_Level: str
Risk_Score: float
Risk_Score_Description: str
CDD_Level: str
PEP_Status: str
Value_Date: str
Last_Review_Date: str
Next_Review_Date: str
Sanction_Description: str
Checker_Notes: str
Sanction_Context: str
Maker_Action: str
Customer_ID: int
Customer_Type: str
Industry: str
Transaction_Date_Time: str
Transaction_Type: str
Transaction_Channel: str
Originating_Bank: str
Beneficiary_Bank: str
Geographic_Origin: str
Geographic_Destination: str
Match_Score: float
Match_Type: str
Sanctions_List_Version: str
Screening_Date_Time: str
Risk_Category: str
Risk_Drivers: str
Alert_Status: str
Investigation_Outcome: str
Case_Owner_Analyst: str
Escalation_Level: str
Escalation_Date: str
Regulatory_Reporting_Flags: bool
Audit_Trail_Timestamp: str
Source_Of_Funds: str
Purpose_Of_Transaction: str
Beneficial_Owner: str
Sanctions_Exposure_History: bool
class PredictionRequest(BaseModel):
transaction_data: TransactionData
@app.get("/")
async def root():
return {"status": "healthy", "message": "BERT API is running"}
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.post("/predict")
async def predict(request: PredictionRequest):
try:
# Convert transaction data to DataFrame
input_data = pd.DataFrame([request.transaction_data.dict()])
# Create the text input by combining relevant fields
text_input = f"""
Transaction ID: {input_data['Transaction_Id'].iloc[0]}
Origin: {input_data['Origin'].iloc[0]}
Designation: {input_data['Designation'].iloc[0]}
Keywords: {input_data['Keywords'].iloc[0]}
Name: {input_data['Name'].iloc[0]}
SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]}
Currency: {input_data['Currency'].iloc[0]}
Entity: {input_data['Entity'].iloc[0]}
Message: {input_data['Message'].iloc[0]}
City: {input_data['City'].iloc[0]}
Country: {input_data['Country'].iloc[0]}
State: {input_data['State'].iloc[0]}
Hit Type: {input_data['Hit_Type'].iloc[0]}
Record Matching String: {input_data['Record_Matching_String'].iloc[0]}
WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]}
Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]}
Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]}
Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]}
Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]}
Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]}
Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]}
Risk Level: {input_data['Risk_Level'].iloc[0]}
Risk Score: {input_data['Risk_Score'].iloc[0]}
CDD Level: {input_data['CDD_Level'].iloc[0]}
PEP Status: {input_data['PEP_Status'].iloc[0]}
Sanction Description: {input_data['Sanction_Description'].iloc[0]}
Checker Notes: {input_data['Checker_Notes'].iloc[0]}
Sanction Context: {input_data['Sanction_Context'].iloc[0]}
Maker Action: {input_data['Maker_Action'].iloc[0]}
Customer Type: {input_data['Customer_Type'].iloc[0]}
Industry: {input_data['Industry'].iloc[0]}
Transaction Type: {input_data['Transaction_Type'].iloc[0]}
Transaction Channel: {input_data['Transaction_Channel'].iloc[0]}
Geographic Origin: {input_data['Geographic_Origin'].iloc[0]}
Geographic Destination: {input_data['Geographic_Destination'].iloc[0]}
Risk Category: {input_data['Risk_Category'].iloc[0]}
Risk Drivers: {input_data['Risk_Drivers'].iloc[0]}
Alert Status: {input_data['Alert_Status'].iloc[0]}
Investigation Outcome: {input_data['Investigation_Outcome'].iloc[0]}
Source of Funds: {input_data['Source_Of_Funds'].iloc[0]}
Purpose of Transaction: {input_data['Purpose_Of_Transaction'].iloc[0]}
Beneficial Owner: {input_data['Beneficial_Owner'].iloc[0]}
"""
# Create dataset instance
dataset = ComplianceDataset(
texts=[text_input],
labels=[[0] * len(LABEL_COLUMNS)], # Dummy labels for prediction
tokenizer=tokenizer,
max_len=MAX_LEN
)
# Create DataLoader
loader = DataLoader(dataset, batch_size=1, shuffle=False)
# Get prediction probabilities using the predict_probabilities function
all_probabilities = predict_probabilities(model, loader)
# Load label encoders to decode predictions
label_encoders = load_label_encoders()
# Format response
response = {}
for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)):
# Get the prediction (argmax of probabilities)
pred = np.argmax(probs[0])
decoded_pred = label_encoders[col].inverse_transform([pred])[0]
# Get probabilities for each class
class_probs = {
label: float(probs[0][j])
for j, label in enumerate(label_encoders[col].classes_)
}
response[col] = {
"prediction": decoded_pred,
"probabilities": class_probs
}
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
# For Hugging Face Spaces, we need to use port 7860
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)