File size: 1,882 Bytes
17c6d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from transformers import pipeline, AutoTokenizer, AutoModel, AutoModelForMaskedLM
import time

test_sentence = 'Do you [MASK] the muffin man?'

# for comparison
bert = pipeline('fill-mask', model = 'bert-base-uncased')
print('\n'.join([d['sequence'] for d in bert(test_sentence)]))


deberta = pipeline('fill-mask', model = 'microsoft/deberta-v3-base', model_kwargs={"legacy": False})
print('\n'.join([d['sequence'] for d in deberta(test_sentence)]))


tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base")

tokenized_dict = tokenizer(
    ["Is this working",], ["Not yet",],
    return_tensors="pt"
)

deberta.model.forward = torch.compile(deberta.model.forward)
start=time.time()
deberta.model(**tokenized_dict)
end=time.time()
print(end-start)


start=time.time()
deberta.model(**tokenized_dict)
end=time.time()
print(end-start)


start=time.time()
deberta.model(**tokenized_dict)
end=time.time()
print(end-start)


model = AutoModel.from_pretrained('microsoft/deberta-base')
model.config.return_dict = False
model.config.output_hidden_states=False
input_tuple = (tokenized_dict['input_ids'], tokenized_dict['attention_mask'])


start=time.time()
traced_model = torch.jit.trace(model, input_tuple)
end=time.time()
print(end-start)


start=time.time()
traced_model(tokenized_dict['input_ids'], tokenized_dict['attention_mask'])
end=time.time()
print(end-start)


start=time.time()
traced_model(tokenized_dict['input_ids'], tokenized_dict['attention_mask'])
end=time.time()
print(end-start)


start=time.time()
traced_model(tokenized_dict['input_ids'], tokenized_dict['attention_mask'])
end=time.time()
print(end-start)


start=time.time()
traced_model(tokenized_dict['input_ids'], tokenized_dict['attention_mask'])
end=time.time()
print(end-start)


torch.jit.save(traced_model, "compiled_deberta.pt")



# my_script_module = torch.jit.script(model)