TaliDror commited on
Commit ·
626735d
1
Parent(s): 5a68fdd
causal_attention_mask fix
Browse files
external/arc2face/models.py
CHANGED
|
@@ -11,7 +11,7 @@ except ImportError:
|
|
| 11 |
def _make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0):
|
| 12 |
batch_size, tgt_len = input_ids_shape
|
| 13 |
return _create_4d_causal_attention_mask(
|
| 14 |
-
|
| 15 |
dtype=dtype,
|
| 16 |
device=device,
|
| 17 |
past_key_values_length=past_key_values_length,
|
|
|
|
| 11 |
def _make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0):
|
| 12 |
batch_size, tgt_len = input_ids_shape
|
| 13 |
return _create_4d_causal_attention_mask(
|
| 14 |
+
input_shape=(batch_size, tgt_len),
|
| 15 |
dtype=dtype,
|
| 16 |
device=device,
|
| 17 |
past_key_values_length=past_key_values_length,
|