fix infer bug while use_flash_attention=False
Browse files- modeling_telechat.py +1 -1
modeling_telechat.py
CHANGED
|
@@ -257,7 +257,7 @@ class TELECHATAttention(nn.Module):
|
|
| 257 |
self.pruned_heads = set()
|
| 258 |
|
| 259 |
self.use_flash_attn = False
|
| 260 |
-
|
| 261 |
|
| 262 |
|
| 263 |
def set_max_positions(self, max_positions, device='cuda'):
|
|
|
|
| 257 |
self.pruned_heads = set()
|
| 258 |
|
| 259 |
self.use_flash_attn = False
|
| 260 |
+
self.is_cross_attention = False
|
| 261 |
|
| 262 |
|
| 263 |
def set_max_positions(self, max_positions, device='cuda'):
|