| import torch |
| from transformers import CLIPTextModel |
| from typing import Any, Callable, Dict, Optional, Tuple, Union, List |
| from transformers.modeling_outputs import BaseModelOutputWithPooling |
| from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask |
|
|
|
|
| class CLIPTextModelWrapper(CLIPTextModel): |
| |
| |
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| input_token_embs: Optional[torch.Tensor] = None, |
| return_token_embs: Optional[bool] = False, |
| ) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]: |
|
|
| if return_token_embs: |
| return self.text_model.embeddings.token_embedding(input_ids) |
| |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
| output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict |
| |
| if input_ids is None: |
| raise ValueError("You have to specify input_ids") |
| |
| input_shape = input_ids.size() |
| input_ids = input_ids.view(-1, input_shape[-1]) |
| |
| hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs) |
| |
| |
| |
| causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) |
| |
| if attention_mask is not None: |
| |
| attention_mask = _expand_mask(attention_mask, hidden_states.dtype) |
| |
| encoder_outputs = self.text_model.encoder( |
| inputs_embeds=hidden_states, |
| attention_mask=attention_mask, |
| causal_attention_mask=causal_attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| |
| last_hidden_state = encoder_outputs[0] |
| last_hidden_state = self.text_model.final_layer_norm(last_hidden_state) |
| |
| if self.text_model.eos_token_id == 2: |
| |
| |
| |
| |
| |
| |
| pooled_output = last_hidden_state[ |
| torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), |
| input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), |
| ] |
| else: |
| |
| pooled_output = last_hidden_state[ |
| torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), |
| |
| (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id) |
| .int() |
| .argmax(dim=-1), |
| ] |
| |
| if not return_dict: |
| return (last_hidden_state, pooled_output) + encoder_outputs[1:] |
| |
| return BaseModelOutputWithPooling( |
| last_hidden_state=last_hidden_state, |
| pooler_output=pooled_output, |
| hidden_states=encoder_outputs.hidden_states, |
| attentions=encoder_outputs.attentions, |
| ) |