TaliDror commited on
Commit
7806057
·
1 Parent(s): 24e094e

dependency fix

Browse files
Files changed (1) hide show
  1. external/arc2face/models.py +25 -1
external/arc2face/models.py CHANGED
@@ -2,8 +2,32 @@ import torch
2
  from transformers import CLIPTextModel
3
  from typing import Any, Callable, Dict, Optional, Tuple, Union, List
4
  from transformers.modeling_outputs import BaseModelOutputWithPooling
5
- from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
 
 
 
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class CLIPTextModelWrapper(CLIPTextModel):
9
  # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
 
2
  from transformers import CLIPTextModel
3
  from typing import Any, Callable, Dict, Optional, Tuple, Union, List
4
  from transformers.modeling_outputs import BaseModelOutputWithPooling
5
+ #from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
6
+ try:
7
+ from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
8
+ except ImportError:
9
+ from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask
10
 
11
+ def _make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0):
12
+ batch_size, tgt_len = input_ids_shape
13
+ return _create_4d_causal_attention_mask(
14
+ input_ids_shape=(batch_size, tgt_len),
15
+ dtype=dtype,
16
+ device=device,
17
+ past_key_values_length=past_key_values_length,
18
+ )
19
+
20
+ def _expand_mask(mask, dtype, tgt_len=None):
21
+ batch_size, src_len = mask.shape
22
+ tgt_len = tgt_len if tgt_len is not None else src_len
23
+
24
+ expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len)
25
+ inverted_mask = 1.0 - expanded_mask.to(dtype)
26
+
27
+ return inverted_mask.masked_fill(
28
+ inverted_mask.to(torch.bool),
29
+ torch.finfo(dtype).min,
30
+ )
31
 
32
  class CLIPTextModelWrapper(CLIPTextModel):
33
  # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812