irow commited on
Commit
ac39382
·
1 Parent(s): 9b19fb7

Added inference.py and model weights

Browse files
Files changed (2) hide show
  1. inference.py +66 -0
  2. models/model.pt +3 -0
inference.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from tqdm import tqdm
4
+ from transformers import DistilBertTokenizerFast, DistilBertModel
5
+ import numpy as np
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
10
+
11
+ class DistilBERTSent(nn.Module):
12
+ """
13
+ DistilBERT but with a layer attached to perform binary classification.
14
+ """
15
+ def __init__(self, freeze_bert=False):
16
+ super(DistilBERTSent, self).__init__()
17
+ self.distil_bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
18
+ self.linear = nn.Linear(2304, 1)
19
+ self.sigmoid = nn.Sigmoid()
20
+
21
+ if freeze_bert:
22
+ for param in self.distil_bert.parameters():
23
+ param.requires_grad = False
24
+
25
+ def forward(self, ids, mask):
26
+ outputs = self.distil_bert(input_ids = ids, attention_mask=mask, output_hidden_states=True)
27
+ x = torch.concat(outputs.hidden_states[:-4], dim=2).mean(1)
28
+ x = self.linear(x)
29
+ x = self.sigmoid(x)
30
+ return x.flatten()
31
+
32
+ def initialize(path="models/model.pt"):
33
+ model = DistilBERTSent()
34
+ model.load_state_dict(torch.load(path, map_location=device))
35
+ model.eval()
36
+ return model
37
+
38
+ def chunks(lst, n): # chunk list of strings
39
+ for i in tqdm(range(0, len(lst), n)):
40
+ yield lst[i:i+n]
41
+
42
+ @torch.no_grad()
43
+ def inference(model, text, batch_size=32):
44
+ """
45
+ pass in model, list of text, and batch_size
46
+ """
47
+ to_return = []
48
+ for batch in chunks(text, batch_size):
49
+ encoded = tokenizer(
50
+ text = batch,
51
+ add_special_tokens=True,
52
+ padding='max_length',
53
+ return_attention_mask=True,
54
+ truncation=True
55
+ )
56
+ input_ids = torch.tensor(encoded.get('input_ids')).to(device)
57
+ attention_masks = torch.tensor(encoded.get('attention_mask')).to(device)
58
+ to_return.append(model(input_ids, attention_masks))
59
+
60
+ return torch.concat(to_return).cpu().numpy()
61
+
62
+ if __name__ == "__main__":
63
+ model = initialize()
64
+ text = ["I love it so much!", "Broke on the first day"]
65
+ print(inference(model, text, 2))
66
+
models/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd3e40a660ec86d5f1b746490852d456b68a57f664bceba3a994f6704db20143
3
+ size 265494629