|
|
|
|
|
|
|
|
|
|
| import math
|
|
|
| import torch
|
| from torch import nn
|
| import torch.nn.functional as F
|
| from einops import rearrange
|
| from einops.layers.torch import Rearrange
|
| from transformers import CLIPVisionModel, CLIPTokenizer
|
|
|
| import numpy as np
|
| from torch import einsum
|
| from dataclasses import dataclass
|
| from typing import Optional, Tuple
|
| from transformers.utils import ModelOutput
|
| from adaface.util import arc2face_inverse_face_prompt_embs, gen_gradient_scaler
|
| from adaface.arc2face_models import CLIPTextModelWrapper
|
| import sys
|
| sys.modules['ldm'] = sys.modules['adaface']
|
|
|
| def reshape_tensor(x, num_heads):
|
| bs, length, width = x.shape
|
|
|
| x = x.view(bs, length, num_heads, -1)
|
|
|
| x = x.transpose(1, 2)
|
|
|
| x = x.reshape(bs, num_heads, length, -1)
|
| return x
|
|
|
|
|
| def FeedForward(dim, mult=4, p_dropout=0.1):
|
| inner_dim = int(dim * mult)
|
| return nn.Sequential(
|
| nn.LayerNorm(dim),
|
| nn.Linear(dim, inner_dim, bias=False),
|
| nn.GELU(),
|
| nn.Linear(inner_dim, dim, bias=False),
|
| nn.Dropout(p_dropout),
|
| )
|
|
|
|
|
|
|
| class IP_MLPProjModel(nn.Module):
|
| def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
| super().__init__()
|
|
|
| self.cross_attention_dim = cross_attention_dim
|
| self.num_tokens = num_tokens
|
|
|
| self.proj = nn.Sequential(
|
| nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
| nn.GELU(),
|
| nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
| )
|
| self.norm = nn.LayerNorm(cross_attention_dim)
|
|
|
| def forward(self, id_embeds):
|
| x = self.proj(id_embeds)
|
| x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
| x = self.norm(x)
|
| return x
|
|
|
|
|
| class LearnedSoftAggregate(nn.Module):
|
| def __init__(self, num_feat, group_dim, keepdim=False):
|
| super(LearnedSoftAggregate, self).__init__()
|
| self.group_dim = group_dim
|
|
|
|
|
| self.num_feat = num_feat
|
| self.feat2score = nn.Linear(num_feat, 1, bias=False)
|
| self.keepdim = keepdim
|
|
|
| def forward(self, x, score_basis=None):
|
|
|
| if x.shape[self.group_dim] == 1:
|
| if self.keepdim:
|
| return x
|
| else:
|
| return x.squeeze(self.group_dim)
|
|
|
|
|
| if score_basis is None:
|
| score_basis = x
|
|
|
| if self.num_feat == 1:
|
| mode_scores = self.feat2score(score_basis.unsqueeze(-1)).squeeze(-1)
|
| else:
|
| mode_scores = self.feat2score(score_basis)
|
| attn_probs = mode_scores.softmax(dim=self.group_dim)
|
| x_aggr = (x * attn_probs).sum(dim=self.group_dim, keepdim=self.keepdim)
|
| return x_aggr
|
|
|
| def LoRA_ExpandEmbs(input_dim, lora_rank, output_dim, num_modes,
|
| num_output_vecs, elementwise_affine=True, p_dropout=0.1):
|
| return nn.Sequential(
|
|
|
|
|
| nn.Linear(input_dim, lora_rank * output_dim * num_modes, bias=False),
|
|
|
| Rearrange('b (m q d) -> b m q d', q=lora_rank, m=num_modes, d=output_dim),
|
| nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
|
|
| LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \
|
| else Rearrange('b () q d -> b q d'),
|
| nn.Dropout(p_dropout),
|
|
|
| Rearrange('b q d -> b d q'),
|
|
|
| nn.Linear(lora_rank, num_output_vecs, bias=False),
|
|
|
| Rearrange('b d q -> b q d'),
|
| nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
| nn.Dropout(p_dropout),
|
| )
|
|
|
| def ExpandEmbs(input_dim, output_dim, expansion_ratio, elementwise_affine=True, p_dropout=0.1):
|
| return nn.Sequential(
|
|
|
| nn.Linear(input_dim, expansion_ratio * output_dim, bias=False),
|
|
|
| Rearrange('b (e d) -> b e d', e=expansion_ratio, d=output_dim),
|
| nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
| nn.Dropout(p_dropout),
|
| )
|
|
|
|
|
| def MultimodeProjection(input_dim, output_dim=-1, num_modes=4, elementwise_affine=True, p_dropout=0.1):
|
| if output_dim == -1:
|
| output_dim = input_dim
|
|
|
| return nn.Sequential(
|
| nn.Linear(input_dim, output_dim * num_modes, bias=False),
|
|
|
| Rearrange('b n (m d) -> b n m d', m=num_modes, d=output_dim),
|
| nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
|
|
| LearnedSoftAggregate(num_feat=output_dim, group_dim=2, keepdim=False) if num_modes > 1 \
|
| else Rearrange('b n () d -> b n d'),
|
| nn.Dropout(p_dropout),
|
| )
|
|
|
|
|
| def Lora2Hira(lora_rank, hira_rank, output_dim, num_modes, elementwise_affine=True, p_dropout=0.1):
|
| return nn.Sequential(
|
|
|
| Rearrange('b q d -> b d q'),
|
|
|
| nn.Linear(lora_rank, hira_rank * num_modes, bias=False),
|
|
|
| Rearrange('b d (m q) -> b m q d', m=num_modes, q=hira_rank),
|
| nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
|
|
| LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \
|
| else Rearrange('b () q d -> b q d'),
|
| nn.Dropout(p_dropout),
|
| )
|
|
|
| class PerceiverAttention(nn.Module):
|
| def __init__(self, *, dim, dim_head=64, num_heads=8, elementwise_affine=True):
|
| super().__init__()
|
| self.scale = dim_head**-0.5
|
| self.dim_head = dim_head
|
| self.num_heads = num_heads
|
| inner_dim = dim_head * num_heads
|
|
|
| self.norm1 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine)
|
| self.norm2 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine)
|
|
|
| self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
| self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
|
|
| def forward(self, x, latent_queries):
|
| """
|
| Args:
|
| x (torch.Tensor): image features
|
| shape (b, n1, D)
|
| latent (torch.Tensor): latent features
|
| shape (b, n2, D)
|
| """
|
| x = self.norm1(x)
|
| latent_queries = self.norm2(latent_queries)
|
|
|
| b, l, _ = latent_queries.shape
|
|
|
| q = self.to_q(latent_queries)
|
| kv_input = torch.cat((x, latent_queries), dim=-2)
|
| k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
|
|
| q = reshape_tensor(q, self.num_heads)
|
| k = reshape_tensor(k, self.num_heads)
|
| v = reshape_tensor(v, self.num_heads)
|
|
|
|
|
| scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
| weight = (q * scale) @ (k * scale).transpose(-2, -1)
|
| attn = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| out = attn @ v
|
|
|
| out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
|
|
| return self.to_out(out)
|
|
|
|
|
| class CrossAttention(nn.Module):
|
|
|
|
|
|
|
| def __init__(self, input_dim, num_heads=6, p_dropout=0.05,
|
| identity_to_q=False, identity_to_k=False, identity_to_v=False, v_has_skip=True,
|
| q_aware_to_v=True, num_q=416, v_repeat=4, q_aware_to_v_lora_rank=64,
|
| identity_to_out=False, out_has_skip=False):
|
| super().__init__()
|
| dim_head = input_dim // num_heads
|
| inner_dim = dim_head * num_heads
|
|
|
| self.num_heads = num_heads
|
| self.q_aware_to_v = q_aware_to_v
|
| self.v_has_skip = v_has_skip
|
| self.to_q = nn.Sequential(
|
| nn.Linear(input_dim, inner_dim, bias=False),
|
| nn.LayerNorm(inner_dim, elementwise_affine=True)
|
| ) if not identity_to_q else nn.Identity()
|
| self.to_k = nn.Sequential(
|
| nn.Linear(input_dim, inner_dim, bias=False),
|
| nn.LayerNorm(inner_dim, elementwise_affine=True)
|
| ) if not identity_to_k else nn.Identity()
|
|
|
| self.v_repeat = v_repeat
|
| self.num_q_group = num_q_group = num_q // v_repeat
|
|
|
|
|
|
|
| if q_aware_to_v:
|
|
|
| all_q_mid = num_q_group * q_aware_to_v_lora_rank
|
| self.to_v = nn.Sequential(
|
|
|
|
|
|
|
| nn.Linear(input_dim, all_q_mid, bias=False),
|
| nn.LayerNorm(all_q_mid, elementwise_affine=True),
|
|
|
| Rearrange('b n q -> b q n', q=all_q_mid),
|
|
|
|
|
|
|
| nn.Conv1d(
|
| in_channels=all_q_mid,
|
| out_channels=num_q_group * input_dim,
|
| kernel_size=1,
|
| groups=num_q_group,
|
| bias=False,
|
| ),
|
|
|
| Rearrange('b (q d) n -> b q n d', q=num_q_group, d=input_dim),
|
| nn.LayerNorm(input_dim, elementwise_affine=True),
|
| )
|
| else:
|
| self.to_v = nn.Sequential(
|
| nn.Linear(input_dim, inner_dim, bias=False),
|
| nn.LayerNorm(inner_dim, elementwise_affine=True)
|
| ) if not identity_to_v else nn.Identity()
|
|
|
| if identity_to_out:
|
| assert not out_has_skip, "identity_to_out=True, then out_has_skip has to be False."
|
|
|
| if identity_to_out:
|
| self.to_out = nn.Identity()
|
| else:
|
| self.to_out = nn.Sequential(
|
| nn.Linear(input_dim, input_dim, bias=False),
|
| nn.Dropout(p_dropout),
|
| nn.LayerNorm(inner_dim, elementwise_affine=True)
|
| )
|
|
|
| self.out_has_skip = out_has_skip
|
| self.attn_drop = nn.Dropout(p_dropout)
|
|
|
| def forward(self, x, context=None, attn_mat=None, return_attn=False):
|
| h = self.num_heads
|
|
|
| if context is None:
|
| context = x
|
|
|
| if attn_mat is None:
|
|
|
| q = self.to_q(x)
|
|
|
| k = self.to_k(context)
|
|
|
| q, k = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k))
|
|
|
| if self.q_aware_to_v:
|
|
|
|
|
| v = self.to_v(context)
|
| if self.v_has_skip:
|
| v = v + context.unsqueeze(1)
|
| else:
|
|
|
| v = self.to_v(context)
|
| if self.v_has_skip:
|
| v = v + context
|
|
|
|
|
|
|
| if self.q_aware_to_v:
|
|
|
|
|
| v = rearrange(v, 'b q n (h d) -> (b h) q n d', h=h)
|
|
|
|
|
|
|
| v = v.repeat(1, self.v_repeat, 1, 1)
|
| else:
|
| v = rearrange(v, 'b n (h d) -> (b h) n d', h=h)
|
|
|
| if attn_mat is None:
|
| scale = q.size(-1) ** -0.25
|
| sim = einsum('b i d, b j d -> b i j', q * scale, k * scale)
|
|
|
|
|
|
|
|
|
| attn = sim.softmax(dim=-1)
|
| attn = self.attn_drop(attn)
|
|
|
| else:
|
| attn = attn_mat
|
|
|
| if self.q_aware_to_v:
|
|
|
|
|
| out = einsum('b i j, b i j d -> b i d', attn, v)
|
| else:
|
|
|
| out = einsum('b i j, b j d -> b i d', attn, v)
|
|
|
|
|
| out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
|
|
| if self.out_has_skip:
|
| out = self.to_out(out) + out
|
| else:
|
| out = self.to_out(out)
|
|
|
| if return_attn:
|
| return out, attn
|
| else:
|
| return out
|
|
|
| class SubjBasisGenerator(nn.Module):
|
| def __init__(
|
| self,
|
|
|
|
|
| num_heads=6,
|
| num_id_vecs={ 'subj': 77, 'bg': 257 },
|
| num_out_embs_per_layer=4,
|
| num_out_layers=16,
|
| image_embedding_dim=768,
|
|
|
|
|
| dino_embedding_dim=384,
|
| output_dim=768,
|
| placeholder_is_bg: bool = False,
|
| prompt2token_proj_grad_scale: float = 0.4,
|
| zs_extra_words_scale: float = 0.5,
|
| learnable_hidden_state_weights_scheme: str = 'per-layer',
|
| bg_prompt_translator_has_to_out_proj: bool = False,
|
| ):
|
| super().__init__()
|
|
|
| self.placeholder_is_bg = placeholder_is_bg
|
| self.num_out_layers = num_out_layers
|
| self.num_out_embs_per_layer = num_out_embs_per_layer
|
|
|
| self.num_out_embs = num_out_layers * num_out_embs_per_layer
|
| self.output_dim = output_dim
|
|
|
|
|
| self.num_id_vecs = num_id_vecs['bg'] if placeholder_is_bg else num_id_vecs['subj']
|
| self.pos_embs = nn.Parameter(torch.randn(1, self.num_id_vecs, output_dim))
|
| self.pos_embs_ln = nn.LayerNorm(output_dim)
|
| self.zs_extra_words_scale = zs_extra_words_scale
|
| self.output_scale = output_dim ** -0.5
|
| self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
| if not self.placeholder_is_bg:
|
|
|
|
|
| self.obj_proj_in = ExpandEmbs(dino_embedding_dim, output_dim, expansion_ratio=self.num_id_vecs)
|
|
|
|
|
|
|
| self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14')
|
| self.prompt2token_proj_grad_scale = prompt2token_proj_grad_scale
|
| self.prompt2token_proj_grad_scaler = gen_gradient_scaler(prompt2token_proj_grad_scale)
|
| print(f"Subj prompt2token_proj initialized with grad scale of {prompt2token_proj_grad_scale}.")
|
|
|
|
|
| if prompt2token_proj_grad_scale == 0:
|
| self.freeze_prompt2token_proj()
|
|
|
| self.prompt2token_proj_attention_multiplier = -1
|
| self.initialize_hidden_state_layer_weights(learnable_hidden_state_weights_scheme, 'cpu')
|
| self.pad_embeddings = None
|
| self.bg_proj_in = None
|
| else:
|
|
|
| self.obj_proj_in = None
|
| self.prompt2token_proj = None
|
| print("Bg prompt2token_proj is set to None.")
|
|
|
| self.bg_proj_in = nn.Sequential(
|
| nn.Linear(image_embedding_dim, output_dim, bias=False),
|
| nn.LayerNorm(output_dim),
|
| )
|
|
|
| self.latent_queries = nn.Parameter(torch.randn(1, self.num_out_embs, output_dim))
|
| self.latent_queries_ln = nn.LayerNorm(output_dim)
|
|
|
| self.bg_prompt_translator_has_to_out_proj = bg_prompt_translator_has_to_out_proj
|
| identity_to_v = False
|
| v_has_skip = not identity_to_v
|
| identity_to_out = not bg_prompt_translator_has_to_out_proj
|
| out_has_skip = not identity_to_out
|
|
|
|
|
| self.prompt_translator = \
|
| CrossAttention(input_dim=output_dim, num_heads=num_heads, p_dropout=0.05,
|
| identity_to_q=False, identity_to_k=False, identity_to_v=identity_to_v,
|
| q_aware_to_v=False, v_has_skip=v_has_skip,
|
| num_q=0,
|
| identity_to_out=identity_to_out,
|
| out_has_skip=out_has_skip)
|
| '''
|
| prompt_translator: CLIPEncoder
|
| # https://github.com/huggingface/transformers/blob/1872bde7fc6a5d6796bd742bc2dc38eaf8069c5d/src/transformers/models/clip/modeling_clip.py#L566
|
| # CLIPEncoder.layers: 12 layers of CLIPEncoderLayer, each being
|
| (0): CLIPEncoderLayer(
|
| (self_attn): CLIPAttention(
|
| (k_proj): Linear(in_features=768, out_features=768, bias=True)
|
| (v_proj): Linear(in_features=768, out_features=768, bias=True)
|
| (q_proj): Linear(in_features=768, out_features=768, bias=True)
|
| (out_proj): Linear(in_features=768, out_features=768, bias=True)
|
| )
|
| (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| (mlp): CLIPMLP(
|
| (activation_fn): QuickGELUActivation()
|
| (fc1): Linear(in_features=768, out_features=3072, bias=True)
|
| (fc2): Linear(in_features=3072, out_features=768, bias=True)
|
| )
|
| (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
| )
|
| '''
|
|
|
| print(repr(self))
|
|
|
|
|
|
|
|
|
| def forward(self, arc2face_id_embs, clip_features=None, raw_id_embs=None, out_id_embs_scale=1.0,
|
| is_face=True, is_training=False, adaface_prompt_embs_inf_type='full_half_pad'):
|
|
|
| if not self.placeholder_is_bg:
|
| BS = arc2face_id_embs.shape[0]
|
| else:
|
|
|
| BS = clip_features.shape[0]
|
|
|
| adaface_prompt_embs = None
|
| if not hasattr(self, 'clip_tokenizer'):
|
| self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
|
| if not self.placeholder_is_bg:
|
| if is_face:
|
| assert arc2face_id_embs is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights)
|
|
|
|
|
|
|
|
|
| if is_training:
|
| return_emb_types = ['full_pad', 'core']
|
| else:
|
|
|
| return_emb_types = [adaface_prompt_embs_inf_type, 'core']
|
|
|
| if self.pad_embeddings is None:
|
| self.generate_pad_embeddings()
|
| else:
|
| self.pad_embeddings = self.pad_embeddings.to(arc2face_id_embs.device)
|
|
|
| with torch.set_grad_enabled(self.training and self.prompt2token_proj_grad_scale != 0):
|
|
|
|
|
|
|
|
|
|
|
|
|
| adaface_prompt_embs, core_id_embs = \
|
| arc2face_inverse_face_prompt_embs(self.clip_tokenizer,
|
| self.prompt2token_proj,
|
| arc2face_id_embs,
|
| list_extra_words=None,
|
| return_emb_types=return_emb_types,
|
| pad_embeddings=self.pad_embeddings,
|
| hidden_state_layer_weights=hidden_state_layer_weights,
|
| input_max_length=77, zs_extra_words_scale=self.zs_extra_words_scale)
|
|
|
| adaface_prompt_embs = self.prompt2token_proj_grad_scaler(adaface_prompt_embs)
|
| core_id_embs = self.prompt2token_proj_grad_scaler(core_id_embs)
|
| elif raw_id_embs is not None:
|
|
|
|
|
|
|
| id_embs = self.obj_proj_in(raw_id_embs)
|
| else:
|
| breakpoint()
|
| else:
|
|
|
|
|
| id_embs = self.bg_proj_in(clip_features)
|
|
|
| if self.placeholder_is_bg:
|
| id_embs = id_embs + self.pos_embs_ln(self.pos_embs)
|
| latent_queries = self.latent_queries_ln(self.latent_queries).repeat(BS, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
| with torch.set_grad_enabled(self.training):
|
| id_embs_out = self.prompt_translator(latent_queries, id_embs)
|
|
|
| id_embs_out = id_embs_out.reshape(BS, self.num_out_layers, -1, self.output_dim)
|
| adaface_subj_embs = id_embs_out * self.output_scale
|
| else:
|
|
|
| adaface_subj_embs = core_id_embs.unsqueeze(1).repeat(1, self.num_out_layers, 1, 1)
|
|
|
|
|
| if out_id_embs_scale != 1:
|
|
|
| pad_embeddings = self.pad_embeddings[4:4+self.num_out_embs_per_layer].unsqueeze(0).unsqueeze(0)
|
| adaface_subj_embs = adaface_subj_embs * out_id_embs_scale \
|
| + pad_embeddings * (1 - out_id_embs_scale)
|
|
|
| return adaface_subj_embs, adaface_prompt_embs
|
|
|
| def initialize_hidden_state_layer_weights(self, learnable_hidden_state_weights_scheme, device):
|
| if learnable_hidden_state_weights_scheme == 'none':
|
| self.hidden_state_layer_weights = None
|
|
|
| self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(1)
|
| print("hidden_state_layer_weights is set to None.")
|
|
|
| elif learnable_hidden_state_weights_scheme == 'per-layer':
|
|
|
|
|
|
|
| self.hidden_state_layer_weights = nn.Parameter(torch.tensor([[1.0], [2.0], [4.0]], device=device),
|
| requires_grad=True)
|
| self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(5)
|
| print("hidden_state_layer_weights initialized as per-layer [1, 2, 4], with grad scaler 5.")
|
| else:
|
| breakpoint()
|
|
|
| def generate_pad_embeddings(self):
|
|
|
|
|
|
|
| clip_embeddings = self.prompt2token_proj.text_model.embeddings
|
|
|
|
|
|
|
|
|
|
|
| pad_tokens = torch.tensor([self.clip_tokenizer.pad_token_id]).to(clip_embeddings.token_embedding.weight.device).repeat(77)
|
|
|
| pad_embeddings = clip_embeddings(pad_tokens)[0]
|
|
|
|
|
| self.pad_embeddings = pad_embeddings.detach()
|
|
|
| def extend_prompt2token_proj_attention(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1):
|
| if multiplier > 1:
|
| num_extended_layers = self.prompt2token_proj.extend_clip_attention_MKV_multiplier(begin_layer_idx, end_layer_idx, multiplier, noise_std)
|
| self.prompt2token_proj_attention_multiplier = multiplier
|
| print(f"{num_extended_layers} layers in prompt2token_proj_attention are x{multiplier}")
|
|
|
| def freeze_prompt2token_proj(self):
|
|
|
|
|
| if self.prompt2token_proj is not None:
|
| frozen_param_names = []
|
| for param_name, param in self.prompt2token_proj.named_parameters():
|
| if param.requires_grad:
|
| param.requires_grad = False
|
| frozen_param_names.append(param_name)
|
|
|
| print(f"{len(frozen_param_names)} params in Subj prompt2token_proj is frozen.")
|
|
|
|
|
| def __repr__(self):
|
| type_sig = 'subj' if not self.placeholder_is_bg else 'bg'
|
|
|
| if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'):
|
| self.bg_prompt_translator_has_to_out_proj = False
|
| if not hasattr(self, 'num_out_embs'):
|
| self.num_out_embs = -1
|
| return f"{type_sig} SubjBasisGenerator: num_out_embs={self.num_out_embs}, " \
|
| f"bg_prompt_translator_has_to_out_proj={self.bg_prompt_translator_has_to_out_proj}"
|
|
|
| @dataclass
|
| class BaseModelOutputWithPooling2(ModelOutput):
|
| """
|
| Base class for model's outputs that also contains a pooling of the last hidden states.
|
|
|
| Args:
|
| last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| Sequence of hidden-states at the output of the last layer of the model.
|
| pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
|
| Last layer hidden-state of the first token of the sequence (classification token) after further processing
|
| through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
|
| the classification token after processing through a linear layer and a tanh activation function. The linear
|
| layer weights are trained from the next sentence prediction (classification) objective during pretraining.
|
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
|
|
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| sequence_length)`.
|
|
|
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| heads.
|
| """
|
|
|
| last_hidden_state: torch.FloatTensor = None
|
| pooler_output: torch.FloatTensor = None
|
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| attn_mask: Optional[torch.FloatTensor] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| def CLIPVisionTransformer_forward(self, pixel_values = None, attn_mask=None,
|
| output_attentions = None,
|
| output_hidden_states = None, return_dict = None):
|
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| output_hidden_states = (
|
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| )
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
| if pixel_values is None:
|
| raise ValueError("You have to specify pixel_values")
|
|
|
|
|
|
|
|
|
|
|
| hidden_states = self.embeddings(pixel_values)
|
| hidden_states = self.pre_layrnorm(hidden_states)
|
|
|
| if attn_mask is not None:
|
|
|
| feat_edge_size = np.sqrt(hidden_states.shape[1] - 1).astype(int)
|
|
|
| attn_mask = F.interpolate(attn_mask.unsqueeze(1), size=(feat_edge_size, feat_edge_size), mode='nearest')
|
|
|
| attn_mask = attn_mask.flatten(2)
|
|
|
|
|
| attn_mask = torch.cat([torch.ones_like(attn_mask[:, :, :1]), attn_mask], dim=-1)
|
| attn_mask_pairs = torch.matmul(attn_mask.transpose(-1, -2), attn_mask).unsqueeze(1)
|
| else:
|
| attn_mask_pairs = None
|
|
|
|
|
| encoder_outputs = self.encoder(
|
| inputs_embeds=hidden_states,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| attention_mask=attn_mask_pairs,
|
| output_attentions=output_attentions,
|
| output_hidden_states=output_hidden_states,
|
| return_dict=return_dict,
|
| )
|
|
|
|
|
| last_hidden_state = encoder_outputs[0]
|
| pooled_output = last_hidden_state[:, 0, :]
|
| pooled_output = self.post_layernorm(pooled_output)
|
|
|
|
|
| if not return_dict:
|
| return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
|
|
| return BaseModelOutputWithPooling2(
|
| last_hidden_state=last_hidden_state,
|
| pooler_output=pooled_output,
|
| hidden_states=encoder_outputs.hidden_states,
|
| attentions=encoder_outputs.attentions,
|
|
|
|
|
| attn_mask=attn_mask.permute(0, 2, 1) if attn_mask is not None else None
|
| )
|
|
|
|
|
| class CLIPVisionModelWithMask(CLIPVisionModel):
|
| def __init__(self, config):
|
| super().__init__(config)
|
|
|
| self.vision_model.forward = CLIPVisionTransformer_forward.__get__(self.vision_model)
|
|
|
| def forward(self, pixel_values = None, attn_mask = None, output_attentions = None,
|
| output_hidden_states = None, return_dict = None):
|
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
| return self.vision_model(
|
| pixel_values=pixel_values,
|
| attn_mask=attn_mask,
|
| output_attentions=output_attentions,
|
| output_hidden_states=output_hidden_states,
|
| return_dict=return_dict,
|
| )
|
|
|
|
|