xiaoanyu123's picture
Add files using upload-large-folder tool
f134ab5 verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
A one-layer Whisper decoder model test case, with inputs: audio_features.
This model contains one layer of self-attention and one layer of cross-attention.
This is an onnxscript version of the model.
"""
import numpy as np
import onnx_ir as ir
from onnxscript import script
from onnxscript.onnx_opset import opset18
from onnxscript.onnx_types import FLOAT, INT32
def make_model(
decoder_embed_positions_weight,
proj_out_weight,
decoder_layers_0_self_attn_layer_norm_weight,
decoder_layers_0_self_attn_layer_norm_bias,
decoder_layers_0_self_attn_q_proj_weight,
decoder_layers_0_self_attn_q_proj_bias,
decoder_layers_0_self_attn_k_proj_weight,
decoder_layers_0_self_attn_v_proj_weight,
decoder_layers_0_self_attn_v_proj_bias,
decoder_layers_0_self_attn_out_proj_weight,
decoder_layers_0_self_attn_out_proj_bias,
decoder_layers_0_encoder_attn_layer_norm_weight,
decoder_layers_0_encoder_attn_layer_norm_bias,
decoder_layers_0_encoder_attn_q_proj_weight,
decoder_layers_0_encoder_attn_q_proj_bias,
decoder_layers_0_encoder_attn_out_proj_weight,
decoder_layers_0_encoder_attn_out_proj_bias,
decoder_layers_0_final_layer_norm_weight,
decoder_layers_0_final_layer_norm_bias,
decoder_layers_0_fc1_weight,
decoder_layers_0_fc1_bias,
decoder_layers_0_fc2_weight,
decoder_layers_0_fc2_bias,
decoder_layer_norm_weight,
decoder_layer_norm_bias,
):
@script()
def main_graph(
# TODO: Fix test case for dynamic batch size and past sequence length
decoder_input_ids: INT32[1, 1],
encoder_hidden_states: FLOAT[1, 1500, 384],
past_key_values_0_0: FLOAT[1, 6, 32, 64],
past_key_values_0_1: FLOAT[1, 6, 32, 64],
past_key_values_0_2: FLOAT[1, 6, 32, 64],
past_key_values_0_3: FLOAT[1, 6, 32, 64],
) -> (
FLOAT[1, 1, 51865],
FLOAT[1, 6, 33, 64],
FLOAT[1, 6, 33, 64],
):
val_0 = opset18.Shape(decoder_input_ids, end=1, start=0)
val_1 = opset18.Shape(past_key_values_0_0, end=3, start=2)
sym_size_int_42 = opset18.Squeeze(val_1)
view = opset18.Reshape(decoder_input_ids, [-1, 1], allowzero=0)
embedding = opset18.Gather(proj_out_weight, view, axis=0)
add_7 = opset18.Add(sym_size_int_42, 1)
arange = opset18.Range(sym_size_int_42, add_7, 1)
unsqueeze = opset18.Unsqueeze(arange, [0])
val_16 = opset18.Concat(val_0, [1], axis=0)
repeat = opset18.Tile(unsqueeze, val_16)
val_22 = opset18.Unsqueeze(repeat, [-1])
val_24 = opset18.GatherND(decoder_embed_positions_weight, val_22, batch_dims=0)
add_15 = opset18.Add(embedding, val_24)
add_24 = opset18.Add(add_7, 1)
val_28 = opset18.Reshape(add_24, [-1], allowzero=0)
val_29 = opset18.Concat([1], val_28, axis=0)
full = opset18.Expand(-3.4028235e38, val_29)
arange_1 = opset18.Range(0, add_24, 1)
view_1 = opset18.Reshape(arange, [-1, 1], allowzero=0)
gt = opset18.Greater(arange_1, view_1)
convert_element_type_default = opset18.Cast(gt, to=1)
mul_17 = opset18.Mul(full, convert_element_type_default)
layer_norm = opset18.LayerNormalization(
add_15,
decoder_layers_0_self_attn_layer_norm_weight,
decoder_layers_0_self_attn_layer_norm_bias,
stash_type=1,
epsilon=9.999999747378752e-06,
axis=-1,
)
val_37 = opset18.Transpose(decoder_layers_0_self_attn_q_proj_weight, perm=[1, 0])
val_38 = opset18.MatMul(layer_norm, val_37)
linear = opset18.Add(val_38, decoder_layers_0_self_attn_q_proj_bias)
mul_43 = opset18.Mul(linear, 0.125)
val_44 = opset18.Concat(val_0, [1], [6], [64], axis=0)
view_2 = opset18.Reshape(mul_43, val_44, allowzero=0)
transpose = opset18.Transpose(view_2, perm=[0, 2, 1, 3])
val_46 = opset18.Transpose(decoder_layers_0_self_attn_k_proj_weight, perm=[1, 0])
linear_1 = opset18.MatMul(layer_norm, val_46)
val_49 = opset18.Concat(val_0, [-1], [6], [64], axis=0)
view_3 = opset18.Reshape(linear_1, val_49, allowzero=0)
transpose_1 = opset18.Transpose(view_3, perm=[0, 2, 1, 3])
val_51 = opset18.Transpose(decoder_layers_0_self_attn_v_proj_weight, perm=[1, 0])
val_52 = opset18.MatMul(layer_norm, val_51)
linear_2 = opset18.Add(val_52, decoder_layers_0_self_attn_v_proj_bias)
val_55 = opset18.Concat(val_0, [-1], [6], [64], axis=0)
view_4 = opset18.Reshape(linear_2, val_55, allowzero=0)
transpose_2 = opset18.Transpose(view_4, perm=[0, 2, 1, 3])
cat = opset18.Concat(past_key_values_0_0, transpose_1, axis=-2)
cat_1 = opset18.Concat(past_key_values_0_1, transpose_2, axis=-2)
transpose_3 = opset18.Transpose(cat, perm=[0, 1, 3, 2])
matmul = opset18.MatMul(transpose, transpose_3)
unsqueeze_4 = opset18.Unsqueeze(mul_17, [0, 1])
val_83 = opset18.Concat(val_0, [1], [-1], [-1], axis=0)
val_85 = opset18.Abs(val_83)
expand_1 = opset18.Expand(unsqueeze_4, val_85)
val_104 = opset18.Constant(value_ints=[0])
val_106 = opset18.Constant(value_ints=[-1])
val_107 = opset18.Reshape(add_7, val_106, allowzero=0)
val_111 = opset18.Constant(value_ints=[1])
slice_12 = opset18.Slice(expand_1, val_104, val_107, [3], val_111)
add_125 = opset18.Add(matmul, slice_12)
softmax = opset18.Softmax(add_125, axis=-1)
matmul_1 = opset18.MatMul(softmax, cat_1)
transpose_4 = opset18.Transpose(matmul_1, perm=[0, 2, 1, 3])
val_115 = opset18.Concat(val_0, [1], [384], axis=0)
view_5 = opset18.Reshape(transpose_4, val_115, allowzero=0)
val_117 = opset18.Transpose(decoder_layers_0_self_attn_out_proj_weight, perm=[1, 0])
val_118 = opset18.MatMul(view_5, val_117)
linear_3 = opset18.Add(val_118, decoder_layers_0_self_attn_out_proj_bias)
add_163 = opset18.Add(add_15, linear_3)
layer_norm_1 = opset18.LayerNormalization(
add_163,
decoder_layers_0_encoder_attn_layer_norm_weight,
decoder_layers_0_encoder_attn_layer_norm_bias,
stash_type=1,
epsilon=9.999999747378752e-06,
axis=-1,
)
val_121 = opset18.Transpose(decoder_layers_0_encoder_attn_q_proj_weight, perm=[1, 0])
val_122 = opset18.MatMul(layer_norm_1, val_121)
linear_4 = opset18.Add(val_122, decoder_layers_0_encoder_attn_q_proj_bias)
mul_125 = opset18.Mul(linear_4, 0.125)
val_125 = opset18.Concat(val_0, [1], [6], [64], axis=0)
view_6 = opset18.Reshape(mul_125, val_125, allowzero=0)
transpose_5 = opset18.Transpose(view_6, perm=[0, 2, 1, 3])
transpose_6 = opset18.Transpose(past_key_values_0_2, perm=[0, 1, 3, 2])
matmul_2 = opset18.MatMul(transpose_5, transpose_6)
softmax_1 = opset18.Softmax(matmul_2, axis=-1)
matmul_3 = opset18.MatMul(softmax_1, past_key_values_0_3)
transpose_7 = opset18.Transpose(matmul_3, perm=[0, 2, 1, 3])
val_129 = opset18.Concat(val_0, [1], [384], axis=0)
view_7 = opset18.Reshape(transpose_7, val_129, allowzero=0)
val_131 = opset18.Transpose(decoder_layers_0_encoder_attn_out_proj_weight, perm=[1, 0])
val_132 = opset18.MatMul(view_7, val_131)
linear_5 = opset18.Add(val_132, decoder_layers_0_encoder_attn_out_proj_bias)
add_232 = opset18.Add(add_163, linear_5)
layer_norm_2 = opset18.LayerNormalization(
add_232,
decoder_layers_0_final_layer_norm_weight,
decoder_layers_0_final_layer_norm_bias,
stash_type=1,
epsilon=9.999999747378752e-06,
axis=-1,
)
val_135 = opset18.Transpose(decoder_layers_0_fc1_weight, perm=[1, 0])
val_136 = opset18.MatMul(layer_norm_2, val_135)
linear_6 = opset18.Add(val_136, decoder_layers_0_fc1_bias)
val_138 = opset18.Div(linear_6, 1.4142135)
val_139 = opset18.Erf(val_138)
val_141 = opset18.Add(val_139, 1.0)
val_143 = opset18.Mul(0.5, val_141)
gelu = opset18.Mul(linear_6, val_143)
val_144 = opset18.Transpose(decoder_layers_0_fc2_weight, perm=[1, 0])
val_145 = opset18.MatMul(gelu, val_144)
linear_7 = opset18.Add(val_145, decoder_layers_0_fc2_bias)
add_261 = opset18.Add(add_232, linear_7)
layer_norm_12 = opset18.LayerNormalization(
add_261,
decoder_layer_norm_weight,
decoder_layer_norm_bias,
stash_type=1,
epsilon=9.999999747378752e-06,
axis=-1,
)
val_457 = opset18.Transpose(proj_out_weight, perm=[1, 0])
linear_32 = opset18.MatMul(layer_norm_12, val_457)
return linear_32, cat, cat_1
model = main_graph.to_model_proto()
return model
def make_model_with_random_weights():
np.random.seed(10) # Set a fixed seed
decoder_embed_positions_weight = np.random.rand(448, 384).astype(np.float32)
proj_out_weight = np.random.rand(51865, 384).astype(np.float32)
decoder_layers_0_self_attn_layer_norm_weight = np.random.rand(384).astype(np.float32)
decoder_layers_0_self_attn_layer_norm_bias = np.random.rand(384).astype(np.float32)
decoder_layers_0_self_attn_q_proj_weight = np.random.rand(384, 384).astype(np.float32)
decoder_layers_0_self_attn_q_proj_bias = np.random.rand(384).astype(np.float32)
decoder_layers_0_self_attn_k_proj_weight = np.random.rand(384, 384).astype(np.float32)
decoder_layers_0_self_attn_v_proj_weight = np.random.rand(384, 384).astype(np.float32)
decoder_layers_0_self_attn_v_proj_bias = np.random.rand(384).astype(np.float32)
decoder_layers_0_self_attn_out_proj_weight = np.random.rand(384, 384).astype(np.float32)
decoder_layers_0_self_attn_out_proj_bias = np.random.rand(384).astype(np.float32)
decoder_layers_0_encoder_attn_layer_norm_weight = np.random.rand(384).astype(np.float32)
decoder_layers_0_encoder_attn_layer_norm_bias = np.random.rand(384).astype(np.float32)
decoder_layers_0_encoder_attn_q_proj_weight = np.random.rand(384, 384).astype(np.float32)
decoder_layers_0_encoder_attn_q_proj_bias = np.random.rand(384).astype(np.float32)
decoder_layers_0_encoder_attn_out_proj_weight = np.random.rand(384, 384).astype(np.float32)
decoder_layers_0_encoder_attn_out_proj_bias = np.random.rand(384).astype(np.float32)
decoder_layers_0_final_layer_norm_weight = np.random.rand(384).astype(np.float32)
decoder_layers_0_final_layer_norm_bias = np.random.rand(384).astype(np.float32)
decoder_layers_0_fc1_weight = np.random.rand(1536, 384).astype(np.float32)
decoder_layers_0_fc1_bias = np.random.rand(1536).astype(np.float32)
decoder_layers_0_fc2_weight = np.random.rand(384, 1536).astype(np.float32)
decoder_layers_0_fc2_bias = np.random.rand(384).astype(np.float32)
decoder_layer_norm_weight = np.random.rand(384).astype(np.float32)
decoder_layer_norm_bias = np.random.rand(384).astype(np.float32)
model = make_model(
decoder_embed_positions_weight,
proj_out_weight,
decoder_layers_0_self_attn_layer_norm_weight,
decoder_layers_0_self_attn_layer_norm_bias,
decoder_layers_0_self_attn_q_proj_weight,
decoder_layers_0_self_attn_q_proj_bias,
decoder_layers_0_self_attn_k_proj_weight,
decoder_layers_0_self_attn_v_proj_weight,
decoder_layers_0_self_attn_v_proj_bias,
decoder_layers_0_self_attn_out_proj_weight,
decoder_layers_0_self_attn_out_proj_bias,
decoder_layers_0_encoder_attn_layer_norm_weight,
decoder_layers_0_encoder_attn_layer_norm_bias,
decoder_layers_0_encoder_attn_q_proj_weight,
decoder_layers_0_encoder_attn_q_proj_bias,
decoder_layers_0_encoder_attn_out_proj_weight,
decoder_layers_0_encoder_attn_out_proj_bias,
decoder_layers_0_final_layer_norm_weight,
decoder_layers_0_final_layer_norm_bias,
decoder_layers_0_fc1_weight,
decoder_layers_0_fc1_bias,
decoder_layers_0_fc2_weight,
decoder_layers_0_fc2_bias,
decoder_layer_norm_weight,
decoder_layer_norm_bias,
)
return model
class _WhisperDecoderTest:
def get_onnx_model(self):
if not hasattr(self, "_onnx_model"):
model_proto = make_model_with_random_weights()
model = ir.serde.deserialize_model(model_proto)
self._onnx_model = model
return self._onnx_model
def get_ort_inputs(self):
if not hasattr(self, "_ort_inputs"):
np.random.seed(10) # Set a fixed seed
inputs = {
"decoder_input_ids": np.random.randint(0, 49152, (1, 1)).astype(np.int32),
"encoder_hidden_states": np.random.rand(1, 1500, 384).astype(np.float32),
"past_key_values_0_0": np.random.rand(1, 6, 32, 64).astype(np.float32),
"past_key_values_0_1": np.random.rand(1, 6, 32, 64).astype(np.float32),
"past_key_values_0_2": np.random.rand(1, 6, 32, 64).astype(np.float32),
"past_key_values_0_3": np.random.rand(1, 6, 32, 64).astype(np.float32),
}
self._ort_inputs = inputs
return self._ort_inputs
def whisper_decoder_test():
return _WhisperDecoderTest()