| | import math |
| | from posixpath import basename, dirname, join |
| | |
| | from clip.model import convert_weights |
| | import torch |
| | import json |
| | from torch import nn |
| | from torch.nn import functional as nnf |
| | from torch.nn.modules import activation |
| | from torch.nn.modules.activation import ReLU |
| | from torchvision import transforms |
| |
|
| | normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) |
| |
|
| | from torchvision.models import ResNet |
| |
|
| |
|
| | def process_prompts(conditional, prompt_list, conditional_map): |
| | |
| | |
| | |
| | words = [conditional_map[int(i)] for i in conditional] |
| | words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words] |
| | words = [w.replace('_', ' ') for w in words] |
| |
|
| | if prompt_list is not None: |
| | prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) |
| | prompts = [prompt_list[i] for i in prompt_indices] |
| | else: |
| | prompts = ['a photo of {}'] * (len(words)) |
| |
|
| | return [promt.format(w) for promt, w in zip(prompts, words)] |
| |
|
| |
|
| | class VITDenseBase(nn.Module): |
| | |
| | def rescaled_pos_emb(self, new_size): |
| | assert len(new_size) == 2 |
| |
|
| | a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape) |
| | b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T |
| | return torch.cat([self.model.positional_embedding[:1], b]) |
| |
|
| | def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None): |
| | |
| | with torch.no_grad(): |
| |
|
| | x_inp = nnf.interpolate(x_inp, (384, 384)) |
| |
|
| | x = self.model.patch_embed(x_inp) |
| | cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) |
| | if self.model.dist_token is None: |
| | x = torch.cat((cls_token, x), dim=1) |
| | else: |
| | x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1) |
| | x = self.model.pos_drop(x + self.model.pos_embed) |
| |
|
| | activations = [] |
| | for i, block in enumerate(self.model.blocks): |
| | x = block(x) |
| |
|
| | if i in extract_layers: |
| | |
| | activations += [x.permute(1,0,2)] |
| |
|
| | x = self.model.norm(x) |
| | x = self.model.head(self.model.pre_logits(x[:, 0])) |
| |
|
| | |
| | |
| |
|
| | return x, activations, None |
| |
|
| | def sample_prompts(self, words, prompt_list=None): |
| |
|
| | prompt_list = prompt_list if prompt_list is not None else self.prompt_list |
| |
|
| | prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) |
| | prompts = [prompt_list[i] for i in prompt_indices] |
| | return [promt.format(w) for promt, w in zip(prompts, words)] |
| |
|
| | def get_cond_vec(self, conditional, batch_size): |
| | |
| | if conditional is not None and type(conditional) == str: |
| | cond = self.compute_conditional(conditional) |
| | cond = cond.repeat(batch_size, 1) |
| |
|
| | |
| | elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str: |
| | assert len(conditional) == batch_size |
| | cond = self.compute_conditional(conditional) |
| |
|
| | |
| | elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2: |
| | cond = conditional |
| |
|
| | |
| | elif conditional is not None and type(conditional) == torch.Tensor: |
| | with torch.no_grad(): |
| | cond, _, _ = self.visual_forward(conditional) |
| | else: |
| | raise ValueError('invalid conditional') |
| | return cond |
| |
|
| | def compute_conditional(self, conditional): |
| | import clip |
| |
|
| | dev = next(self.parameters()).device |
| |
|
| | if type(conditional) in {list, tuple}: |
| | text_tokens = clip.tokenize(conditional).to(dev) |
| | cond = self.clip_model.encode_text(text_tokens) |
| | else: |
| | if conditional in self.precomputed_prompts: |
| | cond = self.precomputed_prompts[conditional].float().to(dev) |
| | else: |
| | text_tokens = clip.tokenize([conditional]).to(dev) |
| | cond = self.clip_model.encode_text(text_tokens)[0] |
| | |
| | return cond |
| |
|
| |
|
| | class VITDensePredT(VITDenseBase): |
| |
|
| | def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed', |
| | depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False, |
| | learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False, |
| | add_calibration=False, process_cond=None, not_pretrained=False): |
| | super().__init__() |
| | |
| |
|
| | self.extract_layers = extract_layers |
| | self.cond_layer = cond_layer |
| | self.limit_to_clip_only = limit_to_clip_only |
| | self.process_cond = None |
| | |
| | if add_calibration: |
| | self.calibration_conds = 1 |
| |
|
| | self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None |
| |
|
| | self.add_activation1 = True |
| |
|
| | import timm |
| | self.model = timm.create_model('vit_base_patch16_384', pretrained=True) |
| | self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond) |
| |
|
| | for p in self.model.parameters(): |
| | p.requires_grad_(False) |
| |
|
| | import clip |
| | self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False) |
| | |
| | |
| | |
| | self.token_shape = (14, 14) |
| |
|
| | |
| | if reduce_cond is not None: |
| | self.reduce_cond = nn.Linear(512, reduce_cond) |
| | for p in self.reduce_cond.parameters(): |
| | p.requires_grad_(False) |
| | else: |
| | self.reduce_cond = None |
| |
|
| | |
| | self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) |
| | self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) |
| | |
| | |
| | |
| | |
| | assert len(self.extract_layers) == depth |
| |
|
| | self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)]) |
| | self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))]) |
| | self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)]) |
| |
|
| | trans_conv_ks = (16, 16) |
| | self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) |
| |
|
| | |
| |
|
| | if learn_trans_conv_only: |
| | for p in self.parameters(): |
| | p.requires_grad_(False) |
| | |
| | for p in self.trans_conv.parameters(): |
| | p.requires_grad_(True) |
| |
|
| | if prompt == 'fixed': |
| | self.prompt_list = ['a photo of a {}.'] |
| | elif prompt == 'shuffle': |
| | self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] |
| | elif prompt == 'shuffle+': |
| | self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.', |
| | 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.', |
| | 'a bad photo of a {}.', 'a photo of the {}.'] |
| | elif prompt == 'shuffle_clip': |
| | from models.clip_prompts import imagenet_templates |
| | self.prompt_list = imagenet_templates |
| |
|
| | if process_cond is not None: |
| | if process_cond == 'clamp' or process_cond[0] == 'clamp': |
| |
|
| | val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2 |
| |
|
| | def clamp_vec(x): |
| | return torch.clamp(x, -val, val) |
| |
|
| | self.process_cond = clamp_vec |
| |
|
| | elif process_cond.endswith('.pth'): |
| | |
| | shift = torch.load(process_cond) |
| | def add_shift(x): |
| | return x + shift.to(x.device) |
| |
|
| | self.process_cond = add_shift |
| |
|
| | import pickle |
| | precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb')) |
| | self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()} |
| |
|
| |
|
| | def forward(self, inp_image, conditional=None, return_features=False, mask=None): |
| |
|
| | assert type(return_features) == bool |
| |
|
| | |
| |
|
| | if mask is not None: |
| | raise ValueError('mask not supported') |
| |
|
| | |
| | x_inp = inp_image |
| |
|
| | bs, dev = inp_image.shape[0], x_inp.device |
| |
|
| | inp_image_size = inp_image.shape[2:] |
| |
|
| | cond = self.get_cond_vec(conditional, bs) |
| |
|
| | visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers)) |
| |
|
| | activation1 = activations[0] |
| | activations = activations[1:] |
| |
|
| | a = None |
| | for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)): |
| | |
| | if a is not None: |
| | a = reduce(activation) + a |
| | else: |
| | a = reduce(activation) |
| |
|
| | if i == self.cond_layer: |
| | if self.reduce_cond is not None: |
| | cond = self.reduce_cond(cond) |
| | |
| | a = self.film_mul(cond) * a + self.film_add(cond) |
| |
|
| | a = block(a) |
| |
|
| | for block in self.extra_blocks: |
| | a = a + block(a) |
| |
|
| | a = a[1:].permute(1, 2, 0) |
| |
|
| | size = int(math.sqrt(a.shape[2])) |
| |
|
| | a = a.view(bs, a.shape[1], size, size) |
| |
|
| | if self.trans_conv is not None: |
| | a = self.trans_conv(a) |
| |
|
| | if self.upsample_proj is not None: |
| | a = self.upsample_proj(a) |
| | a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear') |
| |
|
| | a = nnf.interpolate(a, inp_image_size) |
| |
|
| | if return_features: |
| | return a, visual_q, cond, [activation1] + activations |
| | else: |
| | return a, |
| |
|