# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. """ A one-layer Whisper decoder model test case, with inputs: audio_features. This model contains one layer of self-attention and one layer of cross-attention. This is an onnxscript version of the model. """ import numpy as np import onnx_ir as ir from onnxscript import script from onnxscript.onnx_opset import opset18 from onnxscript.onnx_types import FLOAT, INT32 def make_model( decoder_embed_positions_weight, proj_out_weight, decoder_layers_0_self_attn_layer_norm_weight, decoder_layers_0_self_attn_layer_norm_bias, decoder_layers_0_self_attn_q_proj_weight, decoder_layers_0_self_attn_q_proj_bias, decoder_layers_0_self_attn_k_proj_weight, decoder_layers_0_self_attn_v_proj_weight, decoder_layers_0_self_attn_v_proj_bias, decoder_layers_0_self_attn_out_proj_weight, decoder_layers_0_self_attn_out_proj_bias, decoder_layers_0_encoder_attn_layer_norm_weight, decoder_layers_0_encoder_attn_layer_norm_bias, decoder_layers_0_encoder_attn_q_proj_weight, decoder_layers_0_encoder_attn_q_proj_bias, decoder_layers_0_encoder_attn_out_proj_weight, decoder_layers_0_encoder_attn_out_proj_bias, decoder_layers_0_final_layer_norm_weight, decoder_layers_0_final_layer_norm_bias, decoder_layers_0_fc1_weight, decoder_layers_0_fc1_bias, decoder_layers_0_fc2_weight, decoder_layers_0_fc2_bias, decoder_layer_norm_weight, decoder_layer_norm_bias, ): @script() def main_graph( # TODO: Fix test case for dynamic batch size and past sequence length decoder_input_ids: INT32[1, 1], encoder_hidden_states: FLOAT[1, 1500, 384], past_key_values_0_0: FLOAT[1, 6, 32, 64], past_key_values_0_1: FLOAT[1, 6, 32, 64], past_key_values_0_2: FLOAT[1, 6, 32, 64], past_key_values_0_3: FLOAT[1, 6, 32, 64], ) -> ( FLOAT[1, 1, 51865], FLOAT[1, 6, 33, 64], FLOAT[1, 6, 33, 64], ): val_0 = opset18.Shape(decoder_input_ids, end=1, start=0) val_1 = opset18.Shape(past_key_values_0_0, end=3, start=2) sym_size_int_42 = opset18.Squeeze(val_1) view = opset18.Reshape(decoder_input_ids, [-1, 1], allowzero=0) embedding = opset18.Gather(proj_out_weight, view, axis=0) add_7 = opset18.Add(sym_size_int_42, 1) arange = opset18.Range(sym_size_int_42, add_7, 1) unsqueeze = opset18.Unsqueeze(arange, [0]) val_16 = opset18.Concat(val_0, [1], axis=0) repeat = opset18.Tile(unsqueeze, val_16) val_22 = opset18.Unsqueeze(repeat, [-1]) val_24 = opset18.GatherND(decoder_embed_positions_weight, val_22, batch_dims=0) add_15 = opset18.Add(embedding, val_24) add_24 = opset18.Add(add_7, 1) val_28 = opset18.Reshape(add_24, [-1], allowzero=0) val_29 = opset18.Concat([1], val_28, axis=0) full = opset18.Expand(-3.4028235e38, val_29) arange_1 = opset18.Range(0, add_24, 1) view_1 = opset18.Reshape(arange, [-1, 1], allowzero=0) gt = opset18.Greater(arange_1, view_1) convert_element_type_default = opset18.Cast(gt, to=1) mul_17 = opset18.Mul(full, convert_element_type_default) layer_norm = opset18.LayerNormalization( add_15, decoder_layers_0_self_attn_layer_norm_weight, decoder_layers_0_self_attn_layer_norm_bias, stash_type=1, epsilon=9.999999747378752e-06, axis=-1, ) val_37 = opset18.Transpose(decoder_layers_0_self_attn_q_proj_weight, perm=[1, 0]) val_38 = opset18.MatMul(layer_norm, val_37) linear = opset18.Add(val_38, decoder_layers_0_self_attn_q_proj_bias) mul_43 = opset18.Mul(linear, 0.125) val_44 = opset18.Concat(val_0, [1], [6], [64], axis=0) view_2 = opset18.Reshape(mul_43, val_44, allowzero=0) transpose = opset18.Transpose(view_2, perm=[0, 2, 1, 3]) val_46 = opset18.Transpose(decoder_layers_0_self_attn_k_proj_weight, perm=[1, 0]) linear_1 = opset18.MatMul(layer_norm, val_46) val_49 = opset18.Concat(val_0, [-1], [6], [64], axis=0) view_3 = opset18.Reshape(linear_1, val_49, allowzero=0) transpose_1 = opset18.Transpose(view_3, perm=[0, 2, 1, 3]) val_51 = opset18.Transpose(decoder_layers_0_self_attn_v_proj_weight, perm=[1, 0]) val_52 = opset18.MatMul(layer_norm, val_51) linear_2 = opset18.Add(val_52, decoder_layers_0_self_attn_v_proj_bias) val_55 = opset18.Concat(val_0, [-1], [6], [64], axis=0) view_4 = opset18.Reshape(linear_2, val_55, allowzero=0) transpose_2 = opset18.Transpose(view_4, perm=[0, 2, 1, 3]) cat = opset18.Concat(past_key_values_0_0, transpose_1, axis=-2) cat_1 = opset18.Concat(past_key_values_0_1, transpose_2, axis=-2) transpose_3 = opset18.Transpose(cat, perm=[0, 1, 3, 2]) matmul = opset18.MatMul(transpose, transpose_3) unsqueeze_4 = opset18.Unsqueeze(mul_17, [0, 1]) val_83 = opset18.Concat(val_0, [1], [-1], [-1], axis=0) val_85 = opset18.Abs(val_83) expand_1 = opset18.Expand(unsqueeze_4, val_85) val_104 = opset18.Constant(value_ints=[0]) val_106 = opset18.Constant(value_ints=[-1]) val_107 = opset18.Reshape(add_7, val_106, allowzero=0) val_111 = opset18.Constant(value_ints=[1]) slice_12 = opset18.Slice(expand_1, val_104, val_107, [3], val_111) add_125 = opset18.Add(matmul, slice_12) softmax = opset18.Softmax(add_125, axis=-1) matmul_1 = opset18.MatMul(softmax, cat_1) transpose_4 = opset18.Transpose(matmul_1, perm=[0, 2, 1, 3]) val_115 = opset18.Concat(val_0, [1], [384], axis=0) view_5 = opset18.Reshape(transpose_4, val_115, allowzero=0) val_117 = opset18.Transpose(decoder_layers_0_self_attn_out_proj_weight, perm=[1, 0]) val_118 = opset18.MatMul(view_5, val_117) linear_3 = opset18.Add(val_118, decoder_layers_0_self_attn_out_proj_bias) add_163 = opset18.Add(add_15, linear_3) layer_norm_1 = opset18.LayerNormalization( add_163, decoder_layers_0_encoder_attn_layer_norm_weight, decoder_layers_0_encoder_attn_layer_norm_bias, stash_type=1, epsilon=9.999999747378752e-06, axis=-1, ) val_121 = opset18.Transpose(decoder_layers_0_encoder_attn_q_proj_weight, perm=[1, 0]) val_122 = opset18.MatMul(layer_norm_1, val_121) linear_4 = opset18.Add(val_122, decoder_layers_0_encoder_attn_q_proj_bias) mul_125 = opset18.Mul(linear_4, 0.125) val_125 = opset18.Concat(val_0, [1], [6], [64], axis=0) view_6 = opset18.Reshape(mul_125, val_125, allowzero=0) transpose_5 = opset18.Transpose(view_6, perm=[0, 2, 1, 3]) transpose_6 = opset18.Transpose(past_key_values_0_2, perm=[0, 1, 3, 2]) matmul_2 = opset18.MatMul(transpose_5, transpose_6) softmax_1 = opset18.Softmax(matmul_2, axis=-1) matmul_3 = opset18.MatMul(softmax_1, past_key_values_0_3) transpose_7 = opset18.Transpose(matmul_3, perm=[0, 2, 1, 3]) val_129 = opset18.Concat(val_0, [1], [384], axis=0) view_7 = opset18.Reshape(transpose_7, val_129, allowzero=0) val_131 = opset18.Transpose(decoder_layers_0_encoder_attn_out_proj_weight, perm=[1, 0]) val_132 = opset18.MatMul(view_7, val_131) linear_5 = opset18.Add(val_132, decoder_layers_0_encoder_attn_out_proj_bias) add_232 = opset18.Add(add_163, linear_5) layer_norm_2 = opset18.LayerNormalization( add_232, decoder_layers_0_final_layer_norm_weight, decoder_layers_0_final_layer_norm_bias, stash_type=1, epsilon=9.999999747378752e-06, axis=-1, ) val_135 = opset18.Transpose(decoder_layers_0_fc1_weight, perm=[1, 0]) val_136 = opset18.MatMul(layer_norm_2, val_135) linear_6 = opset18.Add(val_136, decoder_layers_0_fc1_bias) val_138 = opset18.Div(linear_6, 1.4142135) val_139 = opset18.Erf(val_138) val_141 = opset18.Add(val_139, 1.0) val_143 = opset18.Mul(0.5, val_141) gelu = opset18.Mul(linear_6, val_143) val_144 = opset18.Transpose(decoder_layers_0_fc2_weight, perm=[1, 0]) val_145 = opset18.MatMul(gelu, val_144) linear_7 = opset18.Add(val_145, decoder_layers_0_fc2_bias) add_261 = opset18.Add(add_232, linear_7) layer_norm_12 = opset18.LayerNormalization( add_261, decoder_layer_norm_weight, decoder_layer_norm_bias, stash_type=1, epsilon=9.999999747378752e-06, axis=-1, ) val_457 = opset18.Transpose(proj_out_weight, perm=[1, 0]) linear_32 = opset18.MatMul(layer_norm_12, val_457) return linear_32, cat, cat_1 model = main_graph.to_model_proto() return model def make_model_with_random_weights(): np.random.seed(10) # Set a fixed seed decoder_embed_positions_weight = np.random.rand(448, 384).astype(np.float32) proj_out_weight = np.random.rand(51865, 384).astype(np.float32) decoder_layers_0_self_attn_layer_norm_weight = np.random.rand(384).astype(np.float32) decoder_layers_0_self_attn_layer_norm_bias = np.random.rand(384).astype(np.float32) decoder_layers_0_self_attn_q_proj_weight = np.random.rand(384, 384).astype(np.float32) decoder_layers_0_self_attn_q_proj_bias = np.random.rand(384).astype(np.float32) decoder_layers_0_self_attn_k_proj_weight = np.random.rand(384, 384).astype(np.float32) decoder_layers_0_self_attn_v_proj_weight = np.random.rand(384, 384).astype(np.float32) decoder_layers_0_self_attn_v_proj_bias = np.random.rand(384).astype(np.float32) decoder_layers_0_self_attn_out_proj_weight = np.random.rand(384, 384).astype(np.float32) decoder_layers_0_self_attn_out_proj_bias = np.random.rand(384).astype(np.float32) decoder_layers_0_encoder_attn_layer_norm_weight = np.random.rand(384).astype(np.float32) decoder_layers_0_encoder_attn_layer_norm_bias = np.random.rand(384).astype(np.float32) decoder_layers_0_encoder_attn_q_proj_weight = np.random.rand(384, 384).astype(np.float32) decoder_layers_0_encoder_attn_q_proj_bias = np.random.rand(384).astype(np.float32) decoder_layers_0_encoder_attn_out_proj_weight = np.random.rand(384, 384).astype(np.float32) decoder_layers_0_encoder_attn_out_proj_bias = np.random.rand(384).astype(np.float32) decoder_layers_0_final_layer_norm_weight = np.random.rand(384).astype(np.float32) decoder_layers_0_final_layer_norm_bias = np.random.rand(384).astype(np.float32) decoder_layers_0_fc1_weight = np.random.rand(1536, 384).astype(np.float32) decoder_layers_0_fc1_bias = np.random.rand(1536).astype(np.float32) decoder_layers_0_fc2_weight = np.random.rand(384, 1536).astype(np.float32) decoder_layers_0_fc2_bias = np.random.rand(384).astype(np.float32) decoder_layer_norm_weight = np.random.rand(384).astype(np.float32) decoder_layer_norm_bias = np.random.rand(384).astype(np.float32) model = make_model( decoder_embed_positions_weight, proj_out_weight, decoder_layers_0_self_attn_layer_norm_weight, decoder_layers_0_self_attn_layer_norm_bias, decoder_layers_0_self_attn_q_proj_weight, decoder_layers_0_self_attn_q_proj_bias, decoder_layers_0_self_attn_k_proj_weight, decoder_layers_0_self_attn_v_proj_weight, decoder_layers_0_self_attn_v_proj_bias, decoder_layers_0_self_attn_out_proj_weight, decoder_layers_0_self_attn_out_proj_bias, decoder_layers_0_encoder_attn_layer_norm_weight, decoder_layers_0_encoder_attn_layer_norm_bias, decoder_layers_0_encoder_attn_q_proj_weight, decoder_layers_0_encoder_attn_q_proj_bias, decoder_layers_0_encoder_attn_out_proj_weight, decoder_layers_0_encoder_attn_out_proj_bias, decoder_layers_0_final_layer_norm_weight, decoder_layers_0_final_layer_norm_bias, decoder_layers_0_fc1_weight, decoder_layers_0_fc1_bias, decoder_layers_0_fc2_weight, decoder_layers_0_fc2_bias, decoder_layer_norm_weight, decoder_layer_norm_bias, ) return model class _WhisperDecoderTest: def get_onnx_model(self): if not hasattr(self, "_onnx_model"): model_proto = make_model_with_random_weights() model = ir.serde.deserialize_model(model_proto) self._onnx_model = model return self._onnx_model def get_ort_inputs(self): if not hasattr(self, "_ort_inputs"): np.random.seed(10) # Set a fixed seed inputs = { "decoder_input_ids": np.random.randint(0, 49152, (1, 1)).astype(np.int32), "encoder_hidden_states": np.random.rand(1, 1500, 384).astype(np.float32), "past_key_values_0_0": np.random.rand(1, 6, 32, 64).astype(np.float32), "past_key_values_0_1": np.random.rand(1, 6, 32, 64).astype(np.float32), "past_key_values_0_2": np.random.rand(1, 6, 32, 64).astype(np.float32), "past_key_values_0_3": np.random.rand(1, 6, 32, 64).astype(np.float32), } self._ort_inputs = inputs return self._ort_inputs def whisper_decoder_test(): return _WhisperDecoderTest()