Spaces:
Runtime error
Runtime error
| import math | |
| from typing import Any, Mapping | |
| import torch | |
| from torchvision.transforms.functional import to_pil_image | |
| import torch.nn as nn | |
| import kornia | |
| import open_clip | |
| from transformers import CLIPVisionModelWithProjection, AutoProcessor | |
| from transformers.models.bit.image_processing_bit import BitImageProcessor | |
| from einops import rearrange, repeat | |
| # FFN | |
| # from mamba_ssm import Mamba | |
| class ImgEmbContextResampler(nn.Module): | |
| def __init__( | |
| self, | |
| inner_dim=1280, | |
| cross_attention_dim=1024, | |
| expansion_factor=16, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.context_embedding = nn.Sequential( | |
| nn.Linear(cross_attention_dim, inner_dim), | |
| nn.SiLU(), | |
| nn.Linear(inner_dim, cross_attention_dim * expansion_factor), | |
| ) | |
| self.expansion_factor = expansion_factor | |
| self.cross_attention_dim = cross_attention_dim | |
| def forward(self, x, batch_size=0): | |
| if x.ndim == 2: | |
| x = rearrange(x, "(B F) C -> B F C", B=batch_size) | |
| assert x.ndim == 3 | |
| x = torch.mean(x, dim=1, keepdim=True) | |
| x = self.context_embedding(x) | |
| x = x.view(-1, self.expansion_factor, self.cross_attention_dim) | |
| return x | |
| class AbstractEncoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.embedding_dim = -1 | |
| self.num_tokens = -1 | |
| def encode(self, *args, **kwargs): | |
| raise NotImplementedError | |
| class FrozenOpenCLIPImageEmbedder(AbstractEncoder): | |
| """ | |
| Uses the OpenCLIP vision transformer encoder for images | |
| """ | |
| def __init__( | |
| self, | |
| arch="ViT-H-14", | |
| version="laion2b_s32b_b79k", | |
| device="cuda", | |
| max_length=77, | |
| freeze=True, | |
| antialias=True, | |
| ucg_rate=0.0, | |
| unsqueeze_dim=False, | |
| repeat_to_max_len=False, | |
| num_image_crops=0, | |
| output_tokens=False, | |
| ): | |
| super().__init__() | |
| model, _, _ = open_clip.create_model_and_transforms( | |
| arch, | |
| device=torch.device("cpu"), | |
| pretrained=version, | |
| ) | |
| del model.transformer | |
| self.model = model | |
| # self.model_t = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") | |
| # self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") | |
| self.max_crops = num_image_crops | |
| self.pad_to_max_len = self.max_crops > 0 | |
| self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) | |
| self.device = device | |
| self.max_length = max_length | |
| if freeze: | |
| self.freeze() | |
| self.antialias = antialias | |
| self.register_buffer( | |
| "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False | |
| ) | |
| self.register_buffer( | |
| "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False | |
| ) | |
| self.ucg_rate = ucg_rate | |
| self.unsqueeze_dim = unsqueeze_dim | |
| self.stored_batch = None | |
| # self.model.visual.output_tokens = output_tokens | |
| self.output_tokens = output_tokens | |
| def preprocess(self, x): | |
| # normalize to [0,1] | |
| x = kornia.geometry.resize( | |
| x, | |
| (224, 224), | |
| interpolation="bicubic", | |
| align_corners=True, | |
| antialias=self.antialias, | |
| ) | |
| x = (x + 1.0) / 2.0 | |
| # renormalize according to clip | |
| x = kornia.enhance.normalize(x, self.mean, self.std) | |
| return x | |
| def freeze(self): | |
| self.model = self.model.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| # self.model_t = self.model_t.eval() | |
| def forward(self, image, no_dropout=False): | |
| z = self.encode_with_vision_transformer(image) | |
| tokens = None | |
| if self.output_tokens: | |
| z, tokens = z[0], z[1] | |
| z = z.to(image.dtype) | |
| if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): | |
| z = ( | |
| torch.bernoulli( | |
| (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) | |
| )[:, None] | |
| * z | |
| ) | |
| if tokens is not None: | |
| tokens = ( | |
| expand_dims_like( | |
| torch.bernoulli( | |
| (1.0 - self.ucg_rate) | |
| * torch.ones(tokens.shape[0], device=tokens.device) | |
| ), | |
| tokens, | |
| ) | |
| * tokens | |
| ) | |
| if self.unsqueeze_dim: | |
| z = z[:, None, :] | |
| if self.output_tokens: | |
| assert not self.repeat_to_max_len | |
| assert not self.pad_to_max_len | |
| return tokens, z | |
| if self.repeat_to_max_len: | |
| if z.dim() == 2: | |
| z_ = z[:, None, :] | |
| else: | |
| z_ = z | |
| return repeat(z_, "b 1 d -> b n d", n=self.max_length), z | |
| elif self.pad_to_max_len: | |
| assert z.dim() == 3 | |
| z_pad = torch.cat( | |
| ( | |
| z, | |
| torch.zeros( | |
| z.shape[0], | |
| self.max_length - z.shape[1], | |
| z.shape[2], | |
| device=z.device, | |
| ), | |
| ), | |
| 1, | |
| ) | |
| return z_pad, z_pad[:, 0, ...] | |
| return z | |
| def encode_with_vision_transformer(self, img): | |
| if self.max_crops > 0: | |
| img = self.preprocess_by_cropping(img) | |
| # pil_img = to_pil_image(img[0]*0.5 + 0.5) | |
| # inputs = self.processor(images=pil_img, return_tensors="pt").to("cuda") | |
| # outputs = self.model_t(**inputs) | |
| # return outputs.image_embeds | |
| if img.dim() == 5: | |
| assert self.max_crops == img.shape[1] | |
| img = rearrange(img, "b n c h w -> (b n) c h w") | |
| img = self.preprocess(img) | |
| if not self.output_tokens: | |
| assert not self.model.visual.output_tokens | |
| x = self.model.visual(img) | |
| tokens = None | |
| else: | |
| assert self.model.visual.output_tokens | |
| x, tokens = self.model.visual(img) | |
| if self.max_crops > 0: | |
| x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) | |
| # drop out between 0 and all along the sequence axis | |
| x = ( | |
| torch.bernoulli( | |
| (1.0 - self.ucg_rate) | |
| * torch.ones(x.shape[0], x.shape[1], 1, device=x.device) | |
| ) | |
| * x | |
| ) | |
| if tokens is not None: | |
| tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) | |
| print( | |
| f"You are running very experimental token-concat in {self.__class__.__name__}. " | |
| f"Check what you are doing, and then remove this message." | |
| ) | |
| if self.output_tokens: | |
| return x, tokens | |
| return x | |
| def encode(self, text): | |
| return self(text) |