Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from transformers import BertTokenizer, BertModel | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| # ----- Configuration ----- | |
| MODEL_NAME = "bert-base-uncased" | |
| MODEL_PATH = "model.pth" | |
| TOKENIZER_PATH = "bert-base-uncased" # Or use MODEL_NAME if no custom tokenizer | |
| NUM_LABELS = 6 # Update this if you have more/less labels | |
| # ----- Define Request Schema ----- | |
| class InputText(BaseModel): | |
| text: str | |
| # ----- Define Model Class ----- | |
| class BertMultiLabel(nn.Module): | |
| def __init__(self, num_labels): | |
| super(BertMultiLabel, self).__init__() | |
| self.bert = BertModel.from_pretrained(MODEL_NAME) | |
| self.dropout = nn.Dropout(0.3) | |
| self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = self.dropout(outputs.pooler_output) | |
| return self.classifier(pooled_output) | |
| # ----- Initialize App ----- | |
| app = FastAPI() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ----- Load Tokenizer and Model ----- | |
| tokenizer = BertTokenizer.from_pretrained(TOKENIZER_PATH) | |
| model = BertMultiLabel(num_labels=NUM_LABELS) | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # ----- Prediction Endpoint ----- | |
| def predict(input_data: InputText): | |
| inputs = tokenizer( | |
| input_data.text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=128 | |
| ) | |
| input_ids = inputs["input_ids"].to(device) | |
| attention_mask = inputs["attention_mask"].to(device) | |
| with torch.no_grad(): | |
| logits = model(input_ids=input_ids, attention_mask=attention_mask) | |
| probabilities = torch.sigmoid(logits).cpu().numpy().flatten().tolist() | |
| return {"probabilities": probabilities} | |