Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| import os | |
| from typing import Optional | |
| from torch.utils.data import DataLoader | |
| from transformers import DebertaTokenizer | |
| from config import ( | |
| TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, DEVICE, | |
| DEBERTA_MODEL_NAME, METADATA_COLUMNS | |
| ) | |
| from dataset_utils import load_label_encoders, get_tokenizer, ComplianceDataset | |
| from train_utils import predict_probabilities | |
| from models.deberta_model import DebertaMultiOutputModel | |
| app = FastAPI() | |
| # β Load model and tokenizer | |
| model_path = "saved_models/DEBERTA_model.pth" | |
| tokenizer = get_tokenizer(DEBERTA_MODEL_NAME) | |
| label_encoders = load_label_encoders() | |
| num_classes = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS] | |
| model = DebertaMultiOutputModel(num_classes).to(DEVICE) | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"β Model not found at {model_path}") | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| model.eval() | |
| # β Request schema | |
| class TransactionData(BaseModel): | |
| Transaction_Id: str | |
| Hit_Seq: int | |
| Hit_Id_List: str | |
| Origin: str | |
| Designation: str | |
| Keywords: str | |
| Name: str | |
| SWIFT_Tag: str | |
| Currency: str | |
| Entity: str | |
| Message: str | |
| City: str | |
| Country: str | |
| State: str | |
| Hit_Type: str | |
| Record_Matching_String: str | |
| WatchList_Match_String: str | |
| Payment_Sender_Name: Optional[str] = "" | |
| Payment_Reciever_Name: Optional[str] = "" | |
| Swift_Message_Type: str | |
| Text_Sanction_Data: str | |
| Matched_Sanctioned_Entity: str | |
| Is_Match: int | |
| Red_Flag_Reason: str | |
| Risk_Level: str | |
| Risk_Score: float | |
| Risk_Score_Description: str | |
| CDD_Level: str | |
| PEP_Status: str | |
| Value_Date: str | |
| Last_Review_Date: str | |
| Next_Review_Date: str | |
| Sanction_Description: str | |
| Checker_Notes: str | |
| Sanction_Context: str | |
| Maker_Action: str | |
| Customer_ID: int | |
| Customer_Type: str | |
| Industry: str | |
| Transaction_Date_Time: str | |
| Transaction_Type: str | |
| Transaction_Channel: str | |
| Originating_Bank: str | |
| Beneficiary_Bank: str | |
| Geographic_Origin: str | |
| Geographic_Destination: str | |
| Match_Score: float | |
| Match_Type: str | |
| Sanctions_List_Version: str | |
| Screening_Date_Time: str | |
| Risk_Category: str | |
| Risk_Drivers: str | |
| Alert_Status: str | |
| Investigation_Outcome: str | |
| Case_Owner_Analyst: str | |
| Escalation_Level: str | |
| Escalation_Date: str | |
| Regulatory_Reporting_Flags: bool | |
| Audit_Trail_Timestamp: str | |
| Source_Of_Funds: str | |
| Purpose_Of_Transaction: str | |
| Beneficial_Owner: str | |
| Sanctions_Exposure_History: bool | |
| class PredictionRequest(BaseModel): | |
| transaction_data: TransactionData | |
| # β Health check routes | |
| async def root(): | |
| return {"status": "healthy", "message": "DeBERTa API is running"} | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| # β Inference endpoint | |
| async def predict(request: PredictionRequest): | |
| try: | |
| input_data = pd.DataFrame([request.transaction_data.dict()]) | |
| # π§ Construct input text | |
| text_input = f""" | |
| Transaction ID: {input_data['Transaction_Id'].iloc[0]} | |
| Origin: {input_data['Origin'].iloc[0]} | |
| Designation: {input_data['Designation'].iloc[0]} | |
| Keywords: {input_data['Keywords'].iloc[0]} | |
| Name: {input_data['Name'].iloc[0]} | |
| SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]} | |
| Currency: {input_data['Currency'].iloc[0]} | |
| Entity: {input_data['Entity'].iloc[0]} | |
| Message: {input_data['Message'].iloc[0]} | |
| City: {input_data['City'].iloc[0]} | |
| Country: {input_data['Country'].iloc[0]} | |
| State: {input_data['State'].iloc[0]} | |
| Hit Type: {input_data['Hit_Type'].iloc[0]} | |
| Record Matching String: {input_data['Record_Matching_String'].iloc[0]} | |
| WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]} | |
| Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]} | |
| Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]} | |
| Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]} | |
| Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]} | |
| Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]} | |
| Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]} | |
| Risk Level: {input_data['Risk_Level'].iloc[0]} | |
| Risk Score: {input_data['Risk_Score'].iloc[0]} | |
| CDD Level: {input_data['CDD_Level'].iloc[0]} | |
| PEP Status: {input_data['PEP_Status'].iloc[0]} | |
| Sanction Description: {input_data['Sanction_Description'].iloc[0]} | |
| Checker Notes: {input_data['Checker_Notes'].iloc[0]} | |
| Sanction Context: {input_data['Sanction_Context'].iloc[0]} | |
| Maker Action: {input_data['Maker_Action'].iloc[0]} | |
| Customer Type: {input_data['Customer_Type'].iloc[0]} | |
| Industry: {input_data['Industry'].iloc[0]} | |
| Transaction Type: {input_data['Transaction_Type'].iloc[0]} | |
| Transaction Channel: {input_data['Transaction_Channel'].iloc[0]} | |
| Geographic Origin: {input_data['Geographic_Origin'].iloc[0]} | |
| Geographic Destination: {input_data['Geographic_Destination'].iloc[0]} | |
| Risk Category: {input_data['Risk_Category'].iloc[0]} | |
| Risk Drivers: {input_data['Risk_Drivers'].iloc[0]} | |
| Alert Status: {input_data['Alert_Status'].iloc[0]} | |
| Investigation Outcome: {input_data['Investigation_Outcome'].iloc[0]} | |
| Source of Funds: {input_data['Source_Of_Funds'].iloc[0]} | |
| Purpose of Transaction: {input_data['Purpose_Of_Transaction'].iloc[0]} | |
| Beneficial Owner: {input_data['Beneficial_Owner'].iloc[0]} | |
| """ | |
| # βοΈ Create dataset | |
| dataset = ComplianceDataset( | |
| texts=[text_input], | |
| labels=[[0] * len(LABEL_COLUMNS)], | |
| tokenizer=tokenizer, | |
| max_len=MAX_LEN | |
| ) | |
| loader = DataLoader(dataset, batch_size=1) | |
| all_probabilities = predict_probabilities(model, loader) | |
| # π Format response | |
| response = {} | |
| for i, (col, probs) in enumerate(zip(LABEL_COLUMNS, all_probabilities)): | |
| pred = np.argmax(probs[0]) | |
| decoded_pred = label_encoders[col].inverse_transform([pred])[0] | |
| response[col] = { | |
| "prediction": decoded_pred, | |
| "probabilities": { | |
| label: float(probs[0][j]) | |
| for j, label in enumerate(label_encoders[col].classes_) | |
| } | |
| } | |
| return response | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # π₯οΈ Entry point for Spaces | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |