import os import torch import torch.nn as nn from fastapi import FastAPI from pydantic import BaseModel from transformers import BertTokenizer, BertModel from typing import List # Set local Hugging Face cache os.environ["TRANSFORMERS_CACHE"] = "./hf_cache" TEXT_COLUMN = "Sanction_Context" LABEL_COLUMNS = ["Red_Flag_Reason", "Maker_Action", "Escalation_Level", "Risk_Category", "Risk_Drivers", "Investigation_Outcome"] app = FastAPI() # Define input schema class InputText(BaseModel): text: str # Model definition class BERTMultiOutputModel(nn.Module): def __init__(self): super(BERTMultiOutputModel, self).__init__() self.bert = BertModel.from_pretrained("bert-base-uncased") self.dropout = nn.Dropout(0.3) self.output_layers = nn.ModuleList([ nn.Linear(self.bert.config.hidden_size, 1) for _ in LABEL_COLUMNS ]) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = self.dropout(outputs.pooler_output) return [torch.sigmoid(layer(pooled_output)) for layer in self.output_layers] # Load tokenizer and model tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") model = BERTMultiOutputModel() model.load_state_dict(torch.load("model/bert_model.pth", map_location=torch.device('cpu'))) model.eval() @app.post("/predict") def predict(input_text: InputText): inputs = tokenizer(input_text.text, return_tensors="pt", truncation=True, padding=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) preds = [round(float(o.item())) for o in outputs] return dict(zip(LABEL_COLUMNS, preds))