DistilBertSentiment / inference.py
irow's picture
fixed model not going to correct device
2b742ef
raw
history blame
2.22 kB
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import DistilBertTokenizerFast, DistilBertModel
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
class DistilBERTSent(nn.Module):
"""
DistilBERT but with a layer attached to perform binary classification.
"""
def __init__(self, freeze_bert=False):
super(DistilBERTSent, self).__init__()
self.distil_bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
self.linear = nn.Linear(2304, 1)
self.sigmoid = nn.Sigmoid()
if freeze_bert:
for param in self.distil_bert.parameters():
param.requires_grad = False
def forward(self, ids, mask):
outputs = self.distil_bert(input_ids = ids, attention_mask=mask, output_hidden_states=True)
x = torch.concat(outputs.hidden_states[:-4], dim=2).mean(1)
x = self.linear(x)
x = self.sigmoid(x)
return x.flatten()
def initialize(path="models/model.pt"):
model = DistilBERTSent()
model.load_state_dict(torch.load(path, map_location=device))
model.to(device)
model.eval()
return model
def chunks(lst, n): # chunk list of strings
for i in tqdm(range(0, len(lst), n)):
yield lst[i:i+n]
@torch.no_grad()
def inference(model, text, batch_size=32):
"""
pass in model, list of text, and batch_size
"""
to_return = []
for batch in chunks(text, batch_size):
encoded = tokenizer(
text = batch,
add_special_tokens=True,
padding='max_length',
return_attention_mask=True,
truncation=True
)
input_ids = torch.tensor(encoded.get('input_ids')).to(device)
attention_masks = torch.tensor(encoded.get('attention_mask')).to(device)
to_return.append(model(input_ids, attention_masks))
return torch.concat(to_return).cpu().numpy()
if __name__ == "__main__":
model = initialize()
text = ["I love it so much!", "Broke on the first day"]
print(inference(model, text, 2))