from audioop import reverse import math import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init # We use the Swish activation function class Swish(nn.Module): def forward(self,x): return x * torch.sigmoid(x) # Time embedding FFN for embedding timestep info class TimeEmbedding(nn.Module): def __init__(self, T, d_model, dim): assert d_model % 2 == 0 super().__init__() emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000) emb = torch.exp(-emb) pos = torch.arange(T).float() emb = pos[:, None] * emb[None, :] assert list(emb.shape) == [T, d_model // 2] emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) assert list(emb.shape) == [T, d_model // 2, 2] emb = emb.view(T, d_model) self.timembedding = nn.Sequential( nn.Embedding.from_pretrained(emb), nn.Linear(d_model, dim), Swish(), nn.Linear(dim, dim), ) self.initialize() def initialize(self): for module in self.modules(): if isinstance(module, nn.Linear): init.xavier_uniform_(module.weight) init.zeros_(module.bias) def forward(self, t): emb = self.timembedding(t) return emb class DownSample(nn.Module): def __init__(self, in_ch) -> None: super().__init__() self.down = nn.Conv2d(in_ch,in_ch,3,2,1) self.initialize def initialize(self): nn.init.xavier_uniform_(self.down.weight) nn.init.zeros_(self.down.bias) def forward(self,x,temb): x = self.down(x) return x class UpSample(nn.Module): def __init__(self, in_ch) -> None: super().__init__() self.up = nn.Conv2d(in_ch,in_ch,3,1,1) self.initialize def initialize(self): nn.init.xavier_uniform_(self.up.weight) nn.init.zeros_(self.up.bias) def forward(self,x,temb): x = F.interpolate( x,scale_factor=2,mode='nearest' ) x = self.up(x) return x class AttnBlock(nn.Module): def __init__(self, in_ch) -> None: super().__init__() self.group_norm = nn.GroupNorm(32,in_ch) self.proj_q = nn.Conv2d(in_ch,in_ch,1,1,padding=0) self.proj_k = nn.Conv2d(in_ch,in_ch,1,1,padding=0) self.proj_v = nn.Conv2d(in_ch,in_ch,1,1,padding=0) self.project = nn.Conv2d(in_ch,in_ch,1,1,padding=0) self.initialize def initialize(self): for module in [self.proj_k,self.proj_q,self.proj_v,self.project]: nn.init.xavier_uniform_(module.weight) nn.init.zeros_(module.bias) nn.init.xavier_uniform_(self.project.weight,gain=1e-5) def forward(self,x): B, C,H,W = x.shape h = self.group_norm(x) q = self.proj_q(h) k = self.proj_k(h) v = self.proj_v(h) q = q.permute(0,2,3,1).view(B,H*W,C) k = k.view(B,C,H*W) """ -> torch.bmm(q,k) => Dot product for batch size B. Look at torch.bmm docs -> (int(C) ** (-0.5)): This scales the result of the batch matrix multiplication by the inverse square root of the dimension C. The scaling factor (int(C) ** (-0.5)) is used to prevent the dot products from growing too large, which can lead to small gradients and slow learning. This is a common technique in attention mechanisms to stabilize training. """ w = torch.bmm(q,k) * (C ** (-0.5)) # Scaled attention dor product assert list(w.shape) == [B, H*W,H*W] w = F.softmax(w,-1) v = v.permute(0,2,3,1).view(B,H*W,C) h = torch.bmm(w,v) assert list(h.shape) == [B,H*W,C] h = h.view(B,H,W,C).permute(0,3,1,2) h = self.project(h) return x + h class ResBlock(nn.Module): def __init__(self, in_ch,out_ch,t_dim,dropout,attn=False) -> None: super().__init__() self.block1 = nn.Sequential( nn.GroupNorm(32,in_ch), Swish(), nn.Conv2d(in_ch,out_ch,3,1,1) ) self.temb_proj = nn.Sequential( Swish(), nn.Linear(t_dim,out_ch) ) self.block2 = nn.Sequential( nn.GroupNorm(32,out_ch), Swish(), nn.Dropout(dropout), nn.Conv2d(out_ch,out_ch,3,1,1) ) if in_ch != out_ch: self.shortcut = nn.Conv2d(in_ch,out_ch,1,1,0) else: self.shortcut = nn.Identity() if attn: self.attn = AttnBlock(out_ch) else: self.attn = nn.Identity() self.initialization def initialization(self): for module in self.modules(): if isinstance(module, (nn.Conv2d,nn.Linear)): nn.init.xavier_uniform_(module.weight) nn.init.zeros_(module.bias) nn.init.xavier_uniform_(self.block2[-1].weight, gain=1e-5) def forward(self,x,temb): h = self.block1(x) h += self.temb_proj(temb)[:,:,None,None] h = self.block2(h) h += self.shortcut(x) h = self.attn(h) return h class UNet(nn.Module): def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): super().__init__() assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' tdim = ch * 4 self.time_embedding = TimeEmbedding(T, ch, tdim) self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) self.downblocks = nn.ModuleList() chs = [ch] # record output channel when dowmsample for upsample now_ch = ch for i, mult in enumerate(ch_mult): out_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock( in_ch=now_ch, out_ch=out_ch, t_dim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.middleblocks = nn.ModuleList([ ResBlock(now_ch, now_ch, tdim, dropout, attn=True), ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): out_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock( in_ch=chs.pop() + now_ch, out_ch=out_ch, t_dim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.tail = nn.Sequential( nn.GroupNorm(32, now_ch), Swish(), nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) ) self.initialize() def initialize(self): init.xavier_uniform_(self.head.weight) init.zeros_(self.head.bias) init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) init.zeros_(self.tail[-1].bias) def forward(self, x, t): # Timestep embedding temb = self.time_embedding(t) # Downsampling h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, temb) hs.append(h) # Middle for layer in self.middleblocks: h = layer(h, temb) # Upsampling for layer in self.upblocks: if isinstance(layer, ResBlock): h = torch.cat([h, hs.pop()], dim=1) h = layer(h, temb) h = self.tail(h) assert len(hs) == 0 return h if __name__ == "__main__": batch_size = 128 model = UNet(T = 1000,ch=128,ch_mult=[1,2,2,2],attn=[1], num_res_blocks=2,dropout=0.1) x = torch.randn(batch_size,3,32,32) t = torch.randint(1000,(batch_size,)) y = model(x,t) print(y.shape)