Batched generation (batch_size > 1) produces incorrect outputs β€” possible causal mask issue?

#9
by vconchel - opened

Generation isn't working properly when batch_size > 1 for me, the longest sample of the batch is normally generated, but the rest are full of spaces and repeating a lot of words. Is this a common issue?

I solved it by changing lines 567-583 in modeling_ouro.py from

mask_kwargs = {
    "config": self.config,
    "input_embeds": inputs_embeds,
    "attention_mask": attention_mask,
    "cache_position": cache_position,
    "past_key_values": past_key_values,
    "position_ids": position_ids,
}
# Create the masks
causal_mask_mapping = {
    "full_attention": create_causal_mask(**mask_kwargs),
}
# The sliding window alternating layers are not always activated depending on the config
if self.has_sliding_layers:
    causal_mask_mapping["sliding_attention"] = (
        create_sliding_window_causal_mask(**mask_kwargs)
    )

to

mask_kwargs = {
    "attention_mask": attention_mask,
    "input_shape": inputs_embeds.shape[:2],
    "inputs_embeds": inputs_embeds,
    "past_key_values_length": past_key_values.get_seq_length() if past_key_values is not None else 0
}
# Create the masks
causal_mask_mapping = {
    "full_attention": _prepare_4d_causal_attention_mask(**mask_kwargs),
}

# The sliding window alternating layers are not always activated depending on the config
if self.has_sliding_layers:
    causal_mask_mapping["sliding_attention"] = _prepare_4d_causal_attention_mask(
        **mask_kwargs,
        sliding_window=self.config["sliding_window"]
    )

Is there a more straightforward solution?

Sign up or log in to comment