TaliDror commited on
Commit ·
7806057
1
Parent(s): 24e094e
dependency fix
Browse files- external/arc2face/models.py +25 -1
external/arc2face/models.py
CHANGED
|
@@ -2,8 +2,32 @@ import torch
|
|
| 2 |
from transformers import CLIPTextModel
|
| 3 |
from typing import Any, Callable, Dict, Optional, Tuple, Union, List
|
| 4 |
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
| 5 |
-
from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
class CLIPTextModelWrapper(CLIPTextModel):
|
| 9 |
# Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
|
|
|
|
| 2 |
from transformers import CLIPTextModel
|
| 3 |
from typing import Any, Callable, Dict, Optional, Tuple, Union, List
|
| 4 |
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
| 5 |
+
#from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
|
| 6 |
+
try:
|
| 7 |
+
from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
|
| 8 |
+
except ImportError:
|
| 9 |
+
from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask
|
| 10 |
|
| 11 |
+
def _make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0):
|
| 12 |
+
batch_size, tgt_len = input_ids_shape
|
| 13 |
+
return _create_4d_causal_attention_mask(
|
| 14 |
+
input_ids_shape=(batch_size, tgt_len),
|
| 15 |
+
dtype=dtype,
|
| 16 |
+
device=device,
|
| 17 |
+
past_key_values_length=past_key_values_length,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
def _expand_mask(mask, dtype, tgt_len=None):
|
| 21 |
+
batch_size, src_len = mask.shape
|
| 22 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 23 |
+
|
| 24 |
+
expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len)
|
| 25 |
+
inverted_mask = 1.0 - expanded_mask.to(dtype)
|
| 26 |
+
|
| 27 |
+
return inverted_mask.masked_fill(
|
| 28 |
+
inverted_mask.to(torch.bool),
|
| 29 |
+
torch.finfo(dtype).min,
|
| 30 |
+
)
|
| 31 |
|
| 32 |
class CLIPTextModelWrapper(CLIPTextModel):
|
| 33 |
# Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
|