TaliDror commited on
Commit ·
780f1aa
1
Parent(s): 39db2c4
fix to _make_causal_mask and _expand_mask
Browse files- external/arc2face/models.py +17 -18
external/arc2face/models.py
CHANGED
|
@@ -6,28 +6,27 @@ from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
| 6 |
try:
|
| 7 |
from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
|
| 8 |
except ImportError:
|
| 9 |
-
|
| 10 |
-
|
|
|
|
| 11 |
def _make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0):
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def _expand_mask(mask, dtype, tgt_len=None):
|
| 21 |
-
|
| 22 |
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
return inverted_mask.masked_fill(
|
| 28 |
-
inverted_mask.to(torch.bool),
|
| 29 |
-
torch.finfo(dtype).min,
|
| 30 |
-
)
|
| 31 |
|
| 32 |
class CLIPTextModelWrapper(CLIPTextModel):
|
| 33 |
# Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
|
|
|
|
| 6 |
try:
|
| 7 |
from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
|
| 8 |
except ImportError:
|
| 9 |
+
# transformers >=4.47 removed these internal helpers from modeling_clip.
|
| 10 |
+
# Reimplement them directly from the transformers 4.34 source so the mask
|
| 11 |
+
# format (additive, shape [bsz,1,tgt,src]) matches what CLIPEncoder expects.
|
| 12 |
def _make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0):
|
| 13 |
+
bsz, tgt_len = input_ids_shape
|
| 14 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
| 15 |
+
mask_cond = torch.arange(tgt_len, device=device)
|
| 16 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(tgt_len, 1), 0)
|
| 17 |
+
mask = mask.to(dtype)
|
| 18 |
+
if past_key_values_length > 0:
|
| 19 |
+
mask = torch.cat(
|
| 20 |
+
[torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1
|
| 21 |
+
)
|
| 22 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
| 23 |
|
| 24 |
def _expand_mask(mask, dtype, tgt_len=None):
|
| 25 |
+
bsz, src_len = mask.shape
|
| 26 |
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 27 |
+
expanded = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
| 28 |
+
inverted = 1.0 - expanded
|
| 29 |
+
return inverted.masked_fill(inverted.to(torch.bool), torch.finfo(dtype).min)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
class CLIPTextModelWrapper(CLIPTextModel):
|
| 32 |
# Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
|