| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| |
|
| | import torch |
| | from safetensors.torch import load_file |
| | from transformers import CLIPTextConfig, CLIPTextModelWithProjection |
| |
|
| |
|
| | class AniMemoryAltCLip(torch.nn.Module): |
| | def __init__(self, config: CLIPTextConfig): |
| | super().__init__() |
| | self.model_hf = CLIPTextModelWithProjection(config) |
| | self.linear_proj = torch.nn.Linear(in_features=1280, out_features=1280) |
| |
|
| | @classmethod |
| | def from_pretrained( |
| | cls, |
| | pretrained_model_name_or_path, |
| | subfolder="", |
| | linear_proj_name="weights.safetensors", |
| | torch_dtype=torch.float16, |
| | ): |
| | cls.dtype = torch_dtype |
| | config = CLIPTextModelWithProjection.config_class.from_pretrained( |
| | pretrained_model_name_or_path, subfolder=subfolder |
| | ) |
| | model = cls(config=config) |
| | model.model_hf = CLIPTextModelWithProjection.from_pretrained( |
| | pretrained_model_name_or_path, subfolder=subfolder |
| | ) |
| | linear_proj_state = load_file( |
| | os.path.join(pretrained_model_name_or_path, subfolder, linear_proj_name) |
| | ) |
| | model.linear_proj.load_state_dict(linear_proj_state) |
| | return model |
| |
|
| | def to(self, *args, **kwargs): |
| | device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( |
| | *args, **kwargs |
| | ) |
| | super(AniMemoryAltCLip, self).to(*args, **kwargs) |
| | self.dtype = dtype if dtype is not None else self.dtype |
| | self.device = device if device is not None else self.device |
| | return self |
| |
|
| | def expand_mask(self, mask=None, dtype="", tgt_len=None): |
| | """ |
| | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. |
| | """ |
| | bsz, src_len = mask.size() |
| | tgt_len = tgt_len if tgt_len is not None else src_len |
| |
|
| | expanded_mask = ( |
| | mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) |
| | ) |
| |
|
| | inverted_mask = 1.0 - expanded_mask |
| |
|
| | return inverted_mask.masked_fill( |
| | inverted_mask.to(torch.bool), torch.finfo(dtype).min |
| | ) |
| |
|
| | def make_attn_mask(self, attn_mask): |
| | seq_len = attn_mask.shape[1] |
| | query = attn_mask.unsqueeze(1).float() |
| | attn_mask = ( |
| | query.repeat([1, seq_len, 1]).unsqueeze(1).repeat([1, self.num_head, 1, 1]) |
| | ) |
| | attn_mask = attn_mask.view([-1, seq_len, seq_len]) |
| | return attn_mask |
| |
|
| | def gradient_checkpointing_enable( |
| | self, |
| | ): |
| | self.model_hf.gradient_checkpointing_enable() |
| |
|
| | def forward(self, text, attention_mask): |
| | hidden_states = self.model_hf.text_model.embeddings( |
| | input_ids=text, position_ids=None |
| | ) |
| | if attention_mask is None: |
| | print("Warning: attention_mask is None in altclip!") |
| | new_attn_mask = ( |
| | self.expand_mask(attention_mask, hidden_states.dtype) |
| | if attention_mask is not None |
| | else None |
| | ) |
| | encoder_outputs = self.model_hf.text_model.encoder( |
| | inputs_embeds=hidden_states, |
| | attention_mask=new_attn_mask, |
| | causal_attention_mask=None, |
| | output_attentions=False, |
| | output_hidden_states=True, |
| | return_dict=True, |
| | ) |
| | last_hidden_state = encoder_outputs[0] |
| | last_hidden_state = self.model_hf.text_model.final_layer_norm(last_hidden_state) |
| | last_hidden_state = ( |
| | last_hidden_state[torch.arange(last_hidden_state.shape[0]), 0] |
| | @ self.model_hf.text_projection.weight |
| | ) |
| | pooled_output = self.linear_proj(last_hidden_state) |
| |
|
| | extra_features = encoder_outputs.hidden_states[-2] |
| | extra_features = self.model_hf.text_model.final_layer_norm(extra_features) |
| | return extra_features, pooled_output |
| |
|