program(1.0) [buildInfo = dict, tensor>({{"coremlc-component-MIL", "3510.2.1"}, {"coremlc-version", "3500.32.1"}, {"coremltools-component-torch", "2.7.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "9.0b1"}})] { func main(tensor decoder_step, tensor encoder_step) { tensor input_1_perm_0 = const()[name = tensor("input_1_perm_0"), val = tensor([0, 2, 1])]; tensor encoder_step_to_fp16_dtype_0 = const()[name = tensor("encoder_step_to_fp16_dtype_0"), val = tensor("fp16")]; tensor input_3_perm_0 = const()[name = tensor("input_3_perm_0"), val = tensor([0, 2, 1])]; tensor decoder_step_to_fp16_dtype_0 = const()[name = tensor("decoder_step_to_fp16_dtype_0"), val = tensor("fp16")]; tensor joint_module_enc_weight_to_fp16 = const()[name = tensor("joint_module_enc_weight_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(64)))]; tensor joint_module_enc_bias_to_fp16 = const()[name = tensor("joint_module_enc_bias_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(1310848)))]; tensor encoder_step_to_fp16 = cast(dtype = encoder_step_to_fp16_dtype_0, x = encoder_step)[name = tensor("cast_3")]; tensor input_1_cast_fp16 = transpose(perm = input_1_perm_0, x = encoder_step_to_fp16)[name = tensor("transpose_1")]; tensor linear_0_cast_fp16 = linear(bias = joint_module_enc_bias_to_fp16, weight = joint_module_enc_weight_to_fp16, x = input_1_cast_fp16)[name = tensor("linear_0_cast_fp16")]; tensor joint_module_pred_weight_to_fp16 = const()[name = tensor("joint_module_pred_weight_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(1312192)))]; tensor joint_module_pred_bias_to_fp16 = const()[name = tensor("joint_module_pred_bias_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(2131456)))]; tensor decoder_step_to_fp16 = cast(dtype = decoder_step_to_fp16_dtype_0, x = decoder_step)[name = tensor("cast_2")]; tensor input_3_cast_fp16 = transpose(perm = input_3_perm_0, x = decoder_step_to_fp16)[name = tensor("transpose_0")]; tensor linear_1_cast_fp16 = linear(bias = joint_module_pred_bias_to_fp16, weight = joint_module_pred_weight_to_fp16, x = input_3_cast_fp16)[name = tensor("linear_1_cast_fp16")]; tensor var_23_axes_0 = const()[name = tensor("op_23_axes_0"), val = tensor([2])]; tensor var_23_cast_fp16 = expand_dims(axes = var_23_axes_0, x = linear_0_cast_fp16)[name = tensor("op_23_cast_fp16")]; tensor var_24_axes_0 = const()[name = tensor("op_24_axes_0"), val = tensor([1])]; tensor var_24_cast_fp16 = expand_dims(axes = var_24_axes_0, x = linear_1_cast_fp16)[name = tensor("op_24_cast_fp16")]; tensor input_5_cast_fp16 = add(x = var_23_cast_fp16, y = var_24_cast_fp16)[name = tensor("input_5_cast_fp16")]; tensor input_7_cast_fp16 = relu(x = input_5_cast_fp16)[name = tensor("input_7_cast_fp16")]; tensor joint_module_joint_net_2_weight_to_fp16 = const()[name = tensor("joint_module_joint_net_2_weight_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(2132800)))]; tensor joint_module_joint_net_2_bias_to_fp16 = const()[name = tensor("joint_module_joint_net_2_bias_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(12626304)))]; tensor linear_2_cast_fp16 = linear(bias = joint_module_joint_net_2_bias_to_fp16, weight = joint_module_joint_net_2_weight_to_fp16, x = input_7_cast_fp16)[name = tensor("linear_2_cast_fp16")]; tensor var_35_begin_0 = const()[name = tensor("op_35_begin_0"), val = tensor([0, 0, 0, 0])]; tensor var_35_end_0 = const()[name = tensor("op_35_end_0"), val = tensor([1, 1, 1, 8193])]; tensor var_35_end_mask_0 = const()[name = tensor("op_35_end_mask_0"), val = tensor([true, true, true, false])]; tensor var_35_cast_fp16 = slice_by_index(begin = var_35_begin_0, end = var_35_end_0, end_mask = var_35_end_mask_0, x = linear_2_cast_fp16)[name = tensor("op_35_cast_fp16")]; tensor var_35_cast_fp16_to_fp32_dtype_0 = const()[name = tensor("op_35_cast_fp16_to_fp32_dtype_0"), val = tensor("fp32")]; tensor var_40_begin_0 = const()[name = tensor("op_40_begin_0"), val = tensor([0, 0, 0, 8193])]; tensor var_40_end_0 = const()[name = tensor("op_40_end_0"), val = tensor([1, 1, 1, 8198])]; tensor var_40_end_mask_0 = const()[name = tensor("op_40_end_mask_0"), val = tensor([true, true, true, true])]; tensor var_40_cast_fp16 = slice_by_index(begin = var_40_begin_0, end = var_40_end_0, end_mask = var_40_end_mask_0, x = linear_2_cast_fp16)[name = tensor("op_40_cast_fp16")]; tensor var_40_cast_fp16_to_fp32_dtype_0 = const()[name = tensor("op_40_cast_fp16_to_fp32_dtype_0"), val = tensor("fp32")]; tensor duration_logits = cast(dtype = var_40_cast_fp16_to_fp32_dtype_0, x = var_40_cast_fp16)[name = tensor("cast_0")]; tensor token_logits = cast(dtype = var_35_cast_fp16_to_fp32_dtype_0, x = var_35_cast_fp16)[name = tensor("cast_1")]; } -> (token_logits, duration_logits); }