beerttt / app.py
subbunanepalli's picture
Create app.py
2696c20 verified
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from fastapi import FastAPI
from pydantic import BaseModel
# ----- Configuration -----
MODEL_NAME = "bert-base-uncased"
MODEL_PATH = "model.pth"
TOKENIZER_PATH = "bert-base-uncased" # Or use MODEL_NAME if no custom tokenizer
NUM_LABELS = 6 # Update this if you have more/less labels
# ----- Define Request Schema -----
class InputText(BaseModel):
text: str
# ----- Define Model Class -----
class BertMultiLabel(nn.Module):
def __init__(self, num_labels):
super(BertMultiLabel, self).__init__()
self.bert = BertModel.from_pretrained(MODEL_NAME)
self.dropout = nn.Dropout(0.3)
self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = self.dropout(outputs.pooler_output)
return self.classifier(pooled_output)
# ----- Initialize App -----
app = FastAPI()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ----- Load Tokenizer and Model -----
tokenizer = BertTokenizer.from_pretrained(TOKENIZER_PATH)
model = BertMultiLabel(num_labels=NUM_LABELS)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()
# ----- Prediction Endpoint -----
@app.post("/predict")
def predict(input_data: InputText):
inputs = tokenizer(
input_data.text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
logits = model(input_ids=input_ids, attention_mask=attention_mask)
probabilities = torch.sigmoid(logits).cpu().numpy().flatten().tolist()
return {"probabilities": probabilities}