Diffusion-DDIM / model.py
Yash Nagraj
Add the model and training scipts and some cleaning
990d40a
from audioop import reverse
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
# We use the Swish activation function
class Swish(nn.Module):
def forward(self,x):
return x * torch.sigmoid(x)
# Time embedding FFN for embedding timestep info
class TimeEmbedding(nn.Module):
def __init__(self, T, d_model, dim):
assert d_model % 2 == 0
super().__init__()
emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
emb = torch.exp(-emb)
pos = torch.arange(T).float()
emb = pos[:, None] * emb[None, :]
assert list(emb.shape) == [T, d_model // 2]
emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
assert list(emb.shape) == [T, d_model // 2, 2]
emb = emb.view(T, d_model)
self.timembedding = nn.Sequential(
nn.Embedding.from_pretrained(emb),
nn.Linear(d_model, dim),
Swish(),
nn.Linear(dim, dim),
)
self.initialize()
def initialize(self):
for module in self.modules():
if isinstance(module, nn.Linear):
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
def forward(self, t):
emb = self.timembedding(t)
return emb
class DownSample(nn.Module):
def __init__(self, in_ch) -> None:
super().__init__()
self.down = nn.Conv2d(in_ch,in_ch,3,2,1)
self.initialize
def initialize(self):
nn.init.xavier_uniform_(self.down.weight)
nn.init.zeros_(self.down.bias)
def forward(self,x,temb):
x = self.down(x)
return x
class UpSample(nn.Module):
def __init__(self, in_ch) -> None:
super().__init__()
self.up = nn.Conv2d(in_ch,in_ch,3,1,1)
self.initialize
def initialize(self):
nn.init.xavier_uniform_(self.up.weight)
nn.init.zeros_(self.up.bias)
def forward(self,x,temb):
x = F.interpolate(
x,scale_factor=2,mode='nearest'
)
x = self.up(x)
return x
class AttnBlock(nn.Module):
def __init__(self, in_ch) -> None:
super().__init__()
self.group_norm = nn.GroupNorm(32,in_ch)
self.proj_q = nn.Conv2d(in_ch,in_ch,1,1,padding=0)
self.proj_k = nn.Conv2d(in_ch,in_ch,1,1,padding=0)
self.proj_v = nn.Conv2d(in_ch,in_ch,1,1,padding=0)
self.project = nn.Conv2d(in_ch,in_ch,1,1,padding=0)
self.initialize
def initialize(self):
for module in [self.proj_k,self.proj_q,self.proj_v,self.project]:
nn.init.xavier_uniform_(module.weight)
nn.init.zeros_(module.bias)
nn.init.xavier_uniform_(self.project.weight,gain=1e-5)
def forward(self,x):
B, C,H,W = x.shape
h = self.group_norm(x)
q = self.proj_q(h)
k = self.proj_k(h)
v = self.proj_v(h)
q = q.permute(0,2,3,1).view(B,H*W,C)
k = k.view(B,C,H*W)
"""
-> torch.bmm(q,k) => Dot product for batch size B. Look at torch.bmm docs
-> (int(C) ** (-0.5)): This scales the result of the batch matrix multiplication by the inverse square root of the dimension C.
The scaling factor (int(C) ** (-0.5)) is used to prevent the dot products from growing too large, which can lead to small gradients and slow learning.
This is a common technique in attention mechanisms to stabilize training.
"""
w = torch.bmm(q,k) * (C ** (-0.5)) # Scaled attention dor product
assert list(w.shape) == [B, H*W,H*W]
w = F.softmax(w,-1)
v = v.permute(0,2,3,1).view(B,H*W,C)
h = torch.bmm(w,v)
assert list(h.shape) == [B,H*W,C]
h = h.view(B,H,W,C).permute(0,3,1,2)
h = self.project(h)
return x + h
class ResBlock(nn.Module):
def __init__(self, in_ch,out_ch,t_dim,dropout,attn=False) -> None:
super().__init__()
self.block1 = nn.Sequential(
nn.GroupNorm(32,in_ch),
Swish(),
nn.Conv2d(in_ch,out_ch,3,1,1)
)
self.temb_proj = nn.Sequential(
Swish(),
nn.Linear(t_dim,out_ch)
)
self.block2 = nn.Sequential(
nn.GroupNorm(32,out_ch),
Swish(),
nn.Dropout(dropout),
nn.Conv2d(out_ch,out_ch,3,1,1)
)
if in_ch != out_ch:
self.shortcut = nn.Conv2d(in_ch,out_ch,1,1,0)
else:
self.shortcut = nn.Identity()
if attn:
self.attn = AttnBlock(out_ch)
else:
self.attn = nn.Identity()
self.initialization
def initialization(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d,nn.Linear)):
nn.init.xavier_uniform_(module.weight)
nn.init.zeros_(module.bias)
nn.init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)
def forward(self,x,temb):
h = self.block1(x)
h += self.temb_proj(temb)[:,:,None,None]
h = self.block2(h)
h += self.shortcut(x)
h = self.attn(h)
return h
class UNet(nn.Module):
def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
super().__init__()
assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
tdim = ch * 4
self.time_embedding = TimeEmbedding(T, ch, tdim)
self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
self.downblocks = nn.ModuleList()
chs = [ch] # record output channel when dowmsample for upsample
now_ch = ch
for i, mult in enumerate(ch_mult):
out_ch = ch * mult
for _ in range(num_res_blocks):
self.downblocks.append(ResBlock(
in_ch=now_ch, out_ch=out_ch, t_dim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
chs.append(now_ch)
if i != len(ch_mult) - 1:
self.downblocks.append(DownSample(now_ch))
chs.append(now_ch)
self.middleblocks = nn.ModuleList([
ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
])
self.upblocks = nn.ModuleList()
for i, mult in reversed(list(enumerate(ch_mult))):
out_ch = ch * mult
for _ in range(num_res_blocks + 1):
self.upblocks.append(ResBlock(
in_ch=chs.pop() + now_ch, out_ch=out_ch, t_dim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
if i != 0:
self.upblocks.append(UpSample(now_ch))
assert len(chs) == 0
self.tail = nn.Sequential(
nn.GroupNorm(32, now_ch),
Swish(),
nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
)
self.initialize()
def initialize(self):
init.xavier_uniform_(self.head.weight)
init.zeros_(self.head.bias)
init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
init.zeros_(self.tail[-1].bias)
def forward(self, x, t):
# Timestep embedding
temb = self.time_embedding(t)
# Downsampling
h = self.head(x)
hs = [h]
for layer in self.downblocks:
h = layer(h, temb)
hs.append(h)
# Middle
for layer in self.middleblocks:
h = layer(h, temb)
# Upsampling
for layer in self.upblocks:
if isinstance(layer, ResBlock):
h = torch.cat([h, hs.pop()], dim=1)
h = layer(h, temb)
h = self.tail(h)
assert len(hs) == 0
return h
if __name__ == "__main__":
batch_size = 128
model = UNet(T = 1000,ch=128,ch_mult=[1,2,2,2],attn=[1],
num_res_blocks=2,dropout=0.1)
x = torch.randn(batch_size,3,32,32)
t = torch.randint(1000,(batch_size,))
y = model(x,t)
print(y.shape)