Added output_attentions: bool=False to GroupedQueryAttention.forward() as a temporary fix for AWQ
Browse files- attention.py +1 -1
attention.py
CHANGED
|
@@ -260,7 +260,7 @@ class GroupedQueryAttention(nn.Module):
|
|
| 260 |
self.out_proj = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model, **fc_kwargs)
|
| 261 |
self.out_proj._is_residual = True
|
| 262 |
|
| 263 |
-
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, is_causal: bool=True, needs_weights: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 264 |
qkv = self.Wqkv(x)
|
| 265 |
if self.clip_qkv:
|
| 266 |
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
|
|
|
| 260 |
self.out_proj = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model, **fc_kwargs)
|
| 261 |
self.out_proj._is_residual = True
|
| 262 |
|
| 263 |
+
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, is_causal: bool=True, output_attentions: bool=False, needs_weights: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 264 |
qkv = self.Wqkv(x)
|
| 265 |
if self.clip_qkv:
|
| 266 |
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|