TaliDror commited on
Commit
626735d
·
1 Parent(s): 5a68fdd

causal_attention_mask fix

Browse files
Files changed (1) hide show
  1. external/arc2face/models.py +1 -1
external/arc2face/models.py CHANGED
@@ -11,7 +11,7 @@ except ImportError:
11
  def _make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0):
12
  batch_size, tgt_len = input_ids_shape
13
  return _create_4d_causal_attention_mask(
14
- input_ids_shape=(batch_size, tgt_len),
15
  dtype=dtype,
16
  device=device,
17
  past_key_values_length=past_key_values_length,
 
11
  def _make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0):
12
  batch_size, tgt_len = input_ids_shape
13
  return _create_4d_causal_attention_mask(
14
+ input_shape=(batch_size, tgt_len),
15
  dtype=dtype,
16
  device=device,
17
  past_key_values_length=past_key_values_length,