program(1.0) [buildInfo = dict, tensor>({{"coremlc-component-MIL", "3510.2.1"}, {"coremlc-version", "3500.32.1"}, {"coremltools-component-torch", "2.10.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.1"}})] { func main(tensor decoder_output, tensor encoder_output) { tensor var_6 = const()[name = tensor("op_6"), val = tensor(-1)]; tensor joint_enc_weight_to_fp16 = const()[name = tensor("joint_enc_weight_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(64)))]; tensor joint_enc_bias_to_fp16 = const()[name = tensor("joint_enc_bias_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(1310848)))]; tensor linear_0_cast_fp16 = linear(bias = joint_enc_bias_to_fp16, weight = joint_enc_weight_to_fp16, x = encoder_output)[name = tensor("linear_0_cast_fp16")]; tensor joint_pred_weight_to_fp16 = const()[name = tensor("joint_pred_weight_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(1312192)))]; tensor joint_pred_bias_to_fp16 = const()[name = tensor("joint_pred_bias_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(2131456)))]; tensor linear_1_cast_fp16 = linear(bias = joint_pred_bias_to_fp16, weight = joint_pred_weight_to_fp16, x = decoder_output)[name = tensor("linear_1_cast_fp16")]; tensor f_3_axes_0 = const()[name = tensor("f_3_axes_0"), val = tensor([2])]; tensor f_3_cast_fp16 = expand_dims(axes = f_3_axes_0, x = linear_0_cast_fp16)[name = tensor("f_3_cast_fp16")]; tensor g_3_axes_0 = const()[name = tensor("g_3_axes_0"), val = tensor([1])]; tensor g_3_cast_fp16 = expand_dims(axes = g_3_axes_0, x = linear_1_cast_fp16)[name = tensor("g_3_cast_fp16")]; tensor input_3_cast_fp16 = add(x = f_3_cast_fp16, y = g_3_cast_fp16)[name = tensor("input_3_cast_fp16")]; tensor var_28_cast_fp16 = relu(x = input_3_cast_fp16)[name = tensor("op_28_cast_fp16")]; tensor joint_joint_net_2_weight_to_fp16 = const()[name = tensor("joint_joint_net_2_weight_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(2132800)))]; tensor joint_joint_net_2_bias_to_fp16 = const()[name = tensor("joint_joint_net_2_bias_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(12626304)))]; tensor linear_2_cast_fp16 = linear(bias = joint_joint_net_2_bias_to_fp16, weight = joint_joint_net_2_weight_to_fp16, x = var_28_cast_fp16)[name = tensor("linear_2_cast_fp16")]; tensor combined_1_softmax_cast_fp16 = softmax(axis = var_6, x = linear_2_cast_fp16)[name = tensor("combined_1_softmax_cast_fp16")]; tensor combined_1_epsilon_0 = const()[name = tensor("combined_1_epsilon_0"), val = tensor(0x1p-149)]; tensor combined_1_cast_fp16 = log(epsilon = combined_1_epsilon_0, x = combined_1_softmax_cast_fp16)[name = tensor("combined_1_cast_fp16")]; tensor combined0_1_axes_0 = const()[name = tensor("combined0_1_axes_0"), val = tensor([2])]; tensor combined0_1_cast_fp16 = squeeze(axes = combined0_1_axes_0, x = combined_1_cast_fp16)[name = tensor("combined0_1_cast_fp16")]; tensor var_35_begin_0 = const()[name = tensor("op_35_begin_0"), val = tensor([0, 0, 0])]; tensor var_35_end_0 = const()[name = tensor("op_35_end_0"), val = tensor([1, 1, 8193])]; tensor var_35_end_mask_0 = const()[name = tensor("op_35_end_mask_0"), val = tensor([true, true, false])]; tensor token_logits = slice_by_index(begin = var_35_begin_0, end = var_35_end_0, end_mask = var_35_end_mask_0, x = combined0_1_cast_fp16)[name = tensor("op_35_cast_fp16")]; tensor var_36_begin_0 = const()[name = tensor("op_36_begin_0"), val = tensor([0, 0, 8193])]; tensor var_36_end_0 = const()[name = tensor("op_36_end_0"), val = tensor([1, 1, 8198])]; tensor var_36_end_mask_0 = const()[name = tensor("op_36_end_mask_0"), val = tensor([true, true, true])]; tensor duration_logits = slice_by_index(begin = var_36_begin_0, end = var_36_end_0, end_mask = var_36_end_mask_0, x = combined0_1_cast_fp16)[name = tensor("op_36_cast_fp16")]; } -> (token_logits, duration_logits); }