Question about the inference flow
Sorry, this might be a bit off-topic, but I've been trying to implement an inference flow for texify. The goal is to use as few large third-party libraries as possible, but I've run into a bit of a problem and stuck for a long time. The model's output is only correct for the first few times (the key_values from the first two or three forward passes match with optimum.pipeline), so I did some research and looked at the implementation in transformer.js (https://github.com/xenova/transformers.js/blob/880a2ccde11e0f1568894b8f671051ec7efa6ddd/src/models.js#L355), but I'm still not quite clear, I would be grateful if you take a look and see where the issue might be
Initially, we got the hidden_state from the encoder and prepare to feed some data to the decoder (onnxruntime session created with the merged decoder model). There are two scenarios:
- No
past_key_values: This is actually the first forward pass of the decoder, souse_cache_branchshould be set to false. We inputhidden_state,input_ids(which is actually just the bos_token_id), and a dummy input (pad with zeros) forpast_key_values. At this point, we will get newtoken_idandpresent_key_value, which can be saved for later use. - With
past_key_values: After the first forward pass is complete, all subsequent forwards should follow this step. The inputs arehidden_state, theinput_idsand thepresent_key_valuesobtained from the previous step, anduse_cache_branchshould be set to true. The output will be the newtoken_idandpresent_key_value.
I think this should be the basic flow, but I just cannot get the correct result
while observing the model, I noticed one more thing: although the model's config indicates
is_decoder_encoderis false, the input and output node names of the model are in the form ofpast_key_values.0.encoder.key, rather than the commonpast_key_values.0.key_value. According to https://github.com/xenova/transformers.js/blob/main/src/models.js#L1293, could the issue be incorrect shape for dummy input in first decoder forward?
I figured it out by reading through transformers.js source.
When use_cache_branch is true, the decoder output present.X.encoder.value and present.X.encoder.key are both empty tensors (shape [0, 16, 1, 64]), and they can't be directly fed into the next decode. These two keys and values should continue to use the results from the first decoder forward (where use_cache_branch is false) which has shape [1, 16, 196, 64].
Thank you for your outstanding work!