Spaces:
Runtime error
Runtime error
| from functools import partial | |
| import math | |
| from typing import Iterable | |
| from black import diff | |
| from torch import nn, einsum | |
| import numpy as np | |
| import torch as th | |
| import torch.nn as nn | |
| import functools | |
| import torch.nn.functional as F | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn, Tensor | |
| from einops import rearrange | |
| import copy | |
| from torchvision import transforms | |
| from torchvision.transforms import InterpolationMode | |
| class MLP(nn.Module): | |
| """Very simple multi-layer perceptron (also called FFN)""" | |
| def __init__(self, input_dim, hidden_dim, output_dim, num_layers): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| h = [hidden_dim] * (num_layers - 1) | |
| self.layers = nn.ModuleList( | |
| nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) | |
| ) | |
| def forward(self, x): | |
| for i, layer in enumerate(self.layers): | |
| x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
| return x | |
| def resize_fn(img, size): | |
| return transforms.Resize(size, InterpolationMode.BICUBIC)( | |
| transforms.ToPILImage()(img)) | |
| import math | |
| def _get_clones(module, N): | |
| return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |
| def _get_activation_fn(activation): | |
| """Return an activation function given a string""" | |
| if activation == "relu": | |
| return F.relu | |
| if activation == "gelu": | |
| return F.gelu | |
| if activation == "glu": | |
| return F.glu | |
| raise RuntimeError(F"activation should be relu/gelu, not {activation}.") | |
| class TransformerDecoder(nn.Module): | |
| def __init__(self, decoder_layer, num_layers): | |
| super().__init__() | |
| self.layers = _get_clones(decoder_layer, num_layers) | |
| self.num_layers = num_layers | |
| def forward(self, tgt, memory, pos = None, query_pos = None): | |
| output = tgt | |
| for layer in self.layers: | |
| output = layer(output, memory, pos=pos, query_pos=query_pos) | |
| return output | |
| class TransformerDecoderLayer(nn.Module): | |
| def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, no_norm = False, | |
| activation="relu"): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False) | |
| self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False) | |
| # Implementation of Feedforward model | |
| self.linear1 = nn.Linear(d_model, dim_feedforward) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model) | |
| self.norm1 = nn.LayerNorm(d_model) if not no_norm else nn.Identity() | |
| self.norm2 = nn.LayerNorm(d_model) if not no_norm else nn.Identity() | |
| self.norm3 = nn.LayerNorm(d_model) if not no_norm else nn.Identity() | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.dropout3 = nn.Dropout(dropout) | |
| self.activation = _get_activation_fn(activation) | |
| def with_pos_embed(self, tensor, pos): | |
| return tensor if pos is None else tensor + pos | |
| def forward(self, tgt, memory, pos = None, query_pos = None): | |
| tgt2 = self.norm1(tgt) | |
| q = k = self.with_pos_embed(tgt2, query_pos) | |
| # print('q:',q.shape) | |
| tgt2 = self.self_attn(q, k, value=tgt2)[0] | |
| tgt = tgt + self.dropout1(tgt2) | |
| tgt2 = self.norm2(tgt) | |
| tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), | |
| key=self.with_pos_embed(memory, pos), | |
| value=memory)[0] | |
| tgt = tgt + self.dropout2(tgt2) | |
| tgt2 = self.norm3(tgt) | |
| tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) | |
| tgt = tgt + self.dropout3(tgt2) | |
| return tgt | |
| # Projection of x onto y | |
| def proj(x, y): | |
| return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) | |
| # Orthogonalize x wrt list of vectors ys | |
| def gram_schmidt(x, ys): | |
| for y in ys: | |
| x = x - proj(x, y) | |
| return x | |
| def power_iteration(W, u_, update=True, eps=1e-12): | |
| # Lists holding singular vectors and values | |
| us, vs, svs = [], [], [] | |
| for i, u in enumerate(u_): | |
| # Run one step of the power iteration | |
| with torch.no_grad(): | |
| v = torch.matmul(u, W) | |
| # Run Gram-Schmidt to subtract components of all other singular vectors | |
| v = F.normalize(gram_schmidt(v, vs), eps=eps) | |
| # Add to the list | |
| vs += [v] | |
| # Update the other singular vector | |
| u = torch.matmul(v, W.t()) | |
| # Run Gram-Schmidt to subtract components of all other singular vectors | |
| u = F.normalize(gram_schmidt(u, us), eps=eps) | |
| # Add to the list | |
| us += [u] | |
| if update: | |
| u_[i][:] = u | |
| # Compute this singular value and add it to the list | |
| svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] | |
| #svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)] | |
| return svs, us, vs | |
| # Spectral normalization base class | |
| class SN(object): | |
| def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): | |
| # Number of power iterations per step | |
| self.num_itrs = num_itrs | |
| # Number of singular values | |
| self.num_svs = num_svs | |
| # Transposed? | |
| self.transpose = transpose | |
| # Epsilon value for avoiding divide-by-0 | |
| self.eps = eps | |
| # Register a singular vector for each sv | |
| for i in range(self.num_svs): | |
| self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) | |
| self.register_buffer('sv%d' % i, torch.ones(1)) | |
| # Singular vectors (u side) | |
| def u(self): | |
| return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] | |
| # Singular values; | |
| # note that these buffers are just for logging and are not used in training. | |
| def sv(self): | |
| return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] | |
| # Compute the spectrally-normalized weight | |
| def W_(self): | |
| W_mat = self.weight.view(self.weight.size(0), -1) | |
| if self.transpose: | |
| W_mat = W_mat.t() | |
| # Apply num_itrs power iterations | |
| for _ in range(self.num_itrs): | |
| svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) | |
| # Update the svs | |
| if self.training: | |
| with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks! | |
| for i, sv in enumerate(svs): | |
| self.sv[i][:] = sv | |
| return self.weight / svs[0] | |
| # Linear layer with spectral norm | |
| class SNLinear(nn.Linear, SN): | |
| def __init__(self, in_features, out_features, bias=True, | |
| num_svs=1, num_itrs=1, eps=1e-12): | |
| nn.Linear.__init__(self, in_features, out_features, bias) | |
| SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) | |
| def forward(self, x): | |
| return F.linear(x, self.W_(), self.bias) | |
| # 2D Conv layer with spectral norm | |
| class SNConv2d(nn.Conv2d, SN): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, | |
| padding=0, dilation=1, groups=1, bias=True, | |
| num_svs=1, num_itrs=1, eps=1e-12): | |
| nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, | |
| padding, dilation, groups, bias) | |
| SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) | |
| def forward(self, x): | |
| return F.conv2d(x, self.W_(), self.bias, self.stride, | |
| self.padding, self.dilation, self.groups) | |
| class SegBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, con_channels, | |
| which_conv=nn.Conv2d, which_linear=None, activation=None, | |
| upsample=None): | |
| super(SegBlock, self).__init__() | |
| self.in_channels, self.out_channels = in_channels, out_channels | |
| self.which_conv, self.which_linear = which_conv, which_linear | |
| self.activation = activation | |
| self.upsample = upsample | |
| self.conv1 = self.which_conv(self.in_channels, self.out_channels) | |
| self.conv2 = self.which_conv(self.out_channels, self.out_channels) | |
| self.learnable_sc = in_channels != out_channels or upsample | |
| if self.learnable_sc: | |
| self.conv_sc = self.which_conv(in_channels, out_channels, | |
| kernel_size=1, padding=0) | |
| self.register_buffer('stored_mean1', torch.zeros(in_channels)) | |
| self.register_buffer('stored_var1', torch.ones(in_channels)) | |
| self.register_buffer('stored_mean2', torch.zeros(out_channels)) | |
| self.register_buffer('stored_var2', torch.ones(out_channels)) | |
| self.upsample = upsample | |
| def forward(self, x, y=None): | |
| x = F.batch_norm(x, self.stored_mean1, self.stored_var1, None, None, | |
| self.training, 0.1, 1e-4) | |
| h = self.activation(x) | |
| if self.upsample: | |
| h = self.upsample(h) | |
| x = self.upsample(x) | |
| h = self.conv1(h) | |
| h = F.batch_norm(h, self.stored_mean2, self.stored_var2, None, None, | |
| self.training, 0.1, 1e-4) | |
| h = self.activation(h) | |
| h = self.conv2(h) | |
| if self.learnable_sc: | |
| x = self.conv_sc(x) | |
| return h + x | |
| def make_coord(shape, ranges=None, flatten=True): | |
| """ Make coordinates at grid centers. | |
| """ | |
| coord_seqs = [] | |
| for i, n in enumerate(shape): | |
| if ranges is None: | |
| v0, v1 = -1, 1 | |
| else: | |
| v0, v1 = ranges[i] | |
| r = (v1 - v0) / (2 * n) | |
| seq = v0 + r + (2 * r) * torch.arange(n).float() | |
| coord_seqs.append(seq) | |
| ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) | |
| if flatten: | |
| ret = ret.view(-1, ret.shape[-1]) | |
| return ret | |
| class Embedder: | |
| def __init__(self, **kwargs): | |
| self.kwargs = kwargs | |
| self.create_embedding_fn() | |
| def create_embedding_fn(self): | |
| embed_fns = [] | |
| d = self.kwargs['input_dims'] | |
| out_dim = 0 | |
| if self.kwargs['include_input']: | |
| embed_fns.append(lambda x : x) | |
| out_dim += d | |
| max_freq = self.kwargs['max_freq_log2'] | |
| N_freqs = self.kwargs['num_freqs'] | |
| if self.kwargs['log_sampling']: | |
| freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs).double() | |
| else: | |
| freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) | |
| for freq in freq_bands: | |
| for p_fn in self.kwargs['periodic_fns']: | |
| embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x.double() * freq)) | |
| out_dim += d | |
| self.embed_fns = embed_fns | |
| self.out_dim = out_dim | |
| def embed(self, inputs): | |
| return torch.cat([fn(inputs) for fn in self.embed_fns], -1) | |
| def get_embedder(multires, i=0): | |
| if i == -1: | |
| return nn.Identity(), 3 | |
| embed_kwargs = { | |
| 'include_input' : False, | |
| 'input_dims' : 2, | |
| 'max_freq_log2' : multires-1, | |
| 'num_freqs' : multires, | |
| 'log_sampling' : True, | |
| 'periodic_fns' : [torch.sin, torch.cos], | |
| } | |
| embedder_obj = Embedder(**embed_kwargs) | |
| embed = lambda x, eo=embedder_obj : eo.embed(x) | |
| return embed, embedder_obj.out_dim | |
| class Segmodule(nn.Module): | |
| def __init__(self, | |
| embedding_dim=512, | |
| num_heads=8, | |
| num_layers=3, | |
| hidden_dim=2048, | |
| dropout_rate=0): | |
| super().__init__() | |
| low_feature_channel = 16 | |
| mid_feature_channel = 32 | |
| high_feature_channel = 64 | |
| highest_feature_channel=128 | |
| self.low_feature_conv = nn.Sequential( | |
| nn.Conv2d(1280*6*2, low_feature_channel, kernel_size=1, bias=False), | |
| ) | |
| self.mid_feature_conv = nn.Sequential( | |
| nn.Conv2d((1280*5+640)*2, mid_feature_channel, kernel_size=1, bias=False), | |
| ) | |
| self.mid_feature_mix_conv = SegBlock( | |
| in_channels=low_feature_channel+mid_feature_channel, | |
| out_channels=low_feature_channel+mid_feature_channel, | |
| con_channels=128, | |
| which_conv=functools.partial(SNConv2d, | |
| kernel_size=3, padding=1, | |
| num_svs=1, num_itrs=1, | |
| eps=1e-04), | |
| which_linear=functools.partial(SNLinear, | |
| num_svs=1, num_itrs=1, | |
| eps=1e-04), | |
| activation=nn.ReLU(inplace=True), | |
| upsample=False, | |
| ) | |
| self.high_feature_conv = nn.Sequential( | |
| nn.Conv2d((1280+640*4+320)*2, high_feature_channel, kernel_size=1, bias=False), | |
| ) | |
| self.high_feature_mix_conv = SegBlock( | |
| in_channels=low_feature_channel+mid_feature_channel+high_feature_channel, | |
| out_channels=low_feature_channel+mid_feature_channel+high_feature_channel, | |
| con_channels=128, | |
| which_conv=functools.partial(SNConv2d, | |
| kernel_size=3, padding=1, | |
| num_svs=1, num_itrs=1, | |
| eps=1e-04), | |
| which_linear=functools.partial(SNLinear, | |
| num_svs=1, num_itrs=1, | |
| eps=1e-04), | |
| activation=nn.ReLU(inplace=True), | |
| upsample=False, | |
| ) | |
| self.highest_feature_conv = nn.Sequential( | |
| nn.Conv2d((640+320*6)*2, highest_feature_channel, kernel_size=1, bias=False), | |
| ) | |
| self.highest_feature_mix_conv = SegBlock( | |
| in_channels=low_feature_channel+mid_feature_channel+high_feature_channel+highest_feature_channel, | |
| out_channels=low_feature_channel+mid_feature_channel+high_feature_channel+highest_feature_channel, | |
| con_channels=128, | |
| which_conv=functools.partial(SNConv2d, | |
| kernel_size=3, padding=1, | |
| num_svs=1, num_itrs=1, | |
| eps=1e-04), | |
| which_linear=functools.partial(SNLinear, | |
| num_svs=1, num_itrs=1, | |
| eps=1e-04), | |
| activation=nn.ReLU(inplace=True), | |
| upsample=False, | |
| ) | |
| feature_dim=low_feature_channel+mid_feature_channel+high_feature_channel+highest_feature_channel | |
| query_dim=feature_dim*16 | |
| decoder_layer = TransformerDecoderLayer(embedding_dim, num_heads, hidden_dim, dropout_rate) | |
| self.transfromer_decoder = TransformerDecoder(decoder_layer, num_layers) | |
| self.mlp = MLP(embedding_dim, embedding_dim, feature_dim, 3) | |
| context_dim=768 | |
| self.to_k = nn.Linear(query_dim, embedding_dim, bias=False) | |
| self.to_q = nn.Linear(context_dim, embedding_dim, bias=False) | |
| def forward(self,diffusion_feature,text_embedding): | |
| image_feature=self._prepare_features(diffusion_feature) | |
| final_image_feature=F.interpolate(image_feature, size=512, mode='bilinear', align_corners=False) | |
| b=final_image_feature.size()[0] | |
| patch_size = 4 | |
| patch_number=int(image_feature.size()[2]/patch_size) | |
| image_feature = torch.nn.functional.unfold(image_feature, patch_size, stride=patch_size).transpose(1,2).contiguous() | |
| image_feature=rearrange(image_feature, 'b n d -> (b n) d ') | |
| text_embedding=rearrange(text_embedding, 'b n d -> (b n) d ') | |
| q = self.to_q(text_embedding) | |
| k = self.to_k(image_feature) | |
| output_query = self.transfromer_decoder(q, k, None) | |
| output_query=rearrange(output_query, '(b n) d -> b n d',b=b) | |
| mask_embedding=self.mlp(output_query) | |
| seg_result=einsum('b d h w, b n d -> b n h w', final_image_feature, mask_embedding) | |
| return seg_result | |
| def _prepare_features(self, features, upsample='bilinear'): | |
| self.low_feature_size = 16 | |
| self.mid_feature_size = 32 | |
| self.high_feature_size = 64 | |
| low_features = [ | |
| F.interpolate(i, size=self.low_feature_size, mode=upsample, align_corners=False) for i in features["low"] | |
| ] | |
| low_features = torch.cat(low_features, dim=1) | |
| mid_features = [ | |
| F.interpolate(i, size=self.mid_feature_size, mode=upsample, align_corners=False) for i in features["mid"] | |
| ] | |
| mid_features = torch.cat(mid_features, dim=1) | |
| high_features = [ | |
| F.interpolate(i, size=self.high_feature_size, mode=upsample, align_corners=False) for i in features["high"] | |
| ] | |
| high_features = torch.cat(high_features, dim=1) | |
| highest_features=torch.cat(features["highest"],dim=1) | |
| features_dict = { | |
| 'low': low_features, | |
| 'mid': mid_features, | |
| 'high': high_features, | |
| 'highest':highest_features, | |
| } | |
| low_feat = self.low_feature_conv(features_dict['low']) | |
| low_feat = F.interpolate(low_feat, size=self.mid_feature_size, mode='bilinear', align_corners=False) | |
| mid_feat = self.mid_feature_conv(features_dict['mid']) | |
| mid_feat = torch.cat([low_feat, mid_feat], dim=1) | |
| mid_feat = self.mid_feature_mix_conv(mid_feat, y=None) | |
| mid_feat = F.interpolate(mid_feat, size=self.high_feature_size, mode='bilinear', align_corners=False) | |
| high_feat = self.high_feature_conv(features_dict['high']) | |
| high_feat = torch.cat([mid_feat, high_feat], dim=1) | |
| high_feat = self.high_feature_mix_conv(high_feat, y=None) | |
| highest_feat=self.highest_feature_conv(features_dict['highest']) | |
| highest_feat=torch.cat([high_feat,highest_feat],dim=1) | |
| highest_feat=self.highest_feature_mix_conv(highest_feat,y=None) | |
| return highest_feat | |