repeat input_ids.size(1) in dim1

#3
by kangyang - opened

modeling_longcat_next.py, function prepare_inputs_for_generation,

input_ids = input_ids.repeat((2, input_ids.size(1)))
attention_mask = attention_mask.repeat((2, attention_mask.size(1)))

or

input_ids = input_ids.repeat((2, 1))
attention_mask = attention_mask.repeat((2, 1))

why do you repeat input_ids.size(1) in dim1?

LongCat org

Actually, we repeat input_ids's size at dim=0.

(For the simplicity of implementation, inference currently only supports batch_size=1.) However, during image generation, the CFG strategy is used, which usually needs an unconditional model forward. Therefore, we repeat the input_ids/attention_mask's batch_dim to 2 (masking the condition text tokens of input_ids[1]), combining two forward passes into one.

But the current implementation is actually insufficient; that is, after generating an image, it is impossible to output text or other content after the image (because the batch_size needs to be switched to 1, and the kv_cache needs to be remanaged). We will leave these challenges for the next iteration.

Sign up or log in to comment