Spaces:
Running
on
Zero
Running
on
Zero
| import functools | |
| import logging | |
| import os | |
| from collections import OrderedDict | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint as checkpoint | |
| from pkg_resources import packaging | |
| from torch import nn | |
| from .simple_tokenizer import SimpleTokenizer as _Tokenizer | |
| logger = logging.getLogger(__name__) | |
| # On P1, model extracted from https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K | |
| MODEL_PATH = 'https://huggingface.co/laion' | |
| _MODELS = { | |
| "ViT-L/14": os.path.join(MODEL_PATH, "CLIP-ViT-L-14-DataComp.XL-s13B-b90K", "vit_l14_text.pth"), | |
| "ViT-B/16": os.path.join(MODEL_PATH, "CLIP-ViT-B-16-DataComp.XL-s13B-b90K", "vit_b16_text.pth"), | |
| } | |
| class LayerNorm(nn.LayerNorm): | |
| """Subclass torch's LayerNorm to handle fp16.""" | |
| def forward(self, x: torch.Tensor): | |
| orig_type = x.dtype | |
| ret = super().forward(x.type(torch.float32)) | |
| return ret.type(orig_type) | |
| class QuickGELU(nn.Module): | |
| def forward(self, x: torch.Tensor): | |
| return x * torch.sigmoid(1.702 * x) | |
| class ResidualAttentionBlock(nn.Module): | |
| def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): | |
| super().__init__() | |
| self.attn = nn.MultiheadAttention(d_model, n_head) | |
| self.ln_1 = LayerNorm(d_model) | |
| self.mlp = nn.Sequential(OrderedDict([ | |
| ("c_fc", nn.Linear(d_model, d_model * 4)), | |
| ("gelu", QuickGELU()), | |
| ("c_proj", nn.Linear(d_model * 4, d_model)) | |
| ])) | |
| self.ln_2 = LayerNorm(d_model) | |
| self.attn_mask = attn_mask | |
| def attention(self, x: torch.Tensor): | |
| self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None | |
| return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] | |
| def forward(self, x: torch.Tensor): | |
| x = x + self.attention(self.ln_1(x)) | |
| x = x + self.mlp(self.ln_2(x)) | |
| return x | |
| class Transformer(nn.Module): | |
| def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, | |
| checkpoint_num: int = 0): | |
| super().__init__() | |
| self.width = width | |
| self.layers = layers | |
| self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) | |
| self.checkpoint_num = checkpoint_num | |
| def forward(self, x: torch.Tensor): | |
| if self.checkpoint_num > 0: | |
| segments = min(self.checkpoint_num, len(self.resblocks)) | |
| return checkpoint.checkpoint_sequential(self.resblocks, segments, x) | |
| else: | |
| return self.resblocks(x) | |
| class CLIP_TEXT(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| context_length: int, | |
| vocab_size: int, | |
| transformer_width: int, | |
| transformer_heads: int, | |
| transformer_layers: int, | |
| checkpoint_num: int, | |
| ): | |
| super().__init__() | |
| self.context_length = context_length | |
| self._tokenizer = _Tokenizer() | |
| self.transformer = Transformer( | |
| width=transformer_width, | |
| layers=transformer_layers, | |
| heads=transformer_heads, | |
| attn_mask=self.build_attention_mask(), | |
| checkpoint_num=checkpoint_num, | |
| ) | |
| self.vocab_size = vocab_size | |
| self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |
| self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) | |
| self.ln_final = LayerNorm(transformer_width) | |
| self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) | |
| def no_weight_decay(self): | |
| return {'token_embedding', 'positional_embedding'} | |
| def build_attention_mask(self): | |
| # lazily create causal attention mask, with full attention between the vision tokens | |
| # pytorch uses additive attention mask; fill with -inf | |
| mask = torch.empty(self.context_length, self.context_length) | |
| mask.fill_(float("-inf")) | |
| mask.triu_(1) # zero out the lower diagonal | |
| return mask | |
| def tokenize(self, texts, context_length=77, truncate=True): | |
| """ | |
| Returns the tokenized representation of given input string(s) | |
| Parameters | |
| ---------- | |
| texts : Union[str, List[str]] | |
| An input string or a list of input strings to tokenize | |
| context_length : int | |
| The context length to use; all CLIP models use 77 as the context length | |
| truncate: bool | |
| Whether to truncate the text in case its encoding is longer than the context length | |
| Returns | |
| ------- | |
| A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. | |
| We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. | |
| """ | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| sot_token = self._tokenizer.encoder["<|startoftext|>"] | |
| eot_token = self._tokenizer.encoder["<|endoftext|>"] | |
| all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts] | |
| if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): | |
| result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |
| else: | |
| result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) | |
| for i, tokens in enumerate(all_tokens): | |
| if len(tokens) > context_length: | |
| if truncate: | |
| tokens = tokens[:context_length] | |
| tokens[-1] = eot_token | |
| else: | |
| raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") | |
| result[i, :len(tokens)] = torch.tensor(tokens) | |
| return result | |
| def forward(self, text): | |
| x = self.token_embedding(text) # [batch_size, n_ctx, d_model] | |
| x = x + self.positional_embedding | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| x = self.transformer(x) | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| x = self.ln_final(x) | |
| # x.shape = [batch_size, n_ctx, transformer.width] | |
| # take features from the eot embedding (eot_token is the highest number in each sequence) | |
| x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection | |
| return x | |
| def clip_text_b16( | |
| embed_dim=512, | |
| context_length=77, | |
| vocab_size=49408, | |
| transformer_width=512, | |
| transformer_heads=8, | |
| transformer_layers=12, | |
| checkpoint_num=0, | |
| pretrained=True, | |
| ): | |
| # raise NotImplementedError | |
| model = CLIP_TEXT( | |
| embed_dim, | |
| context_length, | |
| vocab_size, | |
| transformer_width, | |
| transformer_heads, | |
| transformer_layers, | |
| checkpoint_num, | |
| ) | |
| # pretrained = _MODELS["ViT-B/16"] | |
| # logger.info(f"Load pretrained weights from {pretrained}") | |
| # state_dict = torch.load(pretrained, map_location='cpu') | |
| # model.load_state_dict(state_dict, strict=False) | |
| # return model.eval() | |
| if pretrained: | |
| if isinstance(pretrained, str) and pretrained != "bert-base-uncased": | |
| pretrained = _MODELS[pretrained] | |
| else: | |
| pretrained = _MODELS["ViT-B/16"] | |
| logger.info(f"Load pretrained weights from {pretrained}") | |
| state_dict = torch.load(pretrained, map_location='cpu') | |
| if context_length != state_dict["positional_embedding"].size(0): | |
| # assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length." | |
| print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}") | |
| if context_length < state_dict["positional_embedding"].size(0): | |
| state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length] | |
| else: | |
| state_dict["positional_embedding"] = F.pad( | |
| state_dict["positional_embedding"], | |
| (0, 0, 0, context_length - state_dict["positional_embedding"].size(0)), | |
| value=0, | |
| ) | |
| message = model.load_state_dict(state_dict, strict=False) | |
| print(f"Load pretrained weights from {pretrained}: {message}") | |
| return model.eval() | |
| def clip_text_l14( | |
| embed_dim=768, | |
| context_length=77, | |
| vocab_size=49408, | |
| transformer_width=768, | |
| transformer_heads=12, | |
| transformer_layers=12, | |
| checkpoint_num=0, | |
| pretrained=True, | |
| ): | |
| model = CLIP_TEXT( | |
| embed_dim, | |
| context_length, | |
| vocab_size, | |
| transformer_width, | |
| transformer_heads, | |
| transformer_layers, | |
| checkpoint_num, | |
| ) | |
| if pretrained: | |
| if isinstance(pretrained, str) and pretrained != "bert-base-uncased": | |
| pretrained = _MODELS[pretrained] | |
| else: | |
| pretrained = _MODELS["ViT-L/14"] | |
| logger.info(f"Load pretrained weights from {pretrained}") | |
| state_dict = torch.load(pretrained, map_location='cpu') | |
| if context_length != state_dict["positional_embedding"].size(0): | |
| # assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length." | |
| print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}") | |
| if context_length < state_dict["positional_embedding"].size(0): | |
| state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length] | |
| else: | |
| state_dict["positional_embedding"] = F.pad( | |
| state_dict["positional_embedding"], | |
| (0, 0, 0, context_length - state_dict["positional_embedding"].size(0)), | |
| value=0, | |
| ) | |
| message = model.load_state_dict(state_dict, strict=False) | |
| print(f"Load pretrained weights from {pretrained}: {message}") | |
| return model.eval() | |
| def clip_text_l14_336( | |
| embed_dim=768, | |
| context_length=77, | |
| vocab_size=49408, | |
| transformer_width=768, | |
| transformer_heads=12, | |
| transformer_layers=12, | |
| ): | |
| raise NotImplementedError | |
| model = CLIP_TEXT( | |
| embed_dim, | |
| context_length, | |
| vocab_size, | |
| transformer_width, | |
| transformer_heads, | |
| transformer_layers | |
| ) | |
| pretrained = _MODELS["ViT-L/14_336"] | |
| logger.info(f"Load pretrained weights from {pretrained}") | |
| state_dict = torch.load(pretrained, map_location='cpu') | |
| model.load_state_dict(state_dict, strict=False) | |
| return model.eval() | |
| def build_clip(config): | |
| model_cls = config.text_encoder.clip_teacher | |
| model = eval(model_cls)() | |
| return model |