wangyanhui666's picture
fine tune decoder with mask
9cf79cf
# pytorch_diffusion + derived encoder decoder
import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
import torch.nn.functional as F
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False)
else:
self.nin_shortcut = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(h)
else:
x = self.nin_shortcut(h)
return x+h
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=False, in_channels,
resolution, z_channels, double_z=True, **ignore_kwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1,
bias=False)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
down = nn.Module()
down.block = block
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
2*z_channels if double_z else z_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
#assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, **ignorekwargs):
super().__init__()
self.ch = ch # 128
self.temb_ch = 0
self.num_resolutions = len(ch_mult) # 4
self.num_res_blocks = num_res_blocks # 2
self.resolution = resolution # 256
self.in_channels = in_channels # 3
self.give_pre_end = give_pre_end
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,)+tuple(ch_mult) # 没有用
block_in = ch*ch_mult[self.num_resolutions-1] #
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch, # 0
dropout=dropout) # 0.0
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)): # 4个
block = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
# print("i_level=", i_level, "block_in=", block_in, "block_out=", block_out)
block_in = block_out
up = nn.Module()
up.block = block
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv) # Ture
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z):
#assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks):
h = self.up[i_level].block[i_block](h, temb)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def init_weights_zero(m):
if isinstance(m, nn.Conv2d):
nn.init.constant_(m.weight, 0)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def init_weights_kaiming(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
##########################################################################
##---------- Prompt Gen Module -----------------------
class PromptGenBlock(nn.Module):
def __init__(self, prompt_dim=128, prompt_size=96, prompt_len=1):
super(PromptGenBlock,self).__init__()
self.prompt_param = nn.Parameter(torch.rand(1, prompt_len, prompt_dim, prompt_size, prompt_size)) # (1, 1, 128, 96, 96)
self.conv3x3 = nn.Conv2d(prompt_dim, prompt_dim, kernel_size=3, stride=1, padding=1, bias=False)
# self.conv3x3.apply(self.init_weights_zero)
def init_weights_zero(self, m):
if isinstance(m, nn.Conv2d):
nn.init.constant_(m.weight, 0)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
B, C, H, W = x.shape
prompt = self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)
prompt = torch.sum(prompt, dim=1) # (B, prompt_dim, prompt_size, prompt_size)
prompt = F.interpolate(prompt, (H, W), mode="bilinear") #
prompt = self.conv3x3(prompt) # (B, prompt_dim, H, W)
return prompt
class Attention(nn.Module):
def __init__(self, dim, num_heads, bias, prompt_dim=192):
super(Attention, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.shared_mlp = nn.Sequential(
# nn.Linear(prompt_dim, dim*2, bias=False),
nn.Conv2d(prompt_dim, dim*2, kernel_size=1, bias=bias)
)
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.qkv.apply(init_weights_kaiming)
self.qkv_dwconv.apply(init_weights_kaiming)
def forward(self, x, prompt):
b, c, h, w = x.shape
prompt = self.shared_mlp(prompt)
prompt = prompt.expand(b, -1, -1, -1)
gama, beta = prompt.chunk(2, dim=1)
x = x *( 1 + gama) + beta
qkv = self.qkv_dwconv(self.qkv(x))
q,k,v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(DepthwiseSeparableConv, self).__init__()
# Depthwise convolution
self.depthwise_conv = nn.Conv2d(
in_channels,
in_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=in_channels,
bias=False
)
# Pointwise convolution
self.pointwise_conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False
)
def forward(self, x):
out = self.depthwise_conv(x)
out = self.pointwise_conv(out)
return out
class SFT(nn.Module):
def __init__(self, x_dim, prompt_dim=192, ks=3, nhidden=128):
super(SFT, self).__init__()
pw = ks // 2
self.mlp_shared = nn.Sequential(
nn.Conv2d(prompt_dim, nhidden, kernel_size=1),
nn.ReLU()
)
self.mlp_gama = DepthwiseSeparableConv(nhidden, x_dim, kernel_size=ks, padding=pw)
self.mlp_beta = DepthwiseSeparableConv(nhidden, x_dim, kernel_size=ks, padding=pw)
# self.mlp_shared.apply(init_weights_zero) # Initialize shared_mlp
# self.mlp_gama.apply(init_weights_zero)
# self.mlp_beta.apply(init_weights_zero)
def forward(self, x, prompt):
actv = self.mlp_shared(prompt)
gama = self.mlp_gama(actv)
beta = self.mlp_beta(actv)
# print("gama_max=", gama.max())
# print("beta_max=", beta.max())
# print("gama_min=", gama.min())
# print("beta_min=", beta.min())
out = x * (1 + gama) + beta
return out
def init_weights_gama(m):
if isinstance(m, nn.Conv2d):
nn.init.constant_(m.weight, 0)
center = m.kernel_size[0] // 2
if m.groups == m.in_channels and m.in_channels == m.out_channels: # depthwise conv
nn.init.constant_(m.weight[:, :, center, center], 1)
else: # pointwise conv
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
class SFT_new(nn.Module):
def __init__(self, x_dim, prompt_dim=192, ks=3, nhidden=128):
super(SFT_new, self).__init__()
pw = ks // 2
self.mlp_shared = nn.Sequential(
nn.Conv2d(prompt_dim, nhidden, kernel_size=1),
nn.ReLU()
)
self.mlp_gama = DepthwiseSeparableConv(nhidden, x_dim, kernel_size=ks, padding=pw)
self.mlp_beta = DepthwiseSeparableConv(nhidden, x_dim, kernel_size=ks, padding=pw)
# self.mlp_shared.apply(init_weights_zero) # Initialize shared_mlp
# self.mlp_gama.apply(init_weights_gama)
# self.mlp_beta.apply(init_weights_zero)
def forward(self, x, prompt):
actv = self.mlp_shared(prompt)
gama = self.mlp_gama(actv)
beta = self.mlp_beta(actv)
out = x * gama + beta
return out
class Decoder_w_Prompt(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, **ignorekwargs):
super().__init__()
self.ch = ch # 128
self.temb_ch = 0
self.num_resolutions = len(ch_mult) # 4
self.num_res_blocks = num_res_blocks # 2
self.resolution = resolution # 256
self.in_channels = in_channels # 3
self.give_pre_end = give_pre_end
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,)+tuple(ch_mult) # 没有用
block_in = ch*ch_mult[self.num_resolutions-1] # 128 * 4
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch, # 0
dropout=dropout) # 0.0
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)): # 4,3,2,1,0
block = nn.ModuleList()
block_out = ch*ch_mult[i_level] # ch*8
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, # 512
out_channels=block_out, #
temb_channels=self.temb_ch,
dropout=dropout))
# print("i_level=", i_level, "block_in=", block_in, "block_out=", block_out)
block_in = block_out
up = nn.Module()
up.block = block
if i_level != 0:
if i_level == 1:
up.prompt = PromptGenBlock(prompt_dim=128, prompt_size=16)
up.prompt_attn = Attention(dim=128, num_heads=8, bias=True, prompt_dim=128)
# up.prompt_sft = SFT(x_dim=128, prompt_dim=128)
elif i_level == 2:
up.prompt = PromptGenBlock(prompt_dim=256, prompt_size=32)
# up.prompt_sft = SFT(x_dim=256, prompt_dim=256)
up.prompt_attn = Attention(dim=256, num_heads=8, bias=True, prompt_dim=256)
elif i_level == 3:
up.prompt = PromptGenBlock(prompt_dim=256, prompt_size=64)
# up.prompt_sft = SFT(x_dim=256, prompt_dim=256)
up.prompt_attn = Attention(dim=256, num_heads=8, bias=True, prompt_dim=256)
elif i_level == 4:
up.prompt = PromptGenBlock(prompt_dim=512, prompt_size=128)
# up.prompt_sft = SFT(x_dim=512, prompt_dim=512)
up.upsample = Upsample(block_in, resamp_with_conv) # Ture
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z):
#assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z) # 256->512
# middle
h = self.mid.block_1(h, temb) # 512
h = self.mid.block_2(h, temb) # 512
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks):
h = self.up[i_level].block[i_block](h, temb)
if i_level != 0:
# prompt & attention
if i_level != 4:
prompt = self.up[i_level].prompt(h)
# h = self.up[i_level].prompt_sft(h, prompt)
h = self.up[i_level].prompt_attn(h, prompt)
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end: # False
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h