Question about the inference flow

#2
by Spedon - opened

Sorry, this might be a bit off-topic, but I've been trying to implement an inference flow for texify. The goal is to use as few large third-party libraries as possible, but I've run into a bit of a problem and stuck for a long time. The model's output is only correct for the first few times (the key_values from the first two or three forward passes match with optimum.pipeline), so I did some research and looked at the implementation in transformer.js (https://github.com/xenova/transformers.js/blob/880a2ccde11e0f1568894b8f671051ec7efa6ddd/src/models.js#L355), but I'm still not quite clear, I would be grateful if you take a look and see where the issue might be

Initially, we got the hidden_state from the encoder and prepare to feed some data to the decoder (onnxruntime session created with the merged decoder model). There are two scenarios:

  • No past_key_values: This is actually the first forward pass of the decoder, so use_cache_branch should be set to false. We input hidden_state, input_ids (which is actually just the bos_token_id), and a dummy input (pad with zeros) for past_key_values. At this point, we will get new token_id and present_key_value, which can be saved for later use.
  • With past_key_values: After the first forward pass is complete, all subsequent forwards should follow this step. The inputs are hidden_state, the input_ids and the present_key_values obtained from the previous step, and use_cache_branch should be set to true. The output will be the new token_id and present_key_value.

I think this should be the basic flow, but I just cannot get the correct result

while observing the model, I noticed one more thing: although the model's config indicates is_decoder_encoder is false, the input and output node names of the model are in the form of past_key_values.0.encoder.key, rather than the common past_key_values.0.key_value. According to https://github.com/xenova/transformers.js/blob/main/src/models.js#L1293, could the issue be incorrect shape for dummy input in first decoder forward?

I figured it out by reading through transformers.js source.

When use_cache_branch is true, the decoder output present.X.encoder.value and present.X.encoder.key are both empty tensors (shape [0, 16, 1, 64]), and they can't be directly fed into the next decode. These two keys and values should continue to use the results from the first decoder forward (where use_cache_branch is false) which has shape [1, 16, 196, 64].

Thank you for your outstanding work!

Spedon changed discussion status to closed

Sign up or log in to comment