File size: 1,651 Bytes
111c8fb a30ecaf 111c8fb a30ecaf 111c8fb 7711bcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
import torch.nn as nn
import math
from transformers import BertModel, BertTokenizer, BertPreTrainedModel
class BertForRegression(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.regressor = nn.Linear(config.hidden_size, 1)
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
logits = self.regressor(outputs.pooler_output)
loss = None
if labels is not None:
loss_fct = nn.MSELoss()
loss = loss_fct(logits.squeeze(), labels.float())
return (loss, logits) if loss is not None else logits
class EndpointHandler():
def __init__(self, path: str):
self.model = BertForRegression.from_pretrained(path)
self.tokenizer = BertTokenizer.from_pretrained(path)
def __call__(self, data):
self.model.eval()
# Accept JSON input: {"inputs": "text string"}
if isinstance(data, dict):
text = data.get("inputs", "")
else:
text = data
if not isinstance(text, str):
raise ValueError("Input text must be a string under the 'inputs' key.")
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
logits = self.model(**inputs)
prediction = logits[0].item()
prediction = math.trunc(prediction * 100) / 100
return str(prediction) |