Spaces:
Configuration error
Configuration error
| import math | |
| from posixpath import basename, dirname, join | |
| # import clip | |
| from clip.model import convert_weights | |
| import torch | |
| import json | |
| from torch import nn | |
| from torch.nn import functional as nnf | |
| from torch.nn.modules import activation | |
| from torch.nn.modules.activation import ReLU | |
| from torchvision import transforms | |
| normalize = transforms.Normalize( | |
| mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) | |
| ) | |
| from torchvision.models import ResNet | |
| def process_prompts(conditional, prompt_list, conditional_map): | |
| # DEPRECATED | |
| # randomly sample a synonym | |
| words = [conditional_map[int(i)] for i in conditional] | |
| words = [ | |
| syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] | |
| for syns in words | |
| ] | |
| words = [w.replace("_", " ") for w in words] | |
| if prompt_list is not None: | |
| prompt_indices = torch.multinomial( | |
| torch.ones(len(prompt_list)), len(words), replacement=True | |
| ) | |
| prompts = [prompt_list[i] for i in prompt_indices] | |
| else: | |
| prompts = ["a photo of {}"] * (len(words)) | |
| return [promt.format(w) for promt, w in zip(prompts, words)] | |
| class VITDenseBase(nn.Module): | |
| def rescaled_pos_emb(self, new_size): | |
| assert len(new_size) == 2 | |
| a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape) | |
| b = ( | |
| nnf.interpolate(a, new_size, mode="bicubic", align_corners=False) | |
| .squeeze(0) | |
| .view(768, new_size[0] * new_size[1]) | |
| .T | |
| ) | |
| return torch.cat([self.model.positional_embedding[:1], b]) | |
| def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None): | |
| with torch.no_grad(): | |
| x_inp = nnf.interpolate(x_inp, (384, 384)) | |
| x = self.model.patch_embed(x_inp) | |
| cls_token = self.model.cls_token.expand( | |
| x.shape[0], -1, -1 | |
| ) # stole cls_tokens impl from Phil Wang, thanks | |
| if self.model.dist_token is None: | |
| x = torch.cat((cls_token, x), dim=1) | |
| else: | |
| x = torch.cat( | |
| (cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), | |
| dim=1, | |
| ) | |
| x = self.model.pos_drop(x + self.model.pos_embed) | |
| activations = [] | |
| for i, block in enumerate(self.model.blocks): | |
| x = block(x) | |
| if i in extract_layers: | |
| # permute to be compatible with CLIP | |
| activations += [x.permute(1, 0, 2)] | |
| x = self.model.norm(x) | |
| x = self.model.head(self.model.pre_logits(x[:, 0])) | |
| # again for CLIP compatibility | |
| # x = x.permute(1, 0, 2) | |
| return x, activations, None | |
| def sample_prompts(self, words, prompt_list=None): | |
| prompt_list = prompt_list if prompt_list is not None else self.prompt_list | |
| prompt_indices = torch.multinomial( | |
| torch.ones(len(prompt_list)), len(words), replacement=True | |
| ) | |
| prompts = [prompt_list[i] for i in prompt_indices] | |
| return [promt.format(w) for promt, w in zip(prompts, words)] | |
| def get_cond_vec(self, conditional, batch_size): | |
| # compute conditional from a single string | |
| if conditional is not None and type(conditional) == str: | |
| cond = self.compute_conditional(conditional) | |
| cond = cond.repeat(batch_size, 1) | |
| # compute conditional from string list/tuple | |
| elif ( | |
| conditional is not None | |
| and type(conditional) in {list, tuple} | |
| and type(conditional[0]) == str | |
| ): | |
| assert len(conditional) == batch_size | |
| cond = self.compute_conditional(conditional) | |
| # use conditional directly | |
| elif ( | |
| conditional is not None | |
| and type(conditional) == torch.Tensor | |
| and conditional.ndim == 2 | |
| ): | |
| cond = conditional | |
| # compute conditional from image | |
| elif conditional is not None and type(conditional) == torch.Tensor: | |
| with torch.no_grad(): | |
| cond, _, _ = self.visual_forward(conditional) | |
| else: | |
| raise ValueError("invalid conditional") | |
| return cond | |
| def compute_conditional(self, conditional): | |
| import clip | |
| dev = next(self.parameters()).device | |
| if type(conditional) in {list, tuple}: | |
| text_tokens = clip.tokenize(conditional).to(dev) | |
| cond = self.clip_model.encode_text(text_tokens) | |
| else: | |
| if conditional in self.precomputed_prompts: | |
| cond = self.precomputed_prompts[conditional].float().to(dev) | |
| else: | |
| text_tokens = clip.tokenize([conditional]).to(dev) | |
| cond = self.clip_model.encode_text(text_tokens)[0] | |
| return cond | |
| class VITDensePredT(VITDenseBase): | |
| def __init__( | |
| self, | |
| extract_layers=(3, 6, 9), | |
| cond_layer=0, | |
| reduce_dim=128, | |
| n_heads=4, | |
| prompt="fixed", | |
| depth=3, | |
| extra_blocks=0, | |
| reduce_cond=None, | |
| fix_shift=False, | |
| learn_trans_conv_only=False, | |
| refine=None, | |
| limit_to_clip_only=False, | |
| upsample=False, | |
| add_calibration=False, | |
| process_cond=None, | |
| not_pretrained=False, | |
| ): | |
| super().__init__() | |
| # device = 'cpu' | |
| self.extract_layers = extract_layers | |
| self.cond_layer = cond_layer | |
| self.limit_to_clip_only = limit_to_clip_only | |
| self.process_cond = None | |
| if add_calibration: | |
| self.calibration_conds = 1 | |
| self.upsample_proj = ( | |
| nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None | |
| ) | |
| self.add_activation1 = True | |
| import timm | |
| self.model = timm.create_model("vit_base_patch16_384", pretrained=True) | |
| self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond) | |
| for p in self.model.parameters(): | |
| p.requires_grad_(False) | |
| import clip | |
| self.clip_model, _ = clip.load("ViT-B/16", device="cpu", jit=False) | |
| # del self.clip_model.visual | |
| self.token_shape = (14, 14) | |
| # conditional | |
| if reduce_cond is not None: | |
| self.reduce_cond = nn.Linear(512, reduce_cond) | |
| for p in self.reduce_cond.parameters(): | |
| p.requires_grad_(False) | |
| else: | |
| self.reduce_cond = None | |
| # self.film = AVAILABLE_BLOCKS['film'](512, 128) | |
| self.film_mul = nn.Linear( | |
| 512 if reduce_cond is None else reduce_cond, reduce_dim | |
| ) | |
| self.film_add = nn.Linear( | |
| 512 if reduce_cond is None else reduce_cond, reduce_dim | |
| ) | |
| # DEPRECATED | |
| # self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))} | |
| assert len(self.extract_layers) == depth | |
| self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)]) | |
| self.blocks = nn.ModuleList( | |
| [ | |
| nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) | |
| for _ in range(len(self.extract_layers)) | |
| ] | |
| ) | |
| self.extra_blocks = nn.ModuleList( | |
| [ | |
| nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) | |
| for _ in range(extra_blocks) | |
| ] | |
| ) | |
| trans_conv_ks = (16, 16) | |
| self.trans_conv = nn.ConvTranspose2d( | |
| reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks | |
| ) | |
| # refinement and trans conv | |
| if learn_trans_conv_only: | |
| for p in self.parameters(): | |
| p.requires_grad_(False) | |
| for p in self.trans_conv.parameters(): | |
| p.requires_grad_(True) | |
| if prompt == "fixed": | |
| self.prompt_list = ["a photo of a {}."] | |
| elif prompt == "shuffle": | |
| self.prompt_list = [ | |
| "a photo of a {}.", | |
| "a photograph of a {}.", | |
| "an image of a {}.", | |
| "{}.", | |
| ] | |
| elif prompt == "shuffle+": | |
| self.prompt_list = [ | |
| "a photo of a {}.", | |
| "a photograph of a {}.", | |
| "an image of a {}.", | |
| "{}.", | |
| "a cropped photo of a {}.", | |
| "a good photo of a {}.", | |
| "a photo of one {}.", | |
| "a bad photo of a {}.", | |
| "a photo of the {}.", | |
| ] | |
| elif prompt == "shuffle_clip": | |
| from models.clip_prompts import imagenet_templates | |
| self.prompt_list = imagenet_templates | |
| if process_cond is not None: | |
| if process_cond == "clamp" or process_cond[0] == "clamp": | |
| val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2 | |
| def clamp_vec(x): | |
| return torch.clamp(x, -val, val) | |
| self.process_cond = clamp_vec | |
| elif process_cond.endswith(".pth"): | |
| shift = torch.load(process_cond) | |
| def add_shift(x): | |
| return x + shift.to(x.device) | |
| self.process_cond = add_shift | |
| import pickle | |
| precomp = pickle.load(open("precomputed_prompt_vectors.pickle", "rb")) | |
| self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()} | |
| def forward(self, inp_image, conditional=None, return_features=False, mask=None): | |
| assert type(return_features) == bool | |
| # inp_image = inp_image.to(self.model.positional_embedding.device) | |
| if mask is not None: | |
| raise ValueError("mask not supported") | |
| # x_inp = normalize(inp_image) | |
| x_inp = inp_image | |
| bs, dev = inp_image.shape[0], x_inp.device | |
| inp_image_size = inp_image.shape[2:] | |
| cond = self.get_cond_vec(conditional, bs) | |
| visual_q, activations, _ = self.visual_forward( | |
| x_inp, extract_layers=[0] + list(self.extract_layers) | |
| ) | |
| activation1 = activations[0] | |
| activations = activations[1:] | |
| a = None | |
| for i, (activation, block, reduce) in enumerate( | |
| zip(activations[::-1], self.blocks, self.reduces) | |
| ): | |
| if a is not None: | |
| a = reduce(activation) + a | |
| else: | |
| a = reduce(activation) | |
| if i == self.cond_layer: | |
| if self.reduce_cond is not None: | |
| cond = self.reduce_cond(cond) | |
| a = self.film_mul(cond) * a + self.film_add(cond) | |
| a = block(a) | |
| for block in self.extra_blocks: | |
| a = a + block(a) | |
| a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens | |
| size = int(math.sqrt(a.shape[2])) | |
| a = a.view(bs, a.shape[1], size, size) | |
| if self.trans_conv is not None: | |
| a = self.trans_conv(a) | |
| if self.upsample_proj is not None: | |
| a = self.upsample_proj(a) | |
| a = nnf.interpolate(a, x_inp.shape[2:], mode="bilinear") | |
| a = nnf.interpolate(a, inp_image_size) | |
| if return_features: | |
| return a, visual_q, cond, [activation1] + activations | |
| else: | |
| return (a,) | |