Spaces:
Sleeping
Sleeping
Deal with float values
Browse files- llama_diffusion_model.py +2 -1
llama_diffusion_model.py
CHANGED
|
@@ -31,7 +31,8 @@ class BidirectionalLlamaAttention(LlamaAttention):
|
|
| 31 |
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 32 |
|
| 33 |
if attention_mask is not None:
|
| 34 |
-
attn_mask =
|
|
|
|
| 35 |
attn_weights = attn_weights + attn_mask
|
| 36 |
|
| 37 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)
|
|
|
|
| 31 |
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 32 |
|
| 33 |
if attention_mask is not None:
|
| 34 |
+
attn_mask = (1.0 - attention_mask) * float('-inf')
|
| 35 |
+
attn_mask = attn_mask.to(dtype=query.dtype)
|
| 36 |
attn_weights = attn_weights + attn_mask
|
| 37 |
|
| 38 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)
|