import torch import torch.nn as nn from transformers import BertTokenizer, BertModel from fastapi import FastAPI from pydantic import BaseModel # ----- Configuration ----- MODEL_NAME = "bert-base-uncased" MODEL_PATH = "model.pth" TOKENIZER_PATH = "bert-base-uncased" # Or use MODEL_NAME if no custom tokenizer NUM_LABELS = 6 # Update this if you have more/less labels # ----- Define Request Schema ----- class InputText(BaseModel): text: str # ----- Define Model Class ----- class BertMultiLabel(nn.Module): def __init__(self, num_labels): super(BertMultiLabel, self).__init__() self.bert = BertModel.from_pretrained(MODEL_NAME) self.dropout = nn.Dropout(0.3) self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = self.dropout(outputs.pooler_output) return self.classifier(pooled_output) # ----- Initialize App ----- app = FastAPI() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ----- Load Tokenizer and Model ----- tokenizer = BertTokenizer.from_pretrained(TOKENIZER_PATH) model = BertMultiLabel(num_labels=NUM_LABELS) model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) model.to(device) model.eval() # ----- Prediction Endpoint ----- @app.post("/predict") def predict(input_data: InputText): inputs = tokenizer( input_data.text, return_tensors="pt", truncation=True, padding=True, max_length=128 ) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) with torch.no_grad(): logits = model(input_ids=input_ids, attention_mask=attention_mask) probabilities = torch.sigmoid(logits).cpu().numpy().flatten().tolist() return {"probabilities": probabilities}