TaliDror commited on
Commit
5a68fdd
·
1 Parent(s): 7806057

transformer version clip fix

Browse files
Files changed (1) hide show
  1. external/arc2face/models.py +14 -16
external/arc2face/models.py CHANGED
@@ -32,6 +32,7 @@ except ImportError:
32
  class CLIPTextModelWrapper(CLIPTextModel):
33
  # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
34
  # Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them.
 
35
  def forward(
36
  self,
37
  input_ids: Optional[torch.Tensor] = None,
@@ -44,16 +45,16 @@ class CLIPTextModelWrapper(CLIPTextModel):
44
  return_token_embs: Optional[bool] = False,
45
  ) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]:
46
 
 
 
 
 
47
  if return_token_embs:
48
- return self.text_model.embeddings.token_embedding(input_ids)
49
 
50
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
51
-
52
- output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions
53
- output_hidden_states = (
54
- output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states
55
- )
56
- return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict
57
 
58
  if input_ids is None:
59
  raise ValueError("You have to specify input_ids")
@@ -61,17 +62,13 @@ class CLIPTextModelWrapper(CLIPTextModel):
61
  input_shape = input_ids.size()
62
  input_ids = input_ids.view(-1, input_shape[-1])
63
 
64
- hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs)
65
 
66
- # CLIP's text model uses causal mask, prepare it here.
67
- # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
68
  causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
69
- # expand attention_mask
70
  if attention_mask is not None:
71
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
72
  attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
73
 
74
- encoder_outputs = self.text_model.encoder(
75
  inputs_embeds=hidden_states,
76
  attention_mask=attention_mask,
77
  causal_attention_mask=causal_attention_mask,
@@ -81,9 +78,10 @@ class CLIPTextModelWrapper(CLIPTextModel):
81
  )
82
 
83
  last_hidden_state = encoder_outputs[0]
84
- last_hidden_state = self.text_model.final_layer_norm(last_hidden_state)
85
 
86
- if self.text_model.eos_token_id == 2:
 
87
  pooled_output = last_hidden_state[
88
  torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
89
  input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
@@ -91,7 +89,7 @@ class CLIPTextModelWrapper(CLIPTextModel):
91
  else:
92
  pooled_output = last_hidden_state[
93
  torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
94
- (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id)
95
  .int()
96
  .argmax(dim=-1),
97
  ]
 
32
  class CLIPTextModelWrapper(CLIPTextModel):
33
  # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
34
  # Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them.
35
+ # Supports both transformers <=4.46 (self.text_model sub-attribute) and >=4.47 (flat structure, no text_model).
36
  def forward(
37
  self,
38
  input_ids: Optional[torch.Tensor] = None,
 
45
  return_token_embs: Optional[bool] = False,
46
  ) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]:
47
 
48
+ # In transformers <=4.46 the transformer lives in self.text_model;
49
+ # in >=4.47 it was inlined directly onto CLIPTextModel (flat structure).
50
+ tm = getattr(self, 'text_model', self)
51
+
52
  if return_token_embs:
53
+ return tm.embeddings.token_embedding(input_ids)
54
 
55
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
56
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
57
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
58
 
59
  if input_ids is None:
60
  raise ValueError("You have to specify input_ids")
 
62
  input_shape = input_ids.size()
63
  input_ids = input_ids.view(-1, input_shape[-1])
64
 
65
+ hidden_states = tm.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs)
66
 
 
 
67
  causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
 
68
  if attention_mask is not None:
 
69
  attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
70
 
71
+ encoder_outputs = tm.encoder(
72
  inputs_embeds=hidden_states,
73
  attention_mask=attention_mask,
74
  causal_attention_mask=causal_attention_mask,
 
78
  )
79
 
80
  last_hidden_state = encoder_outputs[0]
81
+ last_hidden_state = tm.final_layer_norm(last_hidden_state)
82
 
83
+ eos_token_id = getattr(tm, 'eos_token_id', self.config.eos_token_id)
84
+ if eos_token_id == 2:
85
  pooled_output = last_hidden_state[
86
  torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
87
  input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
 
89
  else:
90
  pooled_output = last_hidden_state[
91
  torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
92
+ (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == eos_token_id)
93
  .int()
94
  .argmax(dim=-1),
95
  ]