Update modeling_generanno.py

#1

FlashAttention 算子要求 bool 类型注意力掩码,否则将降级为原生实现,导致性能严重下降。

Thanks!

GenerTeam changed pull request status to merged

不,应该没关系,我们在GenerannoFlashAttention2中处理过了attention_mask

In line 650:
attention_mask = attention_mask.bool()

您说的对,但是因为华为的FA算子只适配了SDPA接口。
但是这个问题也仅影响attention_mask张量不全为1的情况,您模型中调用的_prepare_4d_attention_mask_for_sdpa
接口会将全1的张量处理成None,此时仍会调用FA算子。但是当出现有padding的输入会因为调用原生算子导致申请比较大的内存。

ok,我更新了一下modelling_generanno.py,之前的版本太老了,看下现在还有没有这种问题?还需不需要加那个?如果需要的话你再提一个pull request就好

另外我的小建议是,如果是SDPA接口的问题,最好把这个修改放在GenerannoSdpaAttention这个类里面,而不要直接写到GenerannoModel类里,否则在eager attention的情况下会报错,因为eager attention的attention mask不是布尔类型(需要attend的地方是0,需要屏蔽的地方是-float("inf")),虽然几乎不存在使用eager attention的场景,但代码还是应该尽量兼容

Sign up or log in to comment