mohitsha commited on
Commit
a69f683
·
1 Parent(s): 6fd0b95

Upload whisper_eval.py

Browse files
Files changed (1) hide show
  1. whisper_eval.py +284 -0
whisper_eval.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+ import onnxruntime as onnxrt
7
+ import torch
8
+ from datasets import load_dataset
9
+ from transformers import (
10
+ AutoConfig,
11
+ AutoProcessor,
12
+ GenerationConfig,
13
+ WhisperForConditionalGeneration,
14
+ )
15
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
16
+
17
+
18
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
19
+
20
+
21
+ model_name = "openai/whisper-tiny.en"
22
+ config = AutoConfig.from_pretrained(model_name)
23
+ processor = AutoProcessor.from_pretrained(model_name)
24
+
25
+ batch_size = 1
26
+ encoder_num_attention_heads = 6
27
+ decoder_num_attention_heads = 6
28
+ hidden_size = 384
29
+ encoder_sequence_length = 1500
30
+ decoder_max_length = 448
31
+ num_hidden_layers = 4
32
+
33
+ encoder_shape = (
34
+ batch_size,
35
+ encoder_num_attention_heads,
36
+ encoder_sequence_length,
37
+ hidden_size // encoder_num_attention_heads,
38
+ )
39
+ decoder_shape = (
40
+ batch_size,
41
+ decoder_num_attention_heads,
42
+ decoder_max_length,
43
+ hidden_size // decoder_num_attention_heads,
44
+ )
45
+
46
+
47
+ # load dataset
48
+ ds = load_dataset(
49
+ "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
50
+ )
51
+ idx = 4
52
+ inputs = processor.feature_extractor(ds[idx]["audio"]["array"], return_tensors="pt")
53
+ input_features = inputs.input_features
54
+
55
+
56
+ # onnx_model_path = "/home/ubuntu/optimum/output_whisper_smooth_quant_4_oct_static_testing"
57
+ onnx_model_path = ".\\whisper-tiny-static-shape-quantized-SL-448"
58
+ config_file = ".\\other_libs_qdq\\vaip_config_gemm_asr_decoder.json"
59
+ encoder_model_path = ".\\whisper-tiny-static-shape-quantized-SL-448\\encoder_model.onnx"
60
+ decoder_model_path = ".\\whisper-tiny-static-shape-quantized-SL-448\\decoder_model_quantized.onnx"
61
+
62
+ print(decoder_model_path)
63
+
64
+
65
+ class ORTEncoder(torch.nn.Module):
66
+ def __init__(self):
67
+ super().__init__()
68
+ self.main_input_name = "input_features"
69
+ self.session = onnxrt.InferenceSession(
70
+ encoder_model_path, providers=["CPUExecutionProvider"]
71
+ )
72
+ self.output_names = {
73
+ output_key.name: idx
74
+ for idx, output_key in enumerate(self.session.get_outputs())
75
+ }
76
+
77
+ def forward(
78
+ self,
79
+ input_features: torch.FloatTensor,
80
+ **kwargs,
81
+ ) -> BaseModelOutput:
82
+ onnx_inputs = {"input_features": input_features.cpu().detach().numpy()}
83
+
84
+ # Run inference
85
+ outputs = self.session.run(None, onnx_inputs)
86
+ last_hidden_state = torch.from_numpy(
87
+ outputs[self.output_names["last_hidden_state"]]
88
+ )
89
+
90
+ return BaseModelOutput(last_hidden_state=last_hidden_state)
91
+
92
+
93
+ class ORTDecoder(torch.nn.Module):
94
+ def __init__(self):
95
+ super().__init__()
96
+ sess_options = onnxrt.SessionOptions()
97
+ self.provider = "VitisAIExecutionProvider"
98
+ self.provider_options = {"config_file": config_file}
99
+ sess_options.graph_optimization_level = (
100
+ onnxrt.GraphOptimizationLevel.ORT_DISABLE_ALL
101
+ )
102
+ sess_options.add_session_config_entry("session.disable_quant_qdq", "1")
103
+ self.session = onnxrt.InferenceSession(
104
+ decoder_model_path,
105
+ providers=[self.provider],
106
+ sess_options=sess_options,
107
+ provider_options=[self.provider_options],
108
+ )
109
+
110
+ self.generation_config = GenerationConfig.from_model_config(config)
111
+ self.max_length = decoder_max_length
112
+
113
+ self.input_names = {
114
+ input_key.name: idx
115
+ for idx, input_key in enumerate(self.session.get_inputs())
116
+ }
117
+ self.output_names = {
118
+ output_key.name: idx
119
+ for idx, output_key in enumerate(self.session.get_outputs())
120
+ }
121
+ self.key_value_input_names = [
122
+ key for key in self.input_names if (".key" in key) or (".value" in key)
123
+ ]
124
+ self.key_value_output_names = [
125
+ key for key in self.output_names if (".key" in key) or (".value" in key)
126
+ ]
127
+
128
+ self.reset()
129
+
130
+ def reset(self):
131
+ # Set the start model inputs
132
+ self.decoder_attention_mask = np.zeros((batch_size, self.max_length)).astype(
133
+ np.int64
134
+ )
135
+ self.decoder_attention_mask[0, 0] = 1
136
+ self.position_ids = np.array([[0]]).astype(np.int64)
137
+
138
+ # Set the input / output names
139
+ self.num_pkv = 4
140
+
141
+ def prepare_pkv(self):
142
+ decoder_key_value = torch.rand(*decoder_shape).to(torch.float32)
143
+ encoder_key_value = torch.rand(*encoder_shape).to(torch.float32)
144
+
145
+ past_key_values = []
146
+ repeat_count = len(self.key_value_input_names) // 4
147
+ past_key_values = tuple(
148
+ (decoder_key_value, decoder_key_value, encoder_key_value, encoder_key_value)
149
+ for _ in range(repeat_count)
150
+ )
151
+
152
+ return tuple(past_key_values)
153
+
154
+ def forward(
155
+ self,
156
+ input_ids: torch.LongTensor,
157
+ encoder_hidden_states: torch.FloatTensor,
158
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
159
+ ) -> Seq2SeqLMOutput:
160
+ if past_key_values is None:
161
+ self.reset()
162
+
163
+ if self.position_ids[0][0] == self.max_length:
164
+ logits = torch.zeros((len(input_ids), 1, config.vocab_size))
165
+ logits[:, :, config.eos_token_id] = 1
166
+
167
+ return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)
168
+
169
+ onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()}
170
+
171
+ onnx_inputs["position_ids"] = self.position_ids
172
+ onnx_inputs["decoder_attention_mask"] = self.decoder_attention_mask
173
+ onnx_inputs["encoder_hidden_states"] = (
174
+ encoder_hidden_states.cpu().detach().numpy()
175
+ )
176
+
177
+ if self.position_ids[0][0] == 0:
178
+ past_key_values = self.prepare_pkv()
179
+
180
+ past_key_values = tuple(
181
+ past_key_value
182
+ for pkv_per_layer in past_key_values
183
+ for past_key_value in pkv_per_layer
184
+ )
185
+
186
+ for input_name, past_key_value in zip(
187
+ self.key_value_input_names, past_key_values
188
+ ):
189
+ onnx_inputs[input_name] = past_key_value.cpu().detach().numpy()
190
+
191
+ # Run inference
192
+ outputs = self.session.run(None, onnx_inputs)
193
+
194
+ logits = torch.from_numpy(outputs[self.output_names["logits"]])
195
+
196
+ out_past_key_values = tuple(
197
+ torch.from_numpy(outputs[self.output_names[key]])
198
+ for key in self.key_value_output_names
199
+ )
200
+
201
+ if self.position_ids[0][0] == 0:
202
+ out_past_key_values = tuple(
203
+ out_past_key_values[i : i + self.num_pkv]
204
+ for i in range(0, len(out_past_key_values), self.num_pkv)
205
+ )
206
+ else:
207
+ out_past_key_values = tuple(
208
+ out_past_key_values[i : i + 2] + past_key_values[i + 2 : i + 4]
209
+ for i in range(0, len(out_past_key_values), self.num_pkv)
210
+ )
211
+
212
+ if self.position_ids[0][0] < self.max_length - 1:
213
+ self.decoder_attention_mask[:, self.position_ids[0][0] + 1] = 1
214
+ self.position_ids += 1
215
+
216
+ return Seq2SeqLMOutput(logits=logits, past_key_values=out_past_key_values)
217
+
218
+
219
+ class ORTModelForWhisper(WhisperForConditionalGeneration):
220
+ def __init__(self, *args, **kwargs):
221
+ config = AutoConfig.from_pretrained(model_name)
222
+ super().__init__(config)
223
+
224
+ self.encoder = ORTEncoder()
225
+ self.decoder = ORTDecoder()
226
+
227
+ def get_encoder(self):
228
+ return self.encoder
229
+
230
+ def forward(
231
+ self,
232
+ input_features: Optional[torch.FloatTensor] = None,
233
+ decoder_input_ids: Optional[torch.LongTensor] = None,
234
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
235
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
236
+ **kwargs,
237
+ ) -> Seq2SeqLMOutput:
238
+ if encoder_outputs is None:
239
+ encoder_outputs = self.encoder(input_features=input_features)
240
+
241
+ # Decode
242
+ decoder_outputs = self.decoder(
243
+ input_ids=decoder_input_ids[:, -1:],
244
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
245
+ past_key_values=past_key_values,
246
+ )
247
+
248
+ return Seq2SeqLMOutput(
249
+ logits=decoder_outputs.logits,
250
+ past_key_values=decoder_outputs.past_key_values,
251
+ )
252
+
253
+ def can_generate(self):
254
+ return True
255
+
256
+ def reset(self):
257
+ self.decoder.reset()
258
+
259
+
260
+ model_ort = ORTModelForWhisper()
261
+ model = WhisperForConditionalGeneration.from_pretrained(model_name)
262
+
263
+
264
+ def test_ort():
265
+ model = ORTModelForWhisper()
266
+
267
+ generated_ids = model.generate(input_features)
268
+ model_output = processor.tokenizer.batch_decode(
269
+ generated_ids, skip_special_tokens=True
270
+ )[0]
271
+
272
+ print("ORT: ", model_output, generated_ids)
273
+
274
+
275
+ def test_original():
276
+ model = WhisperForConditionalGeneration.from_pretrained(model_name)
277
+
278
+ generated_ids = model.generate(input_features)
279
+ model_output = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
280
+
281
+ print("Torch: ", model_output, generated_ids)
282
+
283
+
284
+ test_ort()