import torch from torch import nn from torch.nn import functional as F class Attention(nn.Module): def __init__(self, n_head, dim): super().__init__() assert dim % n_head == 0 self.qkv_proj = nn.Linear(dim, dim * 3) self.out_proj = nn.Linear(dim, dim) self.n_head = n_head self.head_dim = dim // self.n_head def forward(self, x: torch.Tensor): batch_size, channel, height, width = x.shape x = x.reshape(batch_size, channel, height * width).transpose(-1, -2) q, k, v = torch.chunk(self.qkv_proj(x), chunks=3, dim=-1) q_state = q.reshape( batch_size, height * width, self.n_head, self.head_dim ).transpose(1, 2) k_state = k.reshape( batch_size, height * width, self.n_head, self.head_dim ).transpose(1, 2) v_state = v.reshape( batch_size, height * width, self.n_head, self.head_dim ).transpose(1, 2) out = F.scaled_dot_product_attention(q_state, k_state, v_state) out = out.transpose(1, 2).reshape(batch_size, height * width, channel) out = self.out_proj(out) out = out.transpose(-1, -2).reshape(batch_size, channel, height, width) return out class TimePositionEmbedding(nn.Module): def __init__(self, seq_len=1000, dim=320): super().__init__() base = 10000 inv_freq = 1 / base ** (torch.arange(0, dim, step=2).float() / dim) inv_freq = inv_freq.unsqueeze(0) position = torch.arange(0, seq_len, step=1).unsqueeze(1) position = position * inv_freq pe = torch.zeros(size=(seq_len, dim)) pe[:, 0::2] = position.sin() pe[:, 1::2] = position.cos() self.register_buffer("pe", pe, persistent=False) def forward(self, time): time = time.reshape(-1) return self.pe[time] class TimeEmbedding(nn.Module): def __init__(self, dim): super().__init__() self.mlp = nn.Sequential( nn.Linear(dim, dim * 4), nn.SiLU(), nn.Linear(dim * 4, dim * 4) ) def forward(self, x): return self.mlp(x) class ResidualBlock(nn.Module): def __init__(self, in_channel, out_channel, time_dim): super().__init__() self.norm1 = nn.GroupNorm(32, in_channel) self.norm2 = nn.GroupNorm(32, out_channel) self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1) self.time_proj = nn.Linear(time_dim, out_channel) self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1) self.residual_conv = nn.Identity() if in_channel != out_channel: self.residual_conv = nn.Conv2d(in_channel, out_channel, kernel_size=1) def forward(self, x, time): residual = x x = F.silu(self.conv1(self.norm1(x))) time = self.time_proj(time)[:, :, None, None] x += time x = self.norm2(x) x = F.silu(self.conv2(x)) return self.residual_conv(residual) + x class DownSampler(nn.Module): def __init__(self, in_channel): super().__init__() self.conv = nn.Conv2d( in_channel, in_channel, stride=2, padding=1, kernel_size=3 ) def forward(self, x): return self.conv(x) class UpSampler(nn.Module): def __init__(self, in_channel): super().__init__() self.conv = nn.Conv2d( in_channel, in_channel, stride=1, padding=1, kernel_size=3 ) self.up = nn.Upsample(scale_factor=2) def forward(self, x): x = self.up(x) return self.conv(x) class SwitchSequential(nn.Sequential): def forward(self, x, time): for module in self: if isinstance(module, ResidualBlock): x = module(x, time) else: x = module(x) return x class Unet(nn.Module): def __init__(self, time_dim=320, n_head=8): super().__init__() # 时间嵌入 self.time_position_embedding = TimePositionEmbedding() self.time_proj = TimeEmbedding(dim=320) time_dim = time_dim * 4 # ---------------- Encoder:保存“下采样前”的特征做 skip ---------------- self.down_blocks = nn.ModuleList( [ # 输出:128 通道,分辨率 H SwitchSequential( nn.Conv2d(3, 64, kernel_size=3, padding=1, stride=1), ResidualBlock(64, 128, time_dim=time_dim), ResidualBlock(128, 128, time_dim=time_dim), ), # 输出:256 通道,分辨率 H/2 SwitchSequential( ResidualBlock(128, 256, time_dim=time_dim), ResidualBlock(256, 256, time_dim=time_dim), ), # 输出:512 通道,分辨率 H/4 SwitchSequential( ResidualBlock(256, 512, time_dim=time_dim), ResidualBlock(512, 512, time_dim=time_dim), Attention(n_head, 512), ResidualBlock(512, 512, time_dim=time_dim), ), # 底部:512 通道,分辨率 H/8(无下采样) SwitchSequential( ResidualBlock(512, 512, time_dim=time_dim), Attention(n_head, 512), ResidualBlock(512, 512, time_dim=time_dim), ), ] ) self.down_samplers = nn.ModuleList( [ DownSampler(128), # H -> H/2 DownSampler(256), # H/2 -> H/4 DownSampler(512), # H/4 -> H/8 ] ) # ---------------- Bottleneck ---------------- self.mid_blocks = nn.ModuleList( [ SwitchSequential( ResidualBlock(512, 512, time_dim=time_dim), Attention(n_head, 512), ResidualBlock(512, 512, time_dim=time_dim), ), SwitchSequential( ResidualBlock(512, 512, time_dim=time_dim), Attention(n_head, 512), ResidualBlock(512, 512, time_dim=time_dim), ), SwitchSequential( ResidualBlock(512, 512, time_dim=time_dim), Attention(n_head, 512), ResidualBlock(512, 512, time_dim=time_dim), ), ] ) # ---------------- Decoder:先上采样,再与对应 skip 拼接 ---------------- # up_blocks[0]:在最底层先做一轮处理(不拼接) # up_blocks[1]:分辨率 H/4,拼接 skip@H/4(512 通道),输出保持 512 # up_blocks[2]:分辨率 H/2,拼接 skip@H/2(256 通道),输出 256 # up_blocks[3]:分辨率 H,拼接 skip@H(128 通道),输出 64 self.up_blocks = nn.ModuleList( [ SwitchSequential( # H/8,512 -> 512(不拼接) ResidualBlock(512, 512, time_dim=time_dim), Attention(n_head, 512), ResidualBlock(512, 512, time_dim=time_dim), ), SwitchSequential( # H/4,(512 + 512) -> 512 ResidualBlock(512 + 512, 512, time_dim=time_dim), Attention(n_head, 512), ResidualBlock(512, 512, time_dim=time_dim), ), SwitchSequential( # H/2,(512 + 256) -> 256 ResidualBlock(512 + 256, 256, time_dim=time_dim), ResidualBlock(256, 256, time_dim=time_dim), Attention(n_head, 256), ResidualBlock(256, 256, time_dim=time_dim), ), SwitchSequential( # H,(256 + 128) -> 64 ResidualBlock(256 + 128, 64, time_dim=time_dim), ResidualBlock(64, 64, time_dim=time_dim), ), ] ) # 与各阶段输出通道匹配的上采样器: # 先把 512@H/8 上采样到 512@H/4,再 512@H/2,最后 256@H self.up_samplers = nn.ModuleList( [ UpSampler(512), # H/8 -> H/4 UpSampler(512), # H/4 -> H/2 UpSampler(256), # H/2 -> H ] ) self.head = nn.Conv2d(64, 3, kernel_size=3, padding=1, stride=1) def forward(self, x, time): # 时间嵌入 t = self.time_proj(self.time_position_embedding(time)) # -------- Encoder:每个 down_block 输出作为 pre-down skip,然后再下采样 -------- skips = [] for i, block in enumerate(self.down_blocks): x = block(x, t) # 处理当前分辨率 skips.append(x) # 保存“下采样前”的特征 if i < len(self.down_samplers): x = self.down_samplers[i](x) # 下采样到更小分辨率 # -------- Bottleneck -------- for block in self.mid_blocks: x = block(x, t) # -------- Decoder -------- # 底部先做一轮处理(不拼接) x = self.up_blocks[0](x, t) # 仍在 H/8,通道 512 # Stage 1:H/8 -> H/4,拼接 skip@H/4(skips[2]) x = self.up_samplers[0](x) # 512@H/4 x = torch.cat([x, skips[2]], dim=1) # (512 + 512)@H/4 x = self.up_blocks[1](x, t) # 512@H/4 # Stage 2:H/4 -> H/2,拼接 skip@H/2(skips[1]) x = self.up_samplers[1](x) # 512@H/2 x = torch.cat([x, skips[1]], dim=1) # (512 + 256)@H/2 x = self.up_blocks[2](x, t) # 256@H/2 # Stage 3:H/2 -> H,拼接 skip@H(skips[0]) x = self.up_samplers[2](x) # 256@H x = torch.cat([x, skips[0]], dim=1) # (256 + 128)@H x = self.up_blocks[3](x, t) # 64@H # 头部 x = self.head(x) # -> 3@H return x if __name__ == "__main__": model = Unet() x = torch.randn(2, 3, 64, 64) t = torch.randint(0, 1000, (2,)) out = model(x, t) print(out.shape) # torch.save({"model": model.state_dict()}, "unet.pt")