alexwengg's picture
Upload 24 files
c1b4251 verified
raw
history blame
31.7 kB
program(1.0)
[buildInfo = dict<tensor<string, []>, tensor<string, []>>({{"coremlc-component-MIL", "3510.2.1"}, {"coremlc-version", "3500.32.1"}, {"coremltools-component-torch", "2.9.1"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "9.0"}})]
{
func main<ios16>(tensor<fp32, [1, 112, 128]> chunk, tensor<int32, [1]> chunk_lengths, tensor<fp32, [1, 40, 512]> fifo, tensor<int32, [1]> fifo_lengths, tensor<fp32, [1, 188, 512]> spkcache, tensor<int32, [1]> spkcache_lengths) {
tensor<fp32, [256]> model_encoder_pre_encode_conv_0_bias = const()[name = tensor<string, []>("model_encoder_pre_encode_conv_0_bias"), val = tensor<fp32, [256]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(64)))];
tensor<fp32, [256, 1, 3, 3]> model_encoder_pre_encode_conv_0_weight = const()[name = tensor<string, []>("model_encoder_pre_encode_conv_0_weight"), val = tensor<fp32, [256, 1, 3, 3]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(1152)))];
tensor<fp32, [256]> model_encoder_pre_encode_conv_2_bias = const()[name = tensor<string, []>("model_encoder_pre_encode_conv_2_bias"), val = tensor<fp32, [256]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(10432)))];
tensor<fp32, [256, 1, 3, 3]> model_encoder_pre_encode_conv_2_weight = const()[name = tensor<string, []>("model_encoder_pre_encode_conv_2_weight"), val = tensor<fp32, [256, 1, 3, 3]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(11520)))];
tensor<fp32, [256]> model_encoder_pre_encode_conv_3_bias = const()[name = tensor<string, []>("model_encoder_pre_encode_conv_3_bias"), val = tensor<fp32, [256]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(20800)))];
tensor<fp32, [256, 256, 1, 1]> model_encoder_pre_encode_conv_3_weight = const()[name = tensor<string, []>("model_encoder_pre_encode_conv_3_weight"), val = tensor<fp32, [256, 256, 1, 1]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(21888)))];
tensor<fp32, [256]> model_encoder_pre_encode_conv_5_bias = const()[name = tensor<string, []>("model_encoder_pre_encode_conv_5_bias"), val = tensor<fp32, [256]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(284096)))];
tensor<fp32, [256, 1, 3, 3]> model_encoder_pre_encode_conv_5_weight = const()[name = tensor<string, []>("model_encoder_pre_encode_conv_5_weight"), val = tensor<fp32, [256, 1, 3, 3]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(285184)))];
tensor<fp32, [256]> model_encoder_pre_encode_conv_6_bias = const()[name = tensor<string, []>("model_encoder_pre_encode_conv_6_bias"), val = tensor<fp32, [256]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(294464)))];
tensor<fp32, [256, 256, 1, 1]> model_encoder_pre_encode_conv_6_weight = const()[name = tensor<string, []>("model_encoder_pre_encode_conv_6_weight"), val = tensor<fp32, [256, 256, 1, 1]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(295552)))];
tensor<fp32, [512]> model_encoder_pre_encode_out_bias = const()[name = tensor<string, []>("model_encoder_pre_encode_out_bias"), val = tensor<fp32, [512]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(557760)))];
tensor<fp32, [512, 4096]> model_encoder_pre_encode_out_weight = const()[name = tensor<string, []>("model_encoder_pre_encode_out_weight"), val = tensor<fp32, [512, 4096]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(559872)))];
tensor<int32, [1]> tensor_1_axes_0 = const()[name = tensor<string, []>("tensor_1_axes_0"), val = tensor<int32, [1]>([1])];
tensor<fp32, [1, 1, 112, 128]> tensor_1 = expand_dims(axes = tensor_1_axes_0, x = chunk)[name = tensor<string, []>("tensor_1")];
tensor<string, []> current_lengths_1_dtype_0 = const()[name = tensor<string, []>("current_lengths_1_dtype_0"), val = tensor<string, []>("fp32")];
tensor<int32, [1, 112]> expand_dims_0 = const()[name = tensor<string, []>("expand_dims_0"), val = tensor<int32, [1, 112]>([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111]])];
tensor<int32, [1]> var_40_axes_0 = const()[name = tensor<string, []>("op_40_axes_0"), val = tensor<int32, [1]>([1])];
tensor<int32, [1, 1]> var_40 = expand_dims(axes = var_40_axes_0, x = chunk_lengths)[name = tensor<string, []>("op_40")];
tensor<bool, [1, 112]> time_mask_1 = less(x = expand_dims_0, y = var_40)[name = tensor<string, []>("time_mask_1")];
tensor<int32, [1]> var_42_axes_0 = const()[name = tensor<string, []>("op_42_axes_0"), val = tensor<int32, [1]>([-1])];
tensor<bool, [1, 112, 1]> var_42 = expand_dims(axes = var_42_axes_0, x = time_mask_1)[name = tensor<string, []>("op_42")];
tensor<int32, [3]> var_44_reps_0 = const()[name = tensor<string, []>("op_44_reps_0"), val = tensor<int32, [3]>([1, 1, 128])];
tensor<bool, [1, 112, 128]> var_44 = tile(reps = var_44_reps_0, x = var_42)[name = tensor<string, []>("op_44")];
tensor<string, []> mask_1_dtype_0 = const()[name = tensor<string, []>("mask_1_dtype_0"), val = tensor<string, []>("fp32")];
tensor<int32, [1]> var_50_axes_0 = const()[name = tensor<string, []>("op_50_axes_0"), val = tensor<int32, [1]>([1])];
tensor<fp32, [1, 112, 128]> mask_1 = cast(dtype = mask_1_dtype_0, x = var_44)[name = tensor<string, []>("cast_11")];
tensor<fp32, [1, 1, 112, 128]> var_50 = expand_dims(axes = var_50_axes_0, x = mask_1)[name = tensor<string, []>("op_50")];
tensor<fp32, [1, 1, 112, 128]> input_1 = mul(x = tensor_1, y = var_50)[name = tensor<string, []>("input_1")];
tensor<string, []> tensor_3_pad_type_0 = const()[name = tensor<string, []>("tensor_3_pad_type_0"), val = tensor<string, []>("custom")];
tensor<int32, [4]> tensor_3_pad_0 = const()[name = tensor<string, []>("tensor_3_pad_0"), val = tensor<int32, [4]>([1, 1, 1, 1])];
tensor<int32, [2]> tensor_3_strides_0 = const()[name = tensor<string, []>("tensor_3_strides_0"), val = tensor<int32, [2]>([2, 2])];
tensor<int32, [2]> tensor_3_dilations_0 = const()[name = tensor<string, []>("tensor_3_dilations_0"), val = tensor<int32, [2]>([1, 1])];
tensor<int32, []> tensor_3_groups_0 = const()[name = tensor<string, []>("tensor_3_groups_0"), val = tensor<int32, []>(1)];
tensor<fp32, [1, 256, 56, 64]> tensor_3 = conv(bias = model_encoder_pre_encode_conv_0_bias, dilations = tensor_3_dilations_0, groups = tensor_3_groups_0, pad = tensor_3_pad_0, pad_type = tensor_3_pad_type_0, strides = tensor_3_strides_0, weight = model_encoder_pre_encode_conv_0_weight, x = input_1)[name = tensor<string, []>("tensor_3")];
tensor<fp32, []> var_61_promoted = const()[name = tensor<string, []>("op_61_promoted"), val = tensor<fp32, []>(0x1p+0)];
tensor<fp32, [1]> current_lengths_1 = cast(dtype = current_lengths_1_dtype_0, x = chunk_lengths)[name = tensor<string, []>("cast_12")];
tensor<fp32, [1]> var_62 = add(x = current_lengths_1, y = var_61_promoted)[name = tensor<string, []>("op_62")];
tensor<fp32, []> var_63_promoted = const()[name = tensor<string, []>("op_63_promoted"), val = tensor<fp32, []>(0x1p+0)];
tensor<fp32, [1]> var_64 = add(x = var_62, y = var_63_promoted)[name = tensor<string, []>("op_64")];
tensor<fp32, []> var_65_promoted = const()[name = tensor<string, []>("op_65_promoted"), val = tensor<fp32, []>(0x1.8p+1)];
tensor<fp32, [1]> var_66 = sub(x = var_64, y = var_65_promoted)[name = tensor<string, []>("op_66")];
tensor<fp32, []> var_21_promoted = const()[name = tensor<string, []>("op_21_promoted"), val = tensor<fp32, []>(0x1p+1)];
tensor<fp32, [1]> floor_div_0 = floor_div(x = var_66, y = var_21_promoted)[name = tensor<string, []>("floor_div_0")];
tensor<fp32, []> var_68_promoted = const()[name = tensor<string, []>("op_68_promoted"), val = tensor<fp32, []>(0x1p+0)];
tensor<fp32, [1]> current_lengths_3 = add(x = floor_div_0, y = var_68_promoted)[name = tensor<string, []>("current_lengths_3")];
tensor<string, []> lengths_21_dtype_0 = const()[name = tensor<string, []>("lengths_21_dtype_0"), val = tensor<string, []>("int32")];
tensor<int32, [1, 56]> expand_dims_1 = const()[name = tensor<string, []>("expand_dims_1"), val = tensor<int32, [1, 56]>([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55]])];
tensor<int32, [1]> var_77_axes_0 = const()[name = tensor<string, []>("op_77_axes_0"), val = tensor<int32, [1]>([1])];
tensor<int32, [1]> lengths_21 = cast(dtype = lengths_21_dtype_0, x = current_lengths_3)[name = tensor<string, []>("cast_10")];
tensor<int32, [1, 1]> var_77 = expand_dims(axes = var_77_axes_0, x = lengths_21)[name = tensor<string, []>("op_77")];
tensor<bool, [1, 56]> time_mask_3 = less(x = expand_dims_1, y = var_77)[name = tensor<string, []>("time_mask_3")];
tensor<int32, [1]> var_79_axes_0 = const()[name = tensor<string, []>("op_79_axes_0"), val = tensor<int32, [1]>([-1])];
tensor<bool, [1, 56, 1]> var_79 = expand_dims(axes = var_79_axes_0, x = time_mask_3)[name = tensor<string, []>("op_79")];
tensor<int32, [3]> var_81_reps_0 = const()[name = tensor<string, []>("op_81_reps_0"), val = tensor<int32, [3]>([1, 1, 64])];
tensor<bool, [1, 56, 64]> var_81 = tile(reps = var_81_reps_0, x = var_79)[name = tensor<string, []>("op_81")];
tensor<string, []> mask_3_dtype_0 = const()[name = tensor<string, []>("mask_3_dtype_0"), val = tensor<string, []>("fp32")];
tensor<int32, [1]> var_87_axes_0 = const()[name = tensor<string, []>("op_87_axes_0"), val = tensor<int32, [1]>([1])];
tensor<fp32, [1, 56, 64]> mask_3 = cast(dtype = mask_3_dtype_0, x = var_81)[name = tensor<string, []>("cast_9")];
tensor<fp32, [1, 1, 56, 64]> var_87 = expand_dims(axes = var_87_axes_0, x = mask_3)[name = tensor<string, []>("op_87")];
tensor<int32, [4]> expanded_mask_3_reps_0 = const()[name = tensor<string, []>("expanded_mask_3_reps_0"), val = tensor<int32, [4]>([1, 256, 1, 1])];
tensor<fp32, [1, 256, 56, 64]> expanded_mask_3 = tile(reps = expanded_mask_3_reps_0, x = var_87)[name = tensor<string, []>("expanded_mask_3")];
tensor<fp32, [1, 256, 56, 64]> input_3 = mul(x = tensor_3, y = expanded_mask_3)[name = tensor<string, []>("input_3")];
tensor<fp32, [1, 256, 56, 64]> tensor_5 = relu(x = input_3)[name = tensor<string, []>("tensor_5")];
tensor<fp32, [1, 256, 56, 64]> input_5 = mul(x = tensor_5, y = expanded_mask_3)[name = tensor<string, []>("input_5")];
tensor<string, []> tensor_7_pad_type_0 = const()[name = tensor<string, []>("tensor_7_pad_type_0"), val = tensor<string, []>("custom")];
tensor<int32, [4]> tensor_7_pad_0 = const()[name = tensor<string, []>("tensor_7_pad_0"), val = tensor<int32, [4]>([1, 1, 1, 1])];
tensor<int32, [2]> tensor_7_strides_0 = const()[name = tensor<string, []>("tensor_7_strides_0"), val = tensor<int32, [2]>([2, 2])];
tensor<int32, []> tensor_7_groups_0 = const()[name = tensor<string, []>("tensor_7_groups_0"), val = tensor<int32, []>(256)];
tensor<int32, [2]> tensor_7_dilations_0 = const()[name = tensor<string, []>("tensor_7_dilations_0"), val = tensor<int32, [2]>([1, 1])];
tensor<fp32, [1, 256, 28, 32]> tensor_7 = conv(bias = model_encoder_pre_encode_conv_2_bias, dilations = tensor_7_dilations_0, groups = tensor_7_groups_0, pad = tensor_7_pad_0, pad_type = tensor_7_pad_type_0, strides = tensor_7_strides_0, weight = model_encoder_pre_encode_conv_2_weight, x = input_5)[name = tensor<string, []>("tensor_7")];
tensor<fp32, []> var_107_promoted = const()[name = tensor<string, []>("op_107_promoted"), val = tensor<fp32, []>(0x1p+0)];
tensor<fp32, [1]> var_108 = add(x = current_lengths_3, y = var_107_promoted)[name = tensor<string, []>("op_108")];
tensor<fp32, []> var_109_promoted = const()[name = tensor<string, []>("op_109_promoted"), val = tensor<fp32, []>(0x1p+0)];
tensor<fp32, [1]> var_110 = add(x = var_108, y = var_109_promoted)[name = tensor<string, []>("op_110")];
tensor<fp32, []> var_111_promoted = const()[name = tensor<string, []>("op_111_promoted"), val = tensor<fp32, []>(0x1.8p+1)];
tensor<fp32, [1]> var_112 = sub(x = var_110, y = var_111_promoted)[name = tensor<string, []>("op_112")];
tensor<fp32, []> var_21_promoted_1 = const()[name = tensor<string, []>("op_21_promoted_1"), val = tensor<fp32, []>(0x1p+1)];
tensor<fp32, [1]> floor_div_1 = floor_div(x = var_112, y = var_21_promoted_1)[name = tensor<string, []>("floor_div_1")];
tensor<fp32, []> var_114_promoted = const()[name = tensor<string, []>("op_114_promoted"), val = tensor<fp32, []>(0x1p+0)];
tensor<fp32, [1]> current_lengths_5 = add(x = floor_div_1, y = var_114_promoted)[name = tensor<string, []>("current_lengths_5")];
tensor<string, []> lengths_23_dtype_0 = const()[name = tensor<string, []>("lengths_23_dtype_0"), val = tensor<string, []>("int32")];
tensor<int32, [1, 28]> expand_dims_2 = const()[name = tensor<string, []>("expand_dims_2"), val = tensor<int32, [1, 28]>([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]])];
tensor<int32, [1]> var_123_axes_0 = const()[name = tensor<string, []>("op_123_axes_0"), val = tensor<int32, [1]>([1])];
tensor<int32, [1]> lengths_23 = cast(dtype = lengths_23_dtype_0, x = current_lengths_5)[name = tensor<string, []>("cast_8")];
tensor<int32, [1, 1]> var_123 = expand_dims(axes = var_123_axes_0, x = lengths_23)[name = tensor<string, []>("op_123")];
tensor<bool, [1, 28]> time_mask_5 = less(x = expand_dims_2, y = var_123)[name = tensor<string, []>("time_mask_5")];
tensor<int32, [1]> var_125_axes_0 = const()[name = tensor<string, []>("op_125_axes_0"), val = tensor<int32, [1]>([-1])];
tensor<bool, [1, 28, 1]> var_125 = expand_dims(axes = var_125_axes_0, x = time_mask_5)[name = tensor<string, []>("op_125")];
tensor<int32, [3]> var_127_reps_0 = const()[name = tensor<string, []>("op_127_reps_0"), val = tensor<int32, [3]>([1, 1, 32])];
tensor<bool, [1, 28, 32]> var_127 = tile(reps = var_127_reps_0, x = var_125)[name = tensor<string, []>("op_127")];
tensor<string, []> mask_5_dtype_0 = const()[name = tensor<string, []>("mask_5_dtype_0"), val = tensor<string, []>("fp32")];
tensor<int32, [1]> var_133_axes_0 = const()[name = tensor<string, []>("op_133_axes_0"), val = tensor<int32, [1]>([1])];
tensor<fp32, [1, 28, 32]> mask_5 = cast(dtype = mask_5_dtype_0, x = var_127)[name = tensor<string, []>("cast_7")];
tensor<fp32, [1, 1, 28, 32]> var_133 = expand_dims(axes = var_133_axes_0, x = mask_5)[name = tensor<string, []>("op_133")];
tensor<int32, [4]> expanded_mask_7_reps_0 = const()[name = tensor<string, []>("expanded_mask_7_reps_0"), val = tensor<int32, [4]>([1, 256, 1, 1])];
tensor<fp32, [1, 256, 28, 32]> expanded_mask_7 = tile(reps = expanded_mask_7_reps_0, x = var_133)[name = tensor<string, []>("expanded_mask_7")];
tensor<fp32, [1, 256, 28, 32]> input_7 = mul(x = tensor_7, y = expanded_mask_7)[name = tensor<string, []>("input_7")];
tensor<string, []> tensor_9_pad_type_0 = const()[name = tensor<string, []>("tensor_9_pad_type_0"), val = tensor<string, []>("valid")];
tensor<int32, [2]> tensor_9_strides_0 = const()[name = tensor<string, []>("tensor_9_strides_0"), val = tensor<int32, [2]>([1, 1])];
tensor<int32, [4]> tensor_9_pad_0 = const()[name = tensor<string, []>("tensor_9_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
tensor<int32, [2]> tensor_9_dilations_0 = const()[name = tensor<string, []>("tensor_9_dilations_0"), val = tensor<int32, [2]>([1, 1])];
tensor<int32, []> tensor_9_groups_0 = const()[name = tensor<string, []>("tensor_9_groups_0"), val = tensor<int32, []>(1)];
tensor<fp32, [1, 256, 28, 32]> tensor_9 = conv(bias = model_encoder_pre_encode_conv_3_bias, dilations = tensor_9_dilations_0, groups = tensor_9_groups_0, pad = tensor_9_pad_0, pad_type = tensor_9_pad_type_0, strides = tensor_9_strides_0, weight = model_encoder_pre_encode_conv_3_weight, x = input_7)[name = tensor<string, []>("tensor_9")];
tensor<fp32, [1, 256, 28, 32]> input_9 = mul(x = tensor_9, y = expanded_mask_7)[name = tensor<string, []>("input_9")];
tensor<fp32, [1, 256, 28, 32]> tensor_11 = relu(x = input_9)[name = tensor<string, []>("tensor_11")];
tensor<fp32, [1, 256, 28, 32]> input_11 = mul(x = tensor_11, y = expanded_mask_7)[name = tensor<string, []>("input_11")];
tensor<string, []> tensor_13_pad_type_0 = const()[name = tensor<string, []>("tensor_13_pad_type_0"), val = tensor<string, []>("custom")];
tensor<int32, [4]> tensor_13_pad_0 = const()[name = tensor<string, []>("tensor_13_pad_0"), val = tensor<int32, [4]>([1, 1, 1, 1])];
tensor<int32, [2]> tensor_13_strides_0 = const()[name = tensor<string, []>("tensor_13_strides_0"), val = tensor<int32, [2]>([2, 2])];
tensor<int32, []> tensor_13_groups_0 = const()[name = tensor<string, []>("tensor_13_groups_0"), val = tensor<int32, []>(256)];
tensor<int32, [2]> tensor_13_dilations_0 = const()[name = tensor<string, []>("tensor_13_dilations_0"), val = tensor<int32, [2]>([1, 1])];
tensor<fp32, [1, 256, 14, 16]> tensor_13 = conv(bias = model_encoder_pre_encode_conv_5_bias, dilations = tensor_13_dilations_0, groups = tensor_13_groups_0, pad = tensor_13_pad_0, pad_type = tensor_13_pad_type_0, strides = tensor_13_strides_0, weight = model_encoder_pre_encode_conv_5_weight, x = input_11)[name = tensor<string, []>("tensor_13")];
tensor<fp32, []> var_168_promoted = const()[name = tensor<string, []>("op_168_promoted"), val = tensor<fp32, []>(0x1p+0)];
tensor<fp32, [1]> var_169 = add(x = current_lengths_5, y = var_168_promoted)[name = tensor<string, []>("op_169")];
tensor<fp32, []> var_170_promoted = const()[name = tensor<string, []>("op_170_promoted"), val = tensor<fp32, []>(0x1p+0)];
tensor<fp32, [1]> var_171 = add(x = var_169, y = var_170_promoted)[name = tensor<string, []>("op_171")];
tensor<fp32, []> var_172_promoted = const()[name = tensor<string, []>("op_172_promoted"), val = tensor<fp32, []>(0x1.8p+1)];
tensor<fp32, [1]> var_173 = sub(x = var_171, y = var_172_promoted)[name = tensor<string, []>("op_173")];
tensor<fp32, []> var_21_promoted_2 = const()[name = tensor<string, []>("op_21_promoted_2"), val = tensor<fp32, []>(0x1p+1)];
tensor<fp32, [1]> floor_div_2 = floor_div(x = var_173, y = var_21_promoted_2)[name = tensor<string, []>("floor_div_2")];
tensor<fp32, []> var_175_promoted = const()[name = tensor<string, []>("op_175_promoted"), val = tensor<fp32, []>(0x1p+0)];
tensor<fp32, [1]> current_lengths = add(x = floor_div_2, y = var_175_promoted)[name = tensor<string, []>("current_lengths")];
tensor<string, []> lengths_dtype_0 = const()[name = tensor<string, []>("lengths_dtype_0"), val = tensor<string, []>("int32")];
tensor<int32, [1, 14]> expand_dims_3 = const()[name = tensor<string, []>("expand_dims_3"), val = tensor<int32, [1, 14]>([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]])];
tensor<int32, [1]> var_184_axes_0 = const()[name = tensor<string, []>("op_184_axes_0"), val = tensor<int32, [1]>([1])];
tensor<int32, [1]> lengths = cast(dtype = lengths_dtype_0, x = current_lengths)[name = tensor<string, []>("cast_6")];
tensor<int32, [1, 1]> var_184 = expand_dims(axes = var_184_axes_0, x = lengths)[name = tensor<string, []>("op_184")];
tensor<bool, [1, 14]> time_mask = less(x = expand_dims_3, y = var_184)[name = tensor<string, []>("time_mask")];
tensor<int32, [1]> var_186_axes_0 = const()[name = tensor<string, []>("op_186_axes_0"), val = tensor<int32, [1]>([-1])];
tensor<bool, [1, 14, 1]> var_186 = expand_dims(axes = var_186_axes_0, x = time_mask)[name = tensor<string, []>("op_186")];
tensor<int32, [3]> var_188_reps_0 = const()[name = tensor<string, []>("op_188_reps_0"), val = tensor<int32, [3]>([1, 1, 16])];
tensor<bool, [1, 14, 16]> var_188 = tile(reps = var_188_reps_0, x = var_186)[name = tensor<string, []>("op_188")];
tensor<string, []> mask_dtype_0 = const()[name = tensor<string, []>("mask_dtype_0"), val = tensor<string, []>("fp32")];
tensor<int32, [1]> var_194_axes_0 = const()[name = tensor<string, []>("op_194_axes_0"), val = tensor<int32, [1]>([1])];
tensor<fp32, [1, 14, 16]> mask = cast(dtype = mask_dtype_0, x = var_188)[name = tensor<string, []>("cast_5")];
tensor<fp32, [1, 1, 14, 16]> var_194 = expand_dims(axes = var_194_axes_0, x = mask)[name = tensor<string, []>("op_194")];
tensor<int32, [4]> expanded_mask_13_reps_0 = const()[name = tensor<string, []>("expanded_mask_13_reps_0"), val = tensor<int32, [4]>([1, 256, 1, 1])];
tensor<fp32, [1, 256, 14, 16]> expanded_mask_13 = tile(reps = expanded_mask_13_reps_0, x = var_194)[name = tensor<string, []>("expanded_mask_13")];
tensor<fp32, [1, 256, 14, 16]> input_13 = mul(x = tensor_13, y = expanded_mask_13)[name = tensor<string, []>("input_13")];
tensor<string, []> tensor_15_pad_type_0 = const()[name = tensor<string, []>("tensor_15_pad_type_0"), val = tensor<string, []>("valid")];
tensor<int32, [2]> tensor_15_strides_0 = const()[name = tensor<string, []>("tensor_15_strides_0"), val = tensor<int32, [2]>([1, 1])];
tensor<int32, [4]> tensor_15_pad_0 = const()[name = tensor<string, []>("tensor_15_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
tensor<int32, [2]> tensor_15_dilations_0 = const()[name = tensor<string, []>("tensor_15_dilations_0"), val = tensor<int32, [2]>([1, 1])];
tensor<int32, []> tensor_15_groups_0 = const()[name = tensor<string, []>("tensor_15_groups_0"), val = tensor<int32, []>(1)];
tensor<fp32, [1, 256, 14, 16]> tensor_15 = conv(bias = model_encoder_pre_encode_conv_6_bias, dilations = tensor_15_dilations_0, groups = tensor_15_groups_0, pad = tensor_15_pad_0, pad_type = tensor_15_pad_type_0, strides = tensor_15_strides_0, weight = model_encoder_pre_encode_conv_6_weight, x = input_13)[name = tensor<string, []>("tensor_15")];
tensor<fp32, [1, 256, 14, 16]> input_15 = mul(x = tensor_15, y = expanded_mask_13)[name = tensor<string, []>("input_15")];
tensor<fp32, [1, 256, 14, 16]> tensor_workaround = relu(x = input_15)[name = tensor<string, []>("tensor_workaround")];
tensor<fp32, [1, 256, 14, 16]> x = mul(x = tensor_workaround, y = expanded_mask_13)[name = tensor<string, []>("x")];
tensor<int32, [4]> var_228_perm_0 = const()[name = tensor<string, []>("op_228_perm_0"), val = tensor<int32, [4]>([0, 2, 1, 3])];
tensor<int32, [3]> var_229 = const()[name = tensor<string, []>("op_229"), val = tensor<int32, [3]>([1, 14, -1])];
tensor<fp32, [1, 14, 256, 16]> var_228 = transpose(perm = var_228_perm_0, x = x)[name = tensor<string, []>("transpose_0")];
tensor<fp32, [1, 14, 4096]> input = reshape(shape = var_229, x = var_228)[name = tensor<string, []>("input")];
tensor<fp32, [1, 14, 512]> chunk_embs_in = linear(bias = model_encoder_pre_encode_out_bias, weight = model_encoder_pre_encode_out_weight, x = input)[name = tensor<string, []>("linear_0")];
tensor<string, []> var_241_dtype_0 = const()[name = tensor<string, []>("op_241_dtype_0"), val = tensor<string, []>("int32")];
tensor<int32, [1]> size0 = const()[name = tensor<string, []>("size0"), val = tensor<int32, [1]>([188])];
tensor<int32, [1]> size1 = const()[name = tensor<string, []>("size1"), val = tensor<int32, [1]>([40])];
tensor<int32, []> var_264 = const()[name = tensor<string, []>("op_264"), val = tensor<int32, []>(1)];
tensor<bool, []> full_concat_interleave_0 = const()[name = tensor<string, []>("full_concat_interleave_0"), val = tensor<bool, []>(false)];
tensor<fp32, [1, 242, 512]> full_concat = concat(axis = var_264, interleave = full_concat_interleave_0, values = (spkcache, fifo, chunk_embs_in))[name = tensor<string, []>("full_concat")];
tensor<int32, [1]> var_273 = add(x = spkcache_lengths, y = fifo_lengths)[name = tensor<string, []>("op_273")];
tensor<int32, [1]> chunk_lens_in = cast(dtype = var_241_dtype_0, x = current_lengths)[name = tensor<string, []>("cast_4")];
tensor<int32, [1]> pre_encoder_lengths = add(x = var_273, y = chunk_lens_in)[name = tensor<string, []>("total_length")];
tensor<int32, [242]> out_pos = const()[name = tensor<string, []>("out_pos"), val = tensor<int32, [242]>([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241])];
tensor<bool, [242]> var_284 = greater_equal(x = out_pos, y = spkcache_lengths)[name = tensor<string, []>("op_284")];
tensor<string, []> in_seg1_or_2_dtype_0 = const()[name = tensor<string, []>("in_seg1_or_2_dtype_0"), val = tensor<string, []>("int32")];
tensor<bool, [242]> var_290 = greater_equal(x = out_pos, y = var_273)[name = tensor<string, []>("op_290")];
tensor<string, []> in_seg2_dtype_0 = const()[name = tensor<string, []>("in_seg2_dtype_0"), val = tensor<string, []>("int32")];
tensor<int32, [1]> var_297 = sub(x = size0, y = spkcache_lengths)[name = tensor<string, []>("op_297")];
tensor<int32, [242]> in_seg1_or_2 = cast(dtype = in_seg1_or_2_dtype_0, x = var_284)[name = tensor<string, []>("cast_3")];
tensor<int32, [242]> var_298 = mul(x = in_seg1_or_2, y = var_297)[name = tensor<string, []>("op_298")];
tensor<int32, [1]> var_300 = sub(x = size1, y = fifo_lengths)[name = tensor<string, []>("op_300")];
tensor<int32, [242]> in_seg2 = cast(dtype = in_seg2_dtype_0, x = var_290)[name = tensor<string, []>("cast_2")];
tensor<int32, [242]> var_301 = mul(x = in_seg2, y = var_300)[name = tensor<string, []>("op_301")];
tensor<int32, [242]> offset = add(x = var_298, y = var_301)[name = tensor<string, []>("offset")];
tensor<int32, [242]> var_305 = add(x = out_pos, y = offset)[name = tensor<string, []>("op_305")];
tensor<int32, []> var_309 = const()[name = tensor<string, []>("op_309"), val = tensor<int32, []>(241)];
tensor<int32, []> var_310 = const()[name = tensor<string, []>("op_310"), val = tensor<int32, []>(0)];
tensor<int32, [242]> minimum_0 = minimum(x = var_305, y = var_309)[name = tensor<string, []>("minimum_0")];
tensor<int32, [242]> maximum_0 = maximum(x = minimum_0, y = var_310)[name = tensor<string, []>("maximum_0")];
tensor<int32, [1]> var_313_axes_0 = const()[name = tensor<string, []>("op_313_axes_0"), val = tensor<int32, [1]>([0])];
tensor<int32, [1, 242]> var_313 = expand_dims(axes = var_313_axes_0, x = maximum_0)[name = tensor<string, []>("op_313")];
tensor<int32, [1]> var_315_axes_0 = const()[name = tensor<string, []>("op_315_axes_0"), val = tensor<int32, [1]>([-1])];
tensor<int32, [1, 242, 1]> var_315 = expand_dims(axes = var_315_axes_0, x = var_313)[name = tensor<string, []>("op_315")];
tensor<int32, [3]> gather_idx_reps_0 = const()[name = tensor<string, []>("gather_idx_reps_0"), val = tensor<int32, [3]>([1, 1, 512])];
tensor<int32, [1, 242, 512]> gather_idx = tile(reps = gather_idx_reps_0, x = var_315)[name = tensor<string, []>("gather_idx")];
tensor<int32, []> var_320 = const()[name = tensor<string, []>("op_320"), val = tensor<int32, []>(1)];
tensor<fp32, [1, 242, 512]> output = gather_along_axis(axis = var_320, indices = gather_idx, x = full_concat)[name = tensor<string, []>("output")];
tensor<bool, [242]> var_323 = less(x = out_pos, y = pre_encoder_lengths)[name = tensor<string, []>("op_323")];
tensor<string, []> var_328_dtype_0 = const()[name = tensor<string, []>("op_328_dtype_0"), val = tensor<string, []>("fp32")];
tensor<int32, [1]> var_330_axes_0 = const()[name = tensor<string, []>("op_330_axes_0"), val = tensor<int32, [1]>([0])];
tensor<fp32, [242]> var_328 = cast(dtype = var_328_dtype_0, x = var_323)[name = tensor<string, []>("cast_1")];
tensor<fp32, [1, 242]> var_330 = expand_dims(axes = var_330_axes_0, x = var_328)[name = tensor<string, []>("op_330")];
tensor<int32, [1]> var_332_axes_0 = const()[name = tensor<string, []>("op_332_axes_0"), val = tensor<int32, [1]>([-1])];
tensor<fp32, [1, 242, 1]> var_332 = expand_dims(axes = var_332_axes_0, x = var_330)[name = tensor<string, []>("op_332")];
tensor<fp32, [1, 242, 512]> pre_encoder_embs = mul(x = output, y = var_332)[name = tensor<string, []>("op_333")];
} -> (pre_encoder_embs, pre_encoder_lengths, chunk_embs_in, chunk_lens_in);
}