mini-ddpm / model.py
caixiaoshun's picture
Create model.py
020b1da verified
import torch
from torch import nn
from torch.nn import functional as F
class Attention(nn.Module):
def __init__(self, n_head, dim):
super().__init__()
assert dim % n_head == 0
self.qkv_proj = nn.Linear(dim, dim * 3)
self.out_proj = nn.Linear(dim, dim)
self.n_head = n_head
self.head_dim = dim // self.n_head
def forward(self, x: torch.Tensor):
batch_size, channel, height, width = x.shape
x = x.reshape(batch_size, channel, height * width).transpose(-1, -2)
q, k, v = torch.chunk(self.qkv_proj(x), chunks=3, dim=-1)
q_state = q.reshape(
batch_size, height * width, self.n_head, self.head_dim
).transpose(1, 2)
k_state = k.reshape(
batch_size, height * width, self.n_head, self.head_dim
).transpose(1, 2)
v_state = v.reshape(
batch_size, height * width, self.n_head, self.head_dim
).transpose(1, 2)
out = F.scaled_dot_product_attention(q_state, k_state, v_state)
out = out.transpose(1, 2).reshape(batch_size, height * width, channel)
out = self.out_proj(out)
out = out.transpose(-1, -2).reshape(batch_size, channel, height, width)
return out
class TimePositionEmbedding(nn.Module):
def __init__(self, seq_len=1000, dim=320):
super().__init__()
base = 10000
inv_freq = 1 / base ** (torch.arange(0, dim, step=2).float() / dim)
inv_freq = inv_freq.unsqueeze(0)
position = torch.arange(0, seq_len, step=1).unsqueeze(1)
position = position * inv_freq
pe = torch.zeros(size=(seq_len, dim))
pe[:, 0::2] = position.sin()
pe[:, 1::2] = position.cos()
self.register_buffer("pe", pe, persistent=False)
def forward(self, time):
time = time.reshape(-1)
return self.pe[time]
class TimeEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4), nn.SiLU(), nn.Linear(dim * 4, dim * 4)
)
def forward(self, x):
return self.mlp(x)
class ResidualBlock(nn.Module):
def __init__(self, in_channel, out_channel, time_dim):
super().__init__()
self.norm1 = nn.GroupNorm(32, in_channel)
self.norm2 = nn.GroupNorm(32, out_channel)
self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1)
self.time_proj = nn.Linear(time_dim, out_channel)
self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1)
self.residual_conv = nn.Identity()
if in_channel != out_channel:
self.residual_conv = nn.Conv2d(in_channel, out_channel, kernel_size=1)
def forward(self, x, time):
residual = x
x = F.silu(self.conv1(self.norm1(x)))
time = self.time_proj(time)[:, :, None, None]
x += time
x = self.norm2(x)
x = F.silu(self.conv2(x))
return self.residual_conv(residual) + x
class DownSampler(nn.Module):
def __init__(self, in_channel):
super().__init__()
self.conv = nn.Conv2d(
in_channel, in_channel, stride=2, padding=1, kernel_size=3
)
def forward(self, x):
return self.conv(x)
class UpSampler(nn.Module):
def __init__(self, in_channel):
super().__init__()
self.conv = nn.Conv2d(
in_channel, in_channel, stride=1, padding=1, kernel_size=3
)
self.up = nn.Upsample(scale_factor=2)
def forward(self, x):
x = self.up(x)
return self.conv(x)
class SwitchSequential(nn.Sequential):
def forward(self, x, time):
for module in self:
if isinstance(module, ResidualBlock):
x = module(x, time)
else:
x = module(x)
return x
class Unet(nn.Module):
def __init__(self, time_dim=320, n_head=8):
super().__init__()
# 时间嵌入
self.time_position_embedding = TimePositionEmbedding()
self.time_proj = TimeEmbedding(dim=320)
time_dim = time_dim * 4
# ---------------- Encoder:保存“下采样前”的特征做 skip ----------------
self.down_blocks = nn.ModuleList(
[
# 输出:128 通道,分辨率 H
SwitchSequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1, stride=1),
ResidualBlock(64, 128, time_dim=time_dim),
ResidualBlock(128, 128, time_dim=time_dim),
),
# 输出:256 通道,分辨率 H/2
SwitchSequential(
ResidualBlock(128, 256, time_dim=time_dim),
ResidualBlock(256, 256, time_dim=time_dim),
),
# 输出:512 通道,分辨率 H/4
SwitchSequential(
ResidualBlock(256, 512, time_dim=time_dim),
ResidualBlock(512, 512, time_dim=time_dim),
Attention(n_head, 512),
ResidualBlock(512, 512, time_dim=time_dim),
),
# 底部:512 通道,分辨率 H/8(无下采样)
SwitchSequential(
ResidualBlock(512, 512, time_dim=time_dim),
Attention(n_head, 512),
ResidualBlock(512, 512, time_dim=time_dim),
),
]
)
self.down_samplers = nn.ModuleList(
[
DownSampler(128), # H -> H/2
DownSampler(256), # H/2 -> H/4
DownSampler(512), # H/4 -> H/8
]
)
# ---------------- Bottleneck ----------------
self.mid_blocks = nn.ModuleList(
[
SwitchSequential(
ResidualBlock(512, 512, time_dim=time_dim),
Attention(n_head, 512),
ResidualBlock(512, 512, time_dim=time_dim),
),
SwitchSequential(
ResidualBlock(512, 512, time_dim=time_dim),
Attention(n_head, 512),
ResidualBlock(512, 512, time_dim=time_dim),
),
SwitchSequential(
ResidualBlock(512, 512, time_dim=time_dim),
Attention(n_head, 512),
ResidualBlock(512, 512, time_dim=time_dim),
),
]
)
# ---------------- Decoder:先上采样,再与对应 skip 拼接 ----------------
# up_blocks[0]:在最底层先做一轮处理(不拼接)
# up_blocks[1]:分辨率 H/4,拼接 skip@H/4(512 通道),输出保持 512
# up_blocks[2]:分辨率 H/2,拼接 skip@H/2(256 通道),输出 256
# up_blocks[3]:分辨率 H,拼接 skip@H(128 通道),输出 64
self.up_blocks = nn.ModuleList(
[
SwitchSequential( # H/8,512 -> 512(不拼接)
ResidualBlock(512, 512, time_dim=time_dim),
Attention(n_head, 512),
ResidualBlock(512, 512, time_dim=time_dim),
),
SwitchSequential( # H/4,(512 + 512) -> 512
ResidualBlock(512 + 512, 512, time_dim=time_dim),
Attention(n_head, 512),
ResidualBlock(512, 512, time_dim=time_dim),
),
SwitchSequential( # H/2,(512 + 256) -> 256
ResidualBlock(512 + 256, 256, time_dim=time_dim),
ResidualBlock(256, 256, time_dim=time_dim),
Attention(n_head, 256),
ResidualBlock(256, 256, time_dim=time_dim),
),
SwitchSequential( # H,(256 + 128) -> 64
ResidualBlock(256 + 128, 64, time_dim=time_dim),
ResidualBlock(64, 64, time_dim=time_dim),
),
]
)
# 与各阶段输出通道匹配的上采样器:
# 先把 512@H/8 上采样到 512@H/4,再 512@H/2,最后 256@H
self.up_samplers = nn.ModuleList(
[
UpSampler(512), # H/8 -> H/4
UpSampler(512), # H/4 -> H/2
UpSampler(256), # H/2 -> H
]
)
self.head = nn.Conv2d(64, 3, kernel_size=3, padding=1, stride=1)
def forward(self, x, time):
# 时间嵌入
t = self.time_proj(self.time_position_embedding(time))
# -------- Encoder:每个 down_block 输出作为 pre-down skip,然后再下采样 --------
skips = []
for i, block in enumerate(self.down_blocks):
x = block(x, t) # 处理当前分辨率
skips.append(x) # 保存“下采样前”的特征
if i < len(self.down_samplers):
x = self.down_samplers[i](x) # 下采样到更小分辨率
# -------- Bottleneck --------
for block in self.mid_blocks:
x = block(x, t)
# -------- Decoder --------
# 底部先做一轮处理(不拼接)
x = self.up_blocks[0](x, t) # 仍在 H/8,通道 512
# Stage 1:H/8 -> H/4,拼接 skip@H/4(skips[2])
x = self.up_samplers[0](x) # 512@H/4
x = torch.cat([x, skips[2]], dim=1) # (512 + 512)@H/4
x = self.up_blocks[1](x, t) # 512@H/4
# Stage 2:H/4 -> H/2,拼接 skip@H/2(skips[1])
x = self.up_samplers[1](x) # 512@H/2
x = torch.cat([x, skips[1]], dim=1) # (512 + 256)@H/2
x = self.up_blocks[2](x, t) # 256@H/2
# Stage 3:H/2 -> H,拼接 skip@H(skips[0])
x = self.up_samplers[2](x) # 256@H
x = torch.cat([x, skips[0]], dim=1) # (256 + 128)@H
x = self.up_blocks[3](x, t) # 64@H
# 头部
x = self.head(x) # -> 3@H
return x
if __name__ == "__main__":
model = Unet()
x = torch.randn(2, 3, 64, 64)
t = torch.randint(0, 1000, (2,))
out = model(x, t)
print(out.shape)
# torch.save({"model": model.state_dict()}, "unet.pt")