program(1.0) [buildInfo = dict, tensor>({{"coremlc-component-MIL", "3520.4.1"}, {"coremlc-version", "3520.5.1"}, {"coremltools-component-torch", "2.5.1"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.3.0"}})] { func main(tensor c_in, tensor encoder, tensor h_in, tensor token, tensor token_length) { tensor y_axis_0 = const()[name = tensor("y_axis_0"), val = tensor(0)]; tensor y_batch_dims_0 = const()[name = tensor("y_batch_dims_0"), val = tensor(0)]; tensor y_validate_indices_0 = const()[name = tensor("y_validate_indices_0"), val = tensor(false)]; tensor decoder_module_prediction_embed_weight_to_fp16 = const()[name = tensor("decoder_module_prediction_embed_weight_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(64)))]; tensor token_to_int16_dtype_0 = const()[name = tensor("token_to_int16_dtype_0"), val = tensor("int16")]; tensor token_to_int16 = cast(dtype = token_to_int16_dtype_0, x = token)[name = tensor("cast_9")]; tensor y_cast_fp16_cast_uint16 = gather(axis = y_axis_0, batch_dims = y_batch_dims_0, indices = token_to_int16, validate_indices = y_validate_indices_0, x = decoder_module_prediction_embed_weight_to_fp16)[name = tensor("y_cast_fp16_cast_uint16")]; tensor input_3_perm_0 = const()[name = tensor("input_3_perm_0"), val = tensor([1, 0, 2])]; tensor split_0_num_splits_0 = const()[name = tensor("split_0_num_splits_0"), val = tensor(2)]; tensor split_0_axis_0 = const()[name = tensor("split_0_axis_0"), val = tensor(0)]; tensor h_in_to_fp16_dtype_0 = const()[name = tensor("h_in_to_fp16_dtype_0"), val = tensor("fp16")]; tensor h_in_to_fp16 = cast(dtype = h_in_to_fp16_dtype_0, x = h_in)[name = tensor("cast_8")]; tensor split_0_cast_fp16_0, tensor split_0_cast_fp16_1 = split(axis = split_0_axis_0, num_splits = split_0_num_splits_0, x = h_in_to_fp16)[name = tensor("split_0_cast_fp16")]; tensor split_1_num_splits_0 = const()[name = tensor("split_1_num_splits_0"), val = tensor(2)]; tensor split_1_axis_0 = const()[name = tensor("split_1_axis_0"), val = tensor(0)]; tensor c_in_to_fp16_dtype_0 = const()[name = tensor("c_in_to_fp16_dtype_0"), val = tensor("fp16")]; tensor c_in_to_fp16 = cast(dtype = c_in_to_fp16_dtype_0, x = c_in)[name = tensor("cast_7")]; tensor split_1_cast_fp16_0, tensor split_1_cast_fp16_1 = split(axis = split_1_axis_0, num_splits = split_1_num_splits_0, x = c_in_to_fp16)[name = tensor("split_1_cast_fp16")]; tensor input_5_lstm_layer_0_lstm_h0_squeeze_axes_0 = const()[name = tensor("input_5_lstm_layer_0_lstm_h0_squeeze_axes_0"), val = tensor([0])]; tensor input_5_lstm_layer_0_lstm_h0_squeeze_cast_fp16 = squeeze(axes = input_5_lstm_layer_0_lstm_h0_squeeze_axes_0, x = split_0_cast_fp16_0)[name = tensor("input_5_lstm_layer_0_lstm_h0_squeeze_cast_fp16")]; tensor input_5_lstm_layer_0_lstm_c0_squeeze_axes_0 = const()[name = tensor("input_5_lstm_layer_0_lstm_c0_squeeze_axes_0"), val = tensor([0])]; tensor input_5_lstm_layer_0_lstm_c0_squeeze_cast_fp16 = squeeze(axes = input_5_lstm_layer_0_lstm_c0_squeeze_axes_0, x = split_1_cast_fp16_0)[name = tensor("input_5_lstm_layer_0_lstm_c0_squeeze_cast_fp16")]; tensor input_5_lstm_layer_0_direction_0 = const()[name = tensor("input_5_lstm_layer_0_direction_0"), val = tensor("forward")]; tensor input_5_lstm_layer_0_output_sequence_0 = const()[name = tensor("input_5_lstm_layer_0_output_sequence_0"), val = tensor(true)]; tensor input_5_lstm_layer_0_recurrent_activation_0 = const()[name = tensor("input_5_lstm_layer_0_recurrent_activation_0"), val = tensor("sigmoid")]; tensor input_5_lstm_layer_0_cell_activation_0 = const()[name = tensor("input_5_lstm_layer_0_cell_activation_0"), val = tensor("tanh")]; tensor input_5_lstm_layer_0_activation_0 = const()[name = tensor("input_5_lstm_layer_0_activation_0"), val = tensor("tanh")]; tensor concat_1_to_fp16 = const()[name = tensor("concat_1_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(3621248)))]; tensor concat_2_to_fp16 = const()[name = tensor("concat_2_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(6898112)))]; tensor concat_0_to_fp16 = const()[name = tensor("concat_0_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(10174976)))]; tensor input_3_cast_fp16 = transpose(perm = input_3_perm_0, x = y_cast_fp16_cast_uint16)[name = tensor("transpose_4")]; tensor input_5_lstm_layer_0_cast_fp16_0, tensor input_5_lstm_layer_0_cast_fp16_1, tensor input_5_lstm_layer_0_cast_fp16_2 = lstm(activation = input_5_lstm_layer_0_activation_0, bias = concat_0_to_fp16, cell_activation = input_5_lstm_layer_0_cell_activation_0, direction = input_5_lstm_layer_0_direction_0, initial_c = input_5_lstm_layer_0_lstm_c0_squeeze_cast_fp16, initial_h = input_5_lstm_layer_0_lstm_h0_squeeze_cast_fp16, output_sequence = input_5_lstm_layer_0_output_sequence_0, recurrent_activation = input_5_lstm_layer_0_recurrent_activation_0, weight_hh = concat_2_to_fp16, weight_ih = concat_1_to_fp16, x = input_3_cast_fp16)[name = tensor("input_5_lstm_layer_0_cast_fp16")]; tensor input_5_lstm_h0_squeeze_axes_0 = const()[name = tensor("input_5_lstm_h0_squeeze_axes_0"), val = tensor([0])]; tensor input_5_lstm_h0_squeeze_cast_fp16 = squeeze(axes = input_5_lstm_h0_squeeze_axes_0, x = split_0_cast_fp16_1)[name = tensor("input_5_lstm_h0_squeeze_cast_fp16")]; tensor input_5_lstm_c0_squeeze_axes_0 = const()[name = tensor("input_5_lstm_c0_squeeze_axes_0"), val = tensor([0])]; tensor input_5_lstm_c0_squeeze_cast_fp16 = squeeze(axes = input_5_lstm_c0_squeeze_axes_0, x = split_1_cast_fp16_1)[name = tensor("input_5_lstm_c0_squeeze_cast_fp16")]; tensor input_5_direction_0 = const()[name = tensor("input_5_direction_0"), val = tensor("forward")]; tensor input_5_output_sequence_0 = const()[name = tensor("input_5_output_sequence_0"), val = tensor(true)]; tensor input_5_recurrent_activation_0 = const()[name = tensor("input_5_recurrent_activation_0"), val = tensor("sigmoid")]; tensor input_5_cell_activation_0 = const()[name = tensor("input_5_cell_activation_0"), val = tensor("tanh")]; tensor input_5_activation_0 = const()[name = tensor("input_5_activation_0"), val = tensor("tanh")]; tensor concat_4_to_fp16 = const()[name = tensor("concat_4_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(10180160)))]; tensor concat_5_to_fp16 = const()[name = tensor("concat_5_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(13457024)))]; tensor concat_3_to_fp16 = const()[name = tensor("concat_3_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(16733888)))]; tensor input_5_cast_fp16_0, tensor input_5_cast_fp16_1, tensor input_5_cast_fp16_2 = lstm(activation = input_5_activation_0, bias = concat_3_to_fp16, cell_activation = input_5_cell_activation_0, direction = input_5_direction_0, initial_c = input_5_lstm_c0_squeeze_cast_fp16, initial_h = input_5_lstm_h0_squeeze_cast_fp16, output_sequence = input_5_output_sequence_0, recurrent_activation = input_5_recurrent_activation_0, weight_hh = concat_5_to_fp16, weight_ih = concat_4_to_fp16, x = input_5_lstm_layer_0_cast_fp16_0)[name = tensor("input_5_cast_fp16")]; tensor obj_3_axis_0 = const()[name = tensor("obj_3_axis_0"), val = tensor(0)]; tensor obj_3_cast_fp16 = stack(axis = obj_3_axis_0, values = (input_5_lstm_layer_0_cast_fp16_1, input_5_cast_fp16_1))[name = tensor("obj_3_cast_fp16")]; tensor obj_3_cast_fp16_to_fp32_dtype_0 = const()[name = tensor("obj_3_cast_fp16_to_fp32_dtype_0"), val = tensor("fp32")]; tensor obj_axis_0 = const()[name = tensor("obj_axis_0"), val = tensor(0)]; tensor obj_cast_fp16 = stack(axis = obj_axis_0, values = (input_5_lstm_layer_0_cast_fp16_2, input_5_cast_fp16_2))[name = tensor("obj_cast_fp16")]; tensor obj_cast_fp16_to_fp32_dtype_0 = const()[name = tensor("obj_cast_fp16_to_fp32_dtype_0"), val = tensor("fp32")]; tensor transpose_1_perm_0 = const()[name = tensor("transpose_1_perm_0"), val = tensor([1, 0, 2])]; tensor input_7_perm_0 = const()[name = tensor("input_7_perm_0"), val = tensor([0, 2, 1])]; tensor encoder_to_fp16_dtype_0 = const()[name = tensor("encoder_to_fp16_dtype_0"), val = tensor("fp16")]; tensor joint_module_enc_weight_to_fp16 = const()[name = tensor("joint_module_enc_weight_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(16739072)))]; tensor joint_module_enc_bias_to_fp16 = const()[name = tensor("joint_module_enc_bias_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(18049856)))]; tensor encoder_to_fp16 = cast(dtype = encoder_to_fp16_dtype_0, x = encoder)[name = tensor("cast_4")]; tensor input_7_cast_fp16 = transpose(perm = input_7_perm_0, x = encoder_to_fp16)[name = tensor("transpose_2")]; tensor linear_0_cast_fp16 = linear(bias = joint_module_enc_bias_to_fp16, weight = joint_module_enc_weight_to_fp16, x = input_7_cast_fp16)[name = tensor("linear_0_cast_fp16")]; tensor joint_module_pred_weight_to_fp16 = const()[name = tensor("joint_module_pred_weight_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(18051200)))]; tensor joint_module_pred_bias_to_fp16 = const()[name = tensor("joint_module_pred_bias_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(18870464)))]; tensor transpose_1_cast_fp16 = transpose(perm = transpose_1_perm_0, x = input_5_cast_fp16_0)[name = tensor("transpose_3")]; tensor linear_1_cast_fp16 = linear(bias = joint_module_pred_bias_to_fp16, weight = joint_module_pred_weight_to_fp16, x = transpose_1_cast_fp16)[name = tensor("linear_1_cast_fp16")]; tensor var_79_axes_0 = const()[name = tensor("op_79_axes_0"), val = tensor([2])]; tensor var_79_cast_fp16 = expand_dims(axes = var_79_axes_0, x = linear_0_cast_fp16)[name = tensor("op_79_cast_fp16")]; tensor var_80_axes_0 = const()[name = tensor("op_80_axes_0"), val = tensor([1])]; tensor var_80_cast_fp16 = expand_dims(axes = var_80_axes_0, x = linear_1_cast_fp16)[name = tensor("op_80_cast_fp16")]; tensor input_11_cast_fp16 = add(x = var_79_cast_fp16, y = var_80_cast_fp16)[name = tensor("input_11_cast_fp16")]; tensor input_13_cast_fp16 = relu(x = input_11_cast_fp16)[name = tensor("input_13_cast_fp16")]; tensor joint_module_joint_net_2_weight_to_fp16 = const()[name = tensor("joint_module_joint_net_2_weight_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(18871808)))]; tensor joint_module_joint_net_2_bias_to_fp16 = const()[name = tensor("joint_module_joint_net_2_bias_to_fp16"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(22492992)))]; tensor linear_2_cast_fp16 = linear(bias = joint_module_joint_net_2_bias_to_fp16, weight = joint_module_joint_net_2_weight_to_fp16, x = input_13_cast_fp16)[name = tensor("linear_2_cast_fp16")]; tensor linear_2_cast_fp16_to_fp32_dtype_0 = const()[name = tensor("linear_2_cast_fp16_to_fp32_dtype_0"), val = tensor("fp32")]; tensor logits = cast(dtype = linear_2_cast_fp16_to_fp32_dtype_0, x = linear_2_cast_fp16)[name = tensor("cast_3")]; tensor c_out = cast(dtype = obj_cast_fp16_to_fp32_dtype_0, x = obj_cast_fp16)[name = tensor("cast_5")]; tensor h_out = cast(dtype = obj_3_cast_fp16_to_fp32_dtype_0, x = obj_3_cast_fp16)[name = tensor("cast_6")]; tensor token_length_tmp = identity(x = token_length)[name = tensor("token_length_tmp")]; } -> (logits, h_out, c_out); }