program(1.0) [buildInfo = dict, tensor>({{"coremlc-component-MIL", "3510.2.1"}, {"coremlc-version", "3500.32.1"}, {"coremltools-component-torch", "2.9.1"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "9.0"}})] { func main(tensor chunk, tensor chunk_lengths, tensor fifo, tensor fifo_lengths, tensor spkcache, tensor spkcache_lengths) { tensor model_encoder_pre_encode_conv_0_bias = const()[name = tensor("model_encoder_pre_encode_conv_0_bias"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(64)))]; tensor model_encoder_pre_encode_conv_0_weight = const()[name = tensor("model_encoder_pre_encode_conv_0_weight"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(1152)))]; tensor model_encoder_pre_encode_conv_2_bias = const()[name = tensor("model_encoder_pre_encode_conv_2_bias"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(10432)))]; tensor model_encoder_pre_encode_conv_2_weight = const()[name = tensor("model_encoder_pre_encode_conv_2_weight"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(11520)))]; tensor model_encoder_pre_encode_conv_3_bias = const()[name = tensor("model_encoder_pre_encode_conv_3_bias"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(20800)))]; tensor model_encoder_pre_encode_conv_3_weight = const()[name = tensor("model_encoder_pre_encode_conv_3_weight"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(21888)))]; tensor model_encoder_pre_encode_conv_5_bias = const()[name = tensor("model_encoder_pre_encode_conv_5_bias"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(284096)))]; tensor model_encoder_pre_encode_conv_5_weight = const()[name = tensor("model_encoder_pre_encode_conv_5_weight"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(285184)))]; tensor model_encoder_pre_encode_conv_6_bias = const()[name = tensor("model_encoder_pre_encode_conv_6_bias"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(294464)))]; tensor model_encoder_pre_encode_conv_6_weight = const()[name = tensor("model_encoder_pre_encode_conv_6_weight"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(295552)))]; tensor model_encoder_pre_encode_out_bias = const()[name = tensor("model_encoder_pre_encode_out_bias"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(557760)))]; tensor model_encoder_pre_encode_out_weight = const()[name = tensor("model_encoder_pre_encode_out_weight"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(559872)))]; tensor tensor_1_axes_0 = const()[name = tensor("tensor_1_axes_0"), val = tensor([1])]; tensor tensor_1 = expand_dims(axes = tensor_1_axes_0, x = chunk)[name = tensor("tensor_1")]; tensor current_lengths_1_dtype_0 = const()[name = tensor("current_lengths_1_dtype_0"), val = tensor("fp32")]; tensor expand_dims_0 = const()[name = tensor("expand_dims_0"), val = tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111]])]; tensor var_40_axes_0 = const()[name = tensor("op_40_axes_0"), val = tensor([1])]; tensor var_40 = expand_dims(axes = var_40_axes_0, x = chunk_lengths)[name = tensor("op_40")]; tensor time_mask_1 = less(x = expand_dims_0, y = var_40)[name = tensor("time_mask_1")]; tensor var_42_axes_0 = const()[name = tensor("op_42_axes_0"), val = tensor([-1])]; tensor var_42 = expand_dims(axes = var_42_axes_0, x = time_mask_1)[name = tensor("op_42")]; tensor var_44_reps_0 = const()[name = tensor("op_44_reps_0"), val = tensor([1, 1, 128])]; tensor var_44 = tile(reps = var_44_reps_0, x = var_42)[name = tensor("op_44")]; tensor mask_1_dtype_0 = const()[name = tensor("mask_1_dtype_0"), val = tensor("fp32")]; tensor var_50_axes_0 = const()[name = tensor("op_50_axes_0"), val = tensor([1])]; tensor mask_1 = cast(dtype = mask_1_dtype_0, x = var_44)[name = tensor("cast_11")]; tensor var_50 = expand_dims(axes = var_50_axes_0, x = mask_1)[name = tensor("op_50")]; tensor input_1 = mul(x = tensor_1, y = var_50)[name = tensor("input_1")]; tensor tensor_3_pad_type_0 = const()[name = tensor("tensor_3_pad_type_0"), val = tensor("custom")]; tensor tensor_3_pad_0 = const()[name = tensor("tensor_3_pad_0"), val = tensor([1, 1, 1, 1])]; tensor tensor_3_strides_0 = const()[name = tensor("tensor_3_strides_0"), val = tensor([2, 2])]; tensor tensor_3_dilations_0 = const()[name = tensor("tensor_3_dilations_0"), val = tensor([1, 1])]; tensor tensor_3_groups_0 = const()[name = tensor("tensor_3_groups_0"), val = tensor(1)]; tensor tensor_3 = conv(bias = model_encoder_pre_encode_conv_0_bias, dilations = tensor_3_dilations_0, groups = tensor_3_groups_0, pad = tensor_3_pad_0, pad_type = tensor_3_pad_type_0, strides = tensor_3_strides_0, weight = model_encoder_pre_encode_conv_0_weight, x = input_1)[name = tensor("tensor_3")]; tensor var_61_promoted = const()[name = tensor("op_61_promoted"), val = tensor(0x1p+0)]; tensor current_lengths_1 = cast(dtype = current_lengths_1_dtype_0, x = chunk_lengths)[name = tensor("cast_12")]; tensor var_62 = add(x = current_lengths_1, y = var_61_promoted)[name = tensor("op_62")]; tensor var_63_promoted = const()[name = tensor("op_63_promoted"), val = tensor(0x1p+0)]; tensor var_64 = add(x = var_62, y = var_63_promoted)[name = tensor("op_64")]; tensor var_65_promoted = const()[name = tensor("op_65_promoted"), val = tensor(0x1.8p+1)]; tensor var_66 = sub(x = var_64, y = var_65_promoted)[name = tensor("op_66")]; tensor var_21_promoted = const()[name = tensor("op_21_promoted"), val = tensor(0x1p+1)]; tensor floor_div_0 = floor_div(x = var_66, y = var_21_promoted)[name = tensor("floor_div_0")]; tensor var_68_promoted = const()[name = tensor("op_68_promoted"), val = tensor(0x1p+0)]; tensor current_lengths_3 = add(x = floor_div_0, y = var_68_promoted)[name = tensor("current_lengths_3")]; tensor lengths_21_dtype_0 = const()[name = tensor("lengths_21_dtype_0"), val = tensor("int32")]; tensor expand_dims_1 = const()[name = tensor("expand_dims_1"), val = tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55]])]; tensor var_77_axes_0 = const()[name = tensor("op_77_axes_0"), val = tensor([1])]; tensor lengths_21 = cast(dtype = lengths_21_dtype_0, x = current_lengths_3)[name = tensor("cast_10")]; tensor var_77 = expand_dims(axes = var_77_axes_0, x = lengths_21)[name = tensor("op_77")]; tensor time_mask_3 = less(x = expand_dims_1, y = var_77)[name = tensor("time_mask_3")]; tensor var_79_axes_0 = const()[name = tensor("op_79_axes_0"), val = tensor([-1])]; tensor var_79 = expand_dims(axes = var_79_axes_0, x = time_mask_3)[name = tensor("op_79")]; tensor var_81_reps_0 = const()[name = tensor("op_81_reps_0"), val = tensor([1, 1, 64])]; tensor var_81 = tile(reps = var_81_reps_0, x = var_79)[name = tensor("op_81")]; tensor mask_3_dtype_0 = const()[name = tensor("mask_3_dtype_0"), val = tensor("fp32")]; tensor var_87_axes_0 = const()[name = tensor("op_87_axes_0"), val = tensor([1])]; tensor mask_3 = cast(dtype = mask_3_dtype_0, x = var_81)[name = tensor("cast_9")]; tensor var_87 = expand_dims(axes = var_87_axes_0, x = mask_3)[name = tensor("op_87")]; tensor expanded_mask_3_reps_0 = const()[name = tensor("expanded_mask_3_reps_0"), val = tensor([1, 256, 1, 1])]; tensor expanded_mask_3 = tile(reps = expanded_mask_3_reps_0, x = var_87)[name = tensor("expanded_mask_3")]; tensor input_3 = mul(x = tensor_3, y = expanded_mask_3)[name = tensor("input_3")]; tensor tensor_5 = relu(x = input_3)[name = tensor("tensor_5")]; tensor input_5 = mul(x = tensor_5, y = expanded_mask_3)[name = tensor("input_5")]; tensor tensor_7_pad_type_0 = const()[name = tensor("tensor_7_pad_type_0"), val = tensor("custom")]; tensor tensor_7_pad_0 = const()[name = tensor("tensor_7_pad_0"), val = tensor([1, 1, 1, 1])]; tensor tensor_7_strides_0 = const()[name = tensor("tensor_7_strides_0"), val = tensor([2, 2])]; tensor tensor_7_groups_0 = const()[name = tensor("tensor_7_groups_0"), val = tensor(256)]; tensor tensor_7_dilations_0 = const()[name = tensor("tensor_7_dilations_0"), val = tensor([1, 1])]; tensor tensor_7 = conv(bias = model_encoder_pre_encode_conv_2_bias, dilations = tensor_7_dilations_0, groups = tensor_7_groups_0, pad = tensor_7_pad_0, pad_type = tensor_7_pad_type_0, strides = tensor_7_strides_0, weight = model_encoder_pre_encode_conv_2_weight, x = input_5)[name = tensor("tensor_7")]; tensor var_107_promoted = const()[name = tensor("op_107_promoted"), val = tensor(0x1p+0)]; tensor var_108 = add(x = current_lengths_3, y = var_107_promoted)[name = tensor("op_108")]; tensor var_109_promoted = const()[name = tensor("op_109_promoted"), val = tensor(0x1p+0)]; tensor var_110 = add(x = var_108, y = var_109_promoted)[name = tensor("op_110")]; tensor var_111_promoted = const()[name = tensor("op_111_promoted"), val = tensor(0x1.8p+1)]; tensor var_112 = sub(x = var_110, y = var_111_promoted)[name = tensor("op_112")]; tensor var_21_promoted_1 = const()[name = tensor("op_21_promoted_1"), val = tensor(0x1p+1)]; tensor floor_div_1 = floor_div(x = var_112, y = var_21_promoted_1)[name = tensor("floor_div_1")]; tensor var_114_promoted = const()[name = tensor("op_114_promoted"), val = tensor(0x1p+0)]; tensor current_lengths_5 = add(x = floor_div_1, y = var_114_promoted)[name = tensor("current_lengths_5")]; tensor lengths_23_dtype_0 = const()[name = tensor("lengths_23_dtype_0"), val = tensor("int32")]; tensor expand_dims_2 = const()[name = tensor("expand_dims_2"), val = tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]])]; tensor var_123_axes_0 = const()[name = tensor("op_123_axes_0"), val = tensor([1])]; tensor lengths_23 = cast(dtype = lengths_23_dtype_0, x = current_lengths_5)[name = tensor("cast_8")]; tensor var_123 = expand_dims(axes = var_123_axes_0, x = lengths_23)[name = tensor("op_123")]; tensor time_mask_5 = less(x = expand_dims_2, y = var_123)[name = tensor("time_mask_5")]; tensor var_125_axes_0 = const()[name = tensor("op_125_axes_0"), val = tensor([-1])]; tensor var_125 = expand_dims(axes = var_125_axes_0, x = time_mask_5)[name = tensor("op_125")]; tensor var_127_reps_0 = const()[name = tensor("op_127_reps_0"), val = tensor([1, 1, 32])]; tensor var_127 = tile(reps = var_127_reps_0, x = var_125)[name = tensor("op_127")]; tensor mask_5_dtype_0 = const()[name = tensor("mask_5_dtype_0"), val = tensor("fp32")]; tensor var_133_axes_0 = const()[name = tensor("op_133_axes_0"), val = tensor([1])]; tensor mask_5 = cast(dtype = mask_5_dtype_0, x = var_127)[name = tensor("cast_7")]; tensor var_133 = expand_dims(axes = var_133_axes_0, x = mask_5)[name = tensor("op_133")]; tensor expanded_mask_7_reps_0 = const()[name = tensor("expanded_mask_7_reps_0"), val = tensor([1, 256, 1, 1])]; tensor expanded_mask_7 = tile(reps = expanded_mask_7_reps_0, x = var_133)[name = tensor("expanded_mask_7")]; tensor input_7 = mul(x = tensor_7, y = expanded_mask_7)[name = tensor("input_7")]; tensor tensor_9_pad_type_0 = const()[name = tensor("tensor_9_pad_type_0"), val = tensor("valid")]; tensor tensor_9_strides_0 = const()[name = tensor("tensor_9_strides_0"), val = tensor([1, 1])]; tensor tensor_9_pad_0 = const()[name = tensor("tensor_9_pad_0"), val = tensor([0, 0, 0, 0])]; tensor tensor_9_dilations_0 = const()[name = tensor("tensor_9_dilations_0"), val = tensor([1, 1])]; tensor tensor_9_groups_0 = const()[name = tensor("tensor_9_groups_0"), val = tensor(1)]; tensor tensor_9 = conv(bias = model_encoder_pre_encode_conv_3_bias, dilations = tensor_9_dilations_0, groups = tensor_9_groups_0, pad = tensor_9_pad_0, pad_type = tensor_9_pad_type_0, strides = tensor_9_strides_0, weight = model_encoder_pre_encode_conv_3_weight, x = input_7)[name = tensor("tensor_9")]; tensor input_9 = mul(x = tensor_9, y = expanded_mask_7)[name = tensor("input_9")]; tensor tensor_11 = relu(x = input_9)[name = tensor("tensor_11")]; tensor input_11 = mul(x = tensor_11, y = expanded_mask_7)[name = tensor("input_11")]; tensor tensor_13_pad_type_0 = const()[name = tensor("tensor_13_pad_type_0"), val = tensor("custom")]; tensor tensor_13_pad_0 = const()[name = tensor("tensor_13_pad_0"), val = tensor([1, 1, 1, 1])]; tensor tensor_13_strides_0 = const()[name = tensor("tensor_13_strides_0"), val = tensor([2, 2])]; tensor tensor_13_groups_0 = const()[name = tensor("tensor_13_groups_0"), val = tensor(256)]; tensor tensor_13_dilations_0 = const()[name = tensor("tensor_13_dilations_0"), val = tensor([1, 1])]; tensor tensor_13 = conv(bias = model_encoder_pre_encode_conv_5_bias, dilations = tensor_13_dilations_0, groups = tensor_13_groups_0, pad = tensor_13_pad_0, pad_type = tensor_13_pad_type_0, strides = tensor_13_strides_0, weight = model_encoder_pre_encode_conv_5_weight, x = input_11)[name = tensor("tensor_13")]; tensor var_168_promoted = const()[name = tensor("op_168_promoted"), val = tensor(0x1p+0)]; tensor var_169 = add(x = current_lengths_5, y = var_168_promoted)[name = tensor("op_169")]; tensor var_170_promoted = const()[name = tensor("op_170_promoted"), val = tensor(0x1p+0)]; tensor var_171 = add(x = var_169, y = var_170_promoted)[name = tensor("op_171")]; tensor var_172_promoted = const()[name = tensor("op_172_promoted"), val = tensor(0x1.8p+1)]; tensor var_173 = sub(x = var_171, y = var_172_promoted)[name = tensor("op_173")]; tensor var_21_promoted_2 = const()[name = tensor("op_21_promoted_2"), val = tensor(0x1p+1)]; tensor floor_div_2 = floor_div(x = var_173, y = var_21_promoted_2)[name = tensor("floor_div_2")]; tensor var_175_promoted = const()[name = tensor("op_175_promoted"), val = tensor(0x1p+0)]; tensor current_lengths = add(x = floor_div_2, y = var_175_promoted)[name = tensor("current_lengths")]; tensor lengths_dtype_0 = const()[name = tensor("lengths_dtype_0"), val = tensor("int32")]; tensor expand_dims_3 = const()[name = tensor("expand_dims_3"), val = tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]])]; tensor var_184_axes_0 = const()[name = tensor("op_184_axes_0"), val = tensor([1])]; tensor lengths = cast(dtype = lengths_dtype_0, x = current_lengths)[name = tensor("cast_6")]; tensor var_184 = expand_dims(axes = var_184_axes_0, x = lengths)[name = tensor("op_184")]; tensor time_mask = less(x = expand_dims_3, y = var_184)[name = tensor("time_mask")]; tensor var_186_axes_0 = const()[name = tensor("op_186_axes_0"), val = tensor([-1])]; tensor var_186 = expand_dims(axes = var_186_axes_0, x = time_mask)[name = tensor("op_186")]; tensor var_188_reps_0 = const()[name = tensor("op_188_reps_0"), val = tensor([1, 1, 16])]; tensor var_188 = tile(reps = var_188_reps_0, x = var_186)[name = tensor("op_188")]; tensor mask_dtype_0 = const()[name = tensor("mask_dtype_0"), val = tensor("fp32")]; tensor var_194_axes_0 = const()[name = tensor("op_194_axes_0"), val = tensor([1])]; tensor mask = cast(dtype = mask_dtype_0, x = var_188)[name = tensor("cast_5")]; tensor var_194 = expand_dims(axes = var_194_axes_0, x = mask)[name = tensor("op_194")]; tensor expanded_mask_13_reps_0 = const()[name = tensor("expanded_mask_13_reps_0"), val = tensor([1, 256, 1, 1])]; tensor expanded_mask_13 = tile(reps = expanded_mask_13_reps_0, x = var_194)[name = tensor("expanded_mask_13")]; tensor input_13 = mul(x = tensor_13, y = expanded_mask_13)[name = tensor("input_13")]; tensor tensor_15_pad_type_0 = const()[name = tensor("tensor_15_pad_type_0"), val = tensor("valid")]; tensor tensor_15_strides_0 = const()[name = tensor("tensor_15_strides_0"), val = tensor([1, 1])]; tensor tensor_15_pad_0 = const()[name = tensor("tensor_15_pad_0"), val = tensor([0, 0, 0, 0])]; tensor tensor_15_dilations_0 = const()[name = tensor("tensor_15_dilations_0"), val = tensor([1, 1])]; tensor tensor_15_groups_0 = const()[name = tensor("tensor_15_groups_0"), val = tensor(1)]; tensor tensor_15 = conv(bias = model_encoder_pre_encode_conv_6_bias, dilations = tensor_15_dilations_0, groups = tensor_15_groups_0, pad = tensor_15_pad_0, pad_type = tensor_15_pad_type_0, strides = tensor_15_strides_0, weight = model_encoder_pre_encode_conv_6_weight, x = input_13)[name = tensor("tensor_15")]; tensor input_15 = mul(x = tensor_15, y = expanded_mask_13)[name = tensor("input_15")]; tensor tensor_workaround = relu(x = input_15)[name = tensor("tensor_workaround")]; tensor x = mul(x = tensor_workaround, y = expanded_mask_13)[name = tensor("x")]; tensor var_228_perm_0 = const()[name = tensor("op_228_perm_0"), val = tensor([0, 2, 1, 3])]; tensor var_229 = const()[name = tensor("op_229"), val = tensor([1, 14, -1])]; tensor var_228 = transpose(perm = var_228_perm_0, x = x)[name = tensor("transpose_0")]; tensor input = reshape(shape = var_229, x = var_228)[name = tensor("input")]; tensor chunk_embs_in = linear(bias = model_encoder_pre_encode_out_bias, weight = model_encoder_pre_encode_out_weight, x = input)[name = tensor("linear_0")]; tensor var_241_dtype_0 = const()[name = tensor("op_241_dtype_0"), val = tensor("int32")]; tensor size0 = const()[name = tensor("size0"), val = tensor([188])]; tensor size1 = const()[name = tensor("size1"), val = tensor([40])]; tensor var_264 = const()[name = tensor("op_264"), val = tensor(1)]; tensor full_concat_interleave_0 = const()[name = tensor("full_concat_interleave_0"), val = tensor(false)]; tensor full_concat = concat(axis = var_264, interleave = full_concat_interleave_0, values = (spkcache, fifo, chunk_embs_in))[name = tensor("full_concat")]; tensor var_273 = add(x = spkcache_lengths, y = fifo_lengths)[name = tensor("op_273")]; tensor chunk_lens_in = cast(dtype = var_241_dtype_0, x = current_lengths)[name = tensor("cast_4")]; tensor pre_encoder_lengths = add(x = var_273, y = chunk_lens_in)[name = tensor("total_length")]; tensor out_pos = const()[name = tensor("out_pos"), val = tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241])]; tensor var_284 = greater_equal(x = out_pos, y = spkcache_lengths)[name = tensor("op_284")]; tensor in_seg1_or_2_dtype_0 = const()[name = tensor("in_seg1_or_2_dtype_0"), val = tensor("int32")]; tensor var_290 = greater_equal(x = out_pos, y = var_273)[name = tensor("op_290")]; tensor in_seg2_dtype_0 = const()[name = tensor("in_seg2_dtype_0"), val = tensor("int32")]; tensor var_297 = sub(x = size0, y = spkcache_lengths)[name = tensor("op_297")]; tensor in_seg1_or_2 = cast(dtype = in_seg1_or_2_dtype_0, x = var_284)[name = tensor("cast_3")]; tensor var_298 = mul(x = in_seg1_or_2, y = var_297)[name = tensor("op_298")]; tensor var_300 = sub(x = size1, y = fifo_lengths)[name = tensor("op_300")]; tensor in_seg2 = cast(dtype = in_seg2_dtype_0, x = var_290)[name = tensor("cast_2")]; tensor var_301 = mul(x = in_seg2, y = var_300)[name = tensor("op_301")]; tensor offset = add(x = var_298, y = var_301)[name = tensor("offset")]; tensor var_305 = add(x = out_pos, y = offset)[name = tensor("op_305")]; tensor var_309 = const()[name = tensor("op_309"), val = tensor(241)]; tensor var_310 = const()[name = tensor("op_310"), val = tensor(0)]; tensor minimum_0 = minimum(x = var_305, y = var_309)[name = tensor("minimum_0")]; tensor maximum_0 = maximum(x = minimum_0, y = var_310)[name = tensor("maximum_0")]; tensor var_313_axes_0 = const()[name = tensor("op_313_axes_0"), val = tensor([0])]; tensor var_313 = expand_dims(axes = var_313_axes_0, x = maximum_0)[name = tensor("op_313")]; tensor var_315_axes_0 = const()[name = tensor("op_315_axes_0"), val = tensor([-1])]; tensor var_315 = expand_dims(axes = var_315_axes_0, x = var_313)[name = tensor("op_315")]; tensor gather_idx_reps_0 = const()[name = tensor("gather_idx_reps_0"), val = tensor([1, 1, 512])]; tensor gather_idx = tile(reps = gather_idx_reps_0, x = var_315)[name = tensor("gather_idx")]; tensor var_320 = const()[name = tensor("op_320"), val = tensor(1)]; tensor output = gather_along_axis(axis = var_320, indices = gather_idx, x = full_concat)[name = tensor("output")]; tensor var_323 = less(x = out_pos, y = pre_encoder_lengths)[name = tensor("op_323")]; tensor var_328_dtype_0 = const()[name = tensor("op_328_dtype_0"), val = tensor("fp32")]; tensor var_330_axes_0 = const()[name = tensor("op_330_axes_0"), val = tensor([0])]; tensor var_328 = cast(dtype = var_328_dtype_0, x = var_323)[name = tensor("cast_1")]; tensor var_330 = expand_dims(axes = var_330_axes_0, x = var_328)[name = tensor("op_330")]; tensor var_332_axes_0 = const()[name = tensor("op_332_axes_0"), val = tensor([-1])]; tensor var_332 = expand_dims(axes = var_332_axes_0, x = var_330)[name = tensor("op_332")]; tensor pre_encoder_embs = mul(x = output, y = var_332)[name = tensor("op_333")]; } -> (pre_encoder_embs, pre_encoder_lengths, chunk_embs_in, chunk_lens_in); }