TaliDror commited on
Commit
780f1aa
·
1 Parent(s): 39db2c4

fix to _make_causal_mask and _expand_mask

Browse files
Files changed (1) hide show
  1. external/arc2face/models.py +17 -18
external/arc2face/models.py CHANGED
@@ -6,28 +6,27 @@ from transformers.modeling_outputs import BaseModelOutputWithPooling
6
  try:
7
  from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
8
  except ImportError:
9
- from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask
10
-
 
11
  def _make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0):
12
- batch_size, tgt_len = input_ids_shape
13
- return _create_4d_causal_attention_mask(
14
- input_shape=(batch_size, tgt_len),
15
- dtype=dtype,
16
- device=device,
17
- past_key_values_length=past_key_values_length,
18
- )
 
 
 
19
 
20
  def _expand_mask(mask, dtype, tgt_len=None):
21
- batch_size, src_len = mask.shape
22
  tgt_len = tgt_len if tgt_len is not None else src_len
23
-
24
- expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len)
25
- inverted_mask = 1.0 - expanded_mask.to(dtype)
26
-
27
- return inverted_mask.masked_fill(
28
- inverted_mask.to(torch.bool),
29
- torch.finfo(dtype).min,
30
- )
31
 
32
  class CLIPTextModelWrapper(CLIPTextModel):
33
  # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
 
6
  try:
7
  from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
8
  except ImportError:
9
+ # transformers >=4.47 removed these internal helpers from modeling_clip.
10
+ # Reimplement them directly from the transformers 4.34 source so the mask
11
+ # format (additive, shape [bsz,1,tgt,src]) matches what CLIPEncoder expects.
12
  def _make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0):
13
+ bsz, tgt_len = input_ids_shape
14
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
15
+ mask_cond = torch.arange(tgt_len, device=device)
16
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(tgt_len, 1), 0)
17
+ mask = mask.to(dtype)
18
+ if past_key_values_length > 0:
19
+ mask = torch.cat(
20
+ [torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1
21
+ )
22
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
23
 
24
  def _expand_mask(mask, dtype, tgt_len=None):
25
+ bsz, src_len = mask.shape
26
  tgt_len = tgt_len if tgt_len is not None else src_len
27
+ expanded = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
28
+ inverted = 1.0 - expanded
29
+ return inverted.masked_fill(inverted.to(torch.bool), torch.finfo(dtype).min)
 
 
 
 
 
30
 
31
  class CLIPTextModelWrapper(CLIPTextModel):
32
  # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812