bert-output / app.py
subbunanepalli's picture
Create app.py
bab2e3f verified
import os
import torch
import torch.nn as nn
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import BertTokenizer, BertModel
from typing import List
# Set local Hugging Face cache
os.environ["TRANSFORMERS_CACHE"] = "./hf_cache"
TEXT_COLUMN = "Sanction_Context"
LABEL_COLUMNS = ["Red_Flag_Reason", "Maker_Action", "Escalation_Level",
"Risk_Category", "Risk_Drivers", "Investigation_Outcome"]
app = FastAPI()
# Define input schema
class InputText(BaseModel):
text: str
# Model definition
class BERTMultiOutputModel(nn.Module):
def __init__(self):
super(BERTMultiOutputModel, self).__init__()
self.bert = BertModel.from_pretrained("bert-base-uncased")
self.dropout = nn.Dropout(0.3)
self.output_layers = nn.ModuleList([
nn.Linear(self.bert.config.hidden_size, 1) for _ in LABEL_COLUMNS
])
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = self.dropout(outputs.pooler_output)
return [torch.sigmoid(layer(pooled_output)) for layer in self.output_layers]
# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BERTMultiOutputModel()
model.load_state_dict(torch.load("model/bert_model.pth", map_location=torch.device('cpu')))
model.eval()
@app.post("/predict")
def predict(input_text: InputText):
inputs = tokenizer(input_text.text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
preds = [round(float(o.item())) for o in outputs]
return dict(zip(LABEL_COLUMNS, preds))