Grounded-Diffusion / ldm /models /seg_module.py
Purple11's picture
Update ldm/models/seg_module.py
4906d2d
from functools import partial
import math
from typing import Iterable
from black import diff
from torch import nn, einsum
import numpy as np
import torch as th
import torch.nn as nn
import functools
import torch.nn.functional as F
import math
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from einops import rearrange
import copy
from torchvision import transforms
from torchvision.transforms import InterpolationMode
class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
def resize_fn(img, size):
return transforms.Resize(size, InterpolationMode.BICUBIC)(
transforms.ToPILImage()(img))
import math
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
def forward(self, tgt, memory, pos = None, query_pos = None):
output = tgt
for layer in self.layers:
output = layer(output, memory, pos=pos, query_pos=query_pos)
return output
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, no_norm = False,
activation="relu"):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
self.norm2 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
self.norm3 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos):
return tensor if pos is None else tensor + pos
def forward(self, tgt, memory, pos = None, query_pos = None):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
# print('q:',q.shape)
tgt2 = self.self_attn(q, k, value=tgt2)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
# Projection of x onto y
def proj(x, y):
return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
# Orthogonalize x wrt list of vectors ys
def gram_schmidt(x, ys):
for y in ys:
x = x - proj(x, y)
return x
def power_iteration(W, u_, update=True, eps=1e-12):
# Lists holding singular vectors and values
us, vs, svs = [], [], []
for i, u in enumerate(u_):
# Run one step of the power iteration
with torch.no_grad():
v = torch.matmul(u, W)
# Run Gram-Schmidt to subtract components of all other singular vectors
v = F.normalize(gram_schmidt(v, vs), eps=eps)
# Add to the list
vs += [v]
# Update the other singular vector
u = torch.matmul(v, W.t())
# Run Gram-Schmidt to subtract components of all other singular vectors
u = F.normalize(gram_schmidt(u, us), eps=eps)
# Add to the list
us += [u]
if update:
u_[i][:] = u
# Compute this singular value and add it to the list
svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
#svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
return svs, us, vs
# Spectral normalization base class
class SN(object):
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
# Number of power iterations per step
self.num_itrs = num_itrs
# Number of singular values
self.num_svs = num_svs
# Transposed?
self.transpose = transpose
# Epsilon value for avoiding divide-by-0
self.eps = eps
# Register a singular vector for each sv
for i in range(self.num_svs):
self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
self.register_buffer('sv%d' % i, torch.ones(1))
# Singular vectors (u side)
@property
def u(self):
return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
# Singular values;
# note that these buffers are just for logging and are not used in training.
@property
def sv(self):
return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
# Compute the spectrally-normalized weight
def W_(self):
W_mat = self.weight.view(self.weight.size(0), -1)
if self.transpose:
W_mat = W_mat.t()
# Apply num_itrs power iterations
for _ in range(self.num_itrs):
svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
# Update the svs
if self.training:
with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
for i, sv in enumerate(svs):
self.sv[i][:] = sv
return self.weight / svs[0]
# Linear layer with spectral norm
class SNLinear(nn.Linear, SN):
def __init__(self, in_features, out_features, bias=True,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Linear.__init__(self, in_features, out_features, bias)
SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
def forward(self, x):
return F.linear(x, self.W_(), self.bias)
# 2D Conv layer with spectral norm
class SNConv2d(nn.Conv2d, SN):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
def forward(self, x):
return F.conv2d(x, self.W_(), self.bias, self.stride,
self.padding, self.dilation, self.groups)
class SegBlock(nn.Module):
def __init__(self, in_channels, out_channels, con_channels,
which_conv=nn.Conv2d, which_linear=None, activation=None,
upsample=None):
super(SegBlock, self).__init__()
self.in_channels, self.out_channels = in_channels, out_channels
self.which_conv, self.which_linear = which_conv, which_linear
self.activation = activation
self.upsample = upsample
self.conv1 = self.which_conv(self.in_channels, self.out_channels)
self.conv2 = self.which_conv(self.out_channels, self.out_channels)
self.learnable_sc = in_channels != out_channels or upsample
if self.learnable_sc:
self.conv_sc = self.which_conv(in_channels, out_channels,
kernel_size=1, padding=0)
self.register_buffer('stored_mean1', torch.zeros(in_channels))
self.register_buffer('stored_var1', torch.ones(in_channels))
self.register_buffer('stored_mean2', torch.zeros(out_channels))
self.register_buffer('stored_var2', torch.ones(out_channels))
self.upsample = upsample
def forward(self, x, y=None):
x = F.batch_norm(x, self.stored_mean1, self.stored_var1, None, None,
self.training, 0.1, 1e-4)
h = self.activation(x)
if self.upsample:
h = self.upsample(h)
x = self.upsample(x)
h = self.conv1(h)
h = F.batch_norm(h, self.stored_mean2, self.stored_var2, None, None,
self.training, 0.1, 1e-4)
h = self.activation(h)
h = self.conv2(h)
if self.learnable_sc:
x = self.conv_sc(x)
return h + x
def make_coord(shape, ranges=None, flatten=True):
""" Make coordinates at grid centers.
"""
coord_seqs = []
for i, n in enumerate(shape):
if ranges is None:
v0, v1 = -1, 1
else:
v0, v1 = ranges[i]
r = (v1 - v0) / (2 * n)
seq = v0 + r + (2 * r) * torch.arange(n).float()
coord_seqs.append(seq)
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
if flatten:
ret = ret.view(-1, ret.shape[-1])
return ret
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x : x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs).double()
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x.double() * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, i=0):
if i == -1:
return nn.Identity(), 3
embed_kwargs = {
'include_input' : False,
'input_dims' : 2,
'max_freq_log2' : multires-1,
'num_freqs' : multires,
'log_sampling' : True,
'periodic_fns' : [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj : eo.embed(x)
return embed, embedder_obj.out_dim
class Segmodule(nn.Module):
def __init__(self,
embedding_dim=512,
num_heads=8,
num_layers=3,
hidden_dim=2048,
dropout_rate=0):
super().__init__()
low_feature_channel = 16
mid_feature_channel = 32
high_feature_channel = 64
highest_feature_channel=128
self.low_feature_conv = nn.Sequential(
nn.Conv2d(1280*6*2, low_feature_channel, kernel_size=1, bias=False),
)
self.mid_feature_conv = nn.Sequential(
nn.Conv2d((1280*5+640)*2, mid_feature_channel, kernel_size=1, bias=False),
)
self.mid_feature_mix_conv = SegBlock(
in_channels=low_feature_channel+mid_feature_channel,
out_channels=low_feature_channel+mid_feature_channel,
con_channels=128,
which_conv=functools.partial(SNConv2d,
kernel_size=3, padding=1,
num_svs=1, num_itrs=1,
eps=1e-04),
which_linear=functools.partial(SNLinear,
num_svs=1, num_itrs=1,
eps=1e-04),
activation=nn.ReLU(inplace=True),
upsample=False,
)
self.high_feature_conv = nn.Sequential(
nn.Conv2d((1280+640*4+320)*2, high_feature_channel, kernel_size=1, bias=False),
)
self.high_feature_mix_conv = SegBlock(
in_channels=low_feature_channel+mid_feature_channel+high_feature_channel,
out_channels=low_feature_channel+mid_feature_channel+high_feature_channel,
con_channels=128,
which_conv=functools.partial(SNConv2d,
kernel_size=3, padding=1,
num_svs=1, num_itrs=1,
eps=1e-04),
which_linear=functools.partial(SNLinear,
num_svs=1, num_itrs=1,
eps=1e-04),
activation=nn.ReLU(inplace=True),
upsample=False,
)
self.highest_feature_conv = nn.Sequential(
nn.Conv2d((640+320*6)*2, highest_feature_channel, kernel_size=1, bias=False),
)
self.highest_feature_mix_conv = SegBlock(
in_channels=low_feature_channel+mid_feature_channel+high_feature_channel+highest_feature_channel,
out_channels=low_feature_channel+mid_feature_channel+high_feature_channel+highest_feature_channel,
con_channels=128,
which_conv=functools.partial(SNConv2d,
kernel_size=3, padding=1,
num_svs=1, num_itrs=1,
eps=1e-04),
which_linear=functools.partial(SNLinear,
num_svs=1, num_itrs=1,
eps=1e-04),
activation=nn.ReLU(inplace=True),
upsample=False,
)
feature_dim=low_feature_channel+mid_feature_channel+high_feature_channel+highest_feature_channel
query_dim=feature_dim*16
decoder_layer = TransformerDecoderLayer(embedding_dim, num_heads, hidden_dim, dropout_rate)
self.transfromer_decoder = TransformerDecoder(decoder_layer, num_layers)
self.mlp = MLP(embedding_dim, embedding_dim, feature_dim, 3)
context_dim=768
self.to_k = nn.Linear(query_dim, embedding_dim, bias=False)
self.to_q = nn.Linear(context_dim, embedding_dim, bias=False)
def forward(self,diffusion_feature,text_embedding):
image_feature=self._prepare_features(diffusion_feature)
final_image_feature=F.interpolate(image_feature, size=512, mode='bilinear', align_corners=False)
b=final_image_feature.size()[0]
patch_size = 4
patch_number=int(image_feature.size()[2]/patch_size)
image_feature = torch.nn.functional.unfold(image_feature, patch_size, stride=patch_size).transpose(1,2).contiguous()
image_feature=rearrange(image_feature, 'b n d -> (b n) d ')
text_embedding=rearrange(text_embedding, 'b n d -> (b n) d ')
q = self.to_q(text_embedding)
k = self.to_k(image_feature)
output_query = self.transfromer_decoder(q, k, None)
output_query=rearrange(output_query, '(b n) d -> b n d',b=b)
mask_embedding=self.mlp(output_query)
seg_result=einsum('b d h w, b n d -> b n h w', final_image_feature, mask_embedding)
return seg_result
def _prepare_features(self, features, upsample='bilinear'):
self.low_feature_size = 16
self.mid_feature_size = 32
self.high_feature_size = 64
low_features = [
F.interpolate(i, size=self.low_feature_size, mode=upsample, align_corners=False) for i in features["low"]
]
low_features = torch.cat(low_features, dim=1)
mid_features = [
F.interpolate(i, size=self.mid_feature_size, mode=upsample, align_corners=False) for i in features["mid"]
]
mid_features = torch.cat(mid_features, dim=1)
high_features = [
F.interpolate(i, size=self.high_feature_size, mode=upsample, align_corners=False) for i in features["high"]
]
high_features = torch.cat(high_features, dim=1)
highest_features=torch.cat(features["highest"],dim=1)
features_dict = {
'low': low_features,
'mid': mid_features,
'high': high_features,
'highest':highest_features,
}
low_feat = self.low_feature_conv(features_dict['low'])
low_feat = F.interpolate(low_feat, size=self.mid_feature_size, mode='bilinear', align_corners=False)
mid_feat = self.mid_feature_conv(features_dict['mid'])
mid_feat = torch.cat([low_feat, mid_feat], dim=1)
mid_feat = self.mid_feature_mix_conv(mid_feat, y=None)
mid_feat = F.interpolate(mid_feat, size=self.high_feature_size, mode='bilinear', align_corners=False)
high_feat = self.high_feature_conv(features_dict['high'])
high_feat = torch.cat([mid_feat, high_feat], dim=1)
high_feat = self.high_feature_mix_conv(high_feat, y=None)
highest_feat=self.highest_feature_conv(features_dict['highest'])
highest_feat=torch.cat([high_feat,highest_feat],dim=1)
highest_feat=self.highest_feature_mix_conv(highest_feat,y=None)
return highest_feat