| from audioop import reverse |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.nn.init as init |
|
|
| |
| class Swish(nn.Module): |
| def forward(self,x): |
| return x * torch.sigmoid(x) |
| |
|
|
| |
| class TimeEmbedding(nn.Module): |
| def __init__(self, T, d_model, dim): |
| assert d_model % 2 == 0 |
| super().__init__() |
| emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000) |
| emb = torch.exp(-emb) |
| pos = torch.arange(T).float() |
| emb = pos[:, None] * emb[None, :] |
| assert list(emb.shape) == [T, d_model // 2] |
| emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) |
| assert list(emb.shape) == [T, d_model // 2, 2] |
| emb = emb.view(T, d_model) |
|
|
| self.timembedding = nn.Sequential( |
| nn.Embedding.from_pretrained(emb), |
| nn.Linear(d_model, dim), |
| Swish(), |
| nn.Linear(dim, dim), |
| ) |
| self.initialize() |
|
|
| def initialize(self): |
| for module in self.modules(): |
| if isinstance(module, nn.Linear): |
| init.xavier_uniform_(module.weight) |
| init.zeros_(module.bias) |
|
|
| def forward(self, t): |
| emb = self.timembedding(t) |
| return emb |
|
|
| class DownSample(nn.Module): |
| def __init__(self, in_ch) -> None: |
| super().__init__() |
| self.down = nn.Conv2d(in_ch,in_ch,3,2,1) |
| self.initialize |
|
|
| def initialize(self): |
| nn.init.xavier_uniform_(self.down.weight) |
| nn.init.zeros_(self.down.bias) |
|
|
| def forward(self,x,temb): |
| x = self.down(x) |
| return x |
| |
|
|
| class UpSample(nn.Module): |
| def __init__(self, in_ch) -> None: |
| super().__init__() |
| self.up = nn.Conv2d(in_ch,in_ch,3,1,1) |
| self.initialize |
|
|
| def initialize(self): |
| nn.init.xavier_uniform_(self.up.weight) |
| nn.init.zeros_(self.up.bias) |
|
|
| def forward(self,x,temb): |
| x = F.interpolate( |
| x,scale_factor=2,mode='nearest' |
| ) |
| x = self.up(x) |
| return x |
| |
| class AttnBlock(nn.Module): |
| def __init__(self, in_ch) -> None: |
| super().__init__() |
| self.group_norm = nn.GroupNorm(32,in_ch) |
| self.proj_q = nn.Conv2d(in_ch,in_ch,1,1,padding=0) |
| self.proj_k = nn.Conv2d(in_ch,in_ch,1,1,padding=0) |
| self.proj_v = nn.Conv2d(in_ch,in_ch,1,1,padding=0) |
| self.project = nn.Conv2d(in_ch,in_ch,1,1,padding=0) |
| self.initialize |
|
|
| def initialize(self): |
| for module in [self.proj_k,self.proj_q,self.proj_v,self.project]: |
| nn.init.xavier_uniform_(module.weight) |
| nn.init.zeros_(module.bias) |
| nn.init.xavier_uniform_(self.project.weight,gain=1e-5) |
|
|
| def forward(self,x): |
| B, C,H,W = x.shape |
| h = self.group_norm(x) |
| q = self.proj_q(h) |
| k = self.proj_k(h) |
| v = self.proj_v(h) |
|
|
| q = q.permute(0,2,3,1).view(B,H*W,C) |
| k = k.view(B,C,H*W) |
|
|
| """ |
| -> torch.bmm(q,k) => Dot product for batch size B. Look at torch.bmm docs |
| -> (int(C) ** (-0.5)): This scales the result of the batch matrix multiplication by the inverse square root of the dimension C. |
| The scaling factor (int(C) ** (-0.5)) is used to prevent the dot products from growing too large, which can lead to small gradients and slow learning. |
| This is a common technique in attention mechanisms to stabilize training. |
| """ |
|
|
| w = torch.bmm(q,k) * (C ** (-0.5)) |
| assert list(w.shape) == [B, H*W,H*W] |
| w = F.softmax(w,-1) |
|
|
|
|
| v = v.permute(0,2,3,1).view(B,H*W,C) |
| h = torch.bmm(w,v) |
| assert list(h.shape) == [B,H*W,C] |
| h = h.view(B,H,W,C).permute(0,3,1,2) |
| h = self.project(h) |
| return x + h |
|
|
| class ResBlock(nn.Module): |
| def __init__(self, in_ch,out_ch,t_dim,dropout,attn=False) -> None: |
| super().__init__() |
| self.block1 = nn.Sequential( |
| nn.GroupNorm(32,in_ch), |
| Swish(), |
| nn.Conv2d(in_ch,out_ch,3,1,1) |
| ) |
|
|
| self.temb_proj = nn.Sequential( |
| Swish(), |
| nn.Linear(t_dim,out_ch) |
| ) |
|
|
| self.block2 = nn.Sequential( |
| nn.GroupNorm(32,out_ch), |
| Swish(), |
| nn.Dropout(dropout), |
| nn.Conv2d(out_ch,out_ch,3,1,1) |
| ) |
| if in_ch != out_ch: |
| self.shortcut = nn.Conv2d(in_ch,out_ch,1,1,0) |
| else: |
| self.shortcut = nn.Identity() |
|
|
| if attn: |
| self.attn = AttnBlock(out_ch) |
| else: |
| self.attn = nn.Identity() |
|
|
| self.initialization |
|
|
| def initialization(self): |
| for module in self.modules(): |
| if isinstance(module, (nn.Conv2d,nn.Linear)): |
| nn.init.xavier_uniform_(module.weight) |
| nn.init.zeros_(module.bias) |
|
|
| nn.init.xavier_uniform_(self.block2[-1].weight, gain=1e-5) |
|
|
| def forward(self,x,temb): |
| h = self.block1(x) |
| h += self.temb_proj(temb)[:,:,None,None] |
| h = self.block2(h) |
|
|
| h += self.shortcut(x) |
| h = self.attn(h) |
| return h |
| |
|
|
| class UNet(nn.Module): |
| def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): |
| super().__init__() |
| assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' |
| tdim = ch * 4 |
| self.time_embedding = TimeEmbedding(T, ch, tdim) |
|
|
| self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) |
| self.downblocks = nn.ModuleList() |
| chs = [ch] |
| now_ch = ch |
| for i, mult in enumerate(ch_mult): |
| out_ch = ch * mult |
| for _ in range(num_res_blocks): |
| self.downblocks.append(ResBlock( |
| in_ch=now_ch, out_ch=out_ch, t_dim=tdim, |
| dropout=dropout, attn=(i in attn))) |
| now_ch = out_ch |
| chs.append(now_ch) |
| if i != len(ch_mult) - 1: |
| self.downblocks.append(DownSample(now_ch)) |
| chs.append(now_ch) |
|
|
| self.middleblocks = nn.ModuleList([ |
| ResBlock(now_ch, now_ch, tdim, dropout, attn=True), |
| ResBlock(now_ch, now_ch, tdim, dropout, attn=False), |
| ]) |
|
|
| self.upblocks = nn.ModuleList() |
| for i, mult in reversed(list(enumerate(ch_mult))): |
| out_ch = ch * mult |
| for _ in range(num_res_blocks + 1): |
| self.upblocks.append(ResBlock( |
| in_ch=chs.pop() + now_ch, out_ch=out_ch, t_dim=tdim, |
| dropout=dropout, attn=(i in attn))) |
| now_ch = out_ch |
| if i != 0: |
| self.upblocks.append(UpSample(now_ch)) |
| assert len(chs) == 0 |
|
|
| self.tail = nn.Sequential( |
| nn.GroupNorm(32, now_ch), |
| Swish(), |
| nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) |
| ) |
| self.initialize() |
|
|
| def initialize(self): |
| init.xavier_uniform_(self.head.weight) |
| init.zeros_(self.head.bias) |
| init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) |
| init.zeros_(self.tail[-1].bias) |
|
|
| def forward(self, x, t): |
| |
| temb = self.time_embedding(t) |
| |
| h = self.head(x) |
| hs = [h] |
| for layer in self.downblocks: |
| h = layer(h, temb) |
| hs.append(h) |
| |
| for layer in self.middleblocks: |
| h = layer(h, temb) |
| |
| for layer in self.upblocks: |
| if isinstance(layer, ResBlock): |
| h = torch.cat([h, hs.pop()], dim=1) |
| h = layer(h, temb) |
| h = self.tail(h) |
|
|
| assert len(hs) == 0 |
| return h |
| |
|
|
| if __name__ == "__main__": |
| batch_size = 128 |
| model = UNet(T = 1000,ch=128,ch_mult=[1,2,2,2],attn=[1], |
| num_res_blocks=2,dropout=0.1) |
| |
| x = torch.randn(batch_size,3,32,32) |
| t = torch.randint(1000,(batch_size,)) |
| y = model(x,t) |
| print(y.shape) |