bertmodel / main.py
ganeshkonapalli's picture
Upload 4 files
3b69eed verified
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from fastapi import FastAPI
from pydantic import BaseModel
import pandas as pd
from sklearn.preprocessing import LabelEncoder
MODEL_PATH = "bert_multioutput_model.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LABEL_COLUMNS = ["Red_Flag_Reason", "Maker_Action", "Escalation_Level", "Risk_Category", "Risk_Drivers", "Investigation_Outcome"]
class InputText(BaseModel):
text: str
class MultiOutputBERT(nn.Module):
def __init__(self, num_classes_per_label):
super(MultiOutputBERT, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.dropout = nn.Dropout(0.3)
self.classifiers = nn.ModuleList([
nn.Linear(self.bert.config.hidden_size, num_classes)
for num_classes in num_classes_per_label
])
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
pooled_output = self.dropout(outputs.pooler_output)
logits = [classifier(pooled_output) for classifier in self.classifiers]
return logits
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
num_classes_per_label = checkpoint["num_classes_per_label"]
label_encoders = checkpoint["label_encoders"]
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = MultiOutputBERT(num_classes_per_label)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(DEVICE)
model.eval()
app = FastAPI()
@app.get("/")
def home():
return {"message": "✅ Multi-output BERT API is live."}
@app.post("/predict")
def predict(request: InputText):
inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True, max_length=128)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs)
predictions = {}
for i, logit in enumerate(logits):
pred_idx = torch.argmax(logit, dim=1).item()
label = label_encoders[LABEL_COLUMNS[i]].inverse_transform([pred_idx])[0]
predictions[LABEL_COLUMNS[i]] = label
return predictions