Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| class Attention(nn.Module): | |
| def __init__(self, n_head, dim): | |
| super().__init__() | |
| assert dim % n_head == 0 | |
| self.qkv_proj = nn.Linear(dim, dim * 3) | |
| self.out_proj = nn.Linear(dim, dim) | |
| self.n_head = n_head | |
| self.head_dim = dim // self.n_head | |
| def forward(self, x: torch.Tensor): | |
| batch_size, channel, height, width = x.shape | |
| x = x.reshape(batch_size, channel, height * width).transpose(-1, -2) | |
| q, k, v = torch.chunk(self.qkv_proj(x), chunks=3, dim=-1) | |
| q_state = q.reshape( | |
| batch_size, height * width, self.n_head, self.head_dim | |
| ).transpose(1, 2) | |
| k_state = k.reshape( | |
| batch_size, height * width, self.n_head, self.head_dim | |
| ).transpose(1, 2) | |
| v_state = v.reshape( | |
| batch_size, height * width, self.n_head, self.head_dim | |
| ).transpose(1, 2) | |
| out = F.scaled_dot_product_attention(q_state, k_state, v_state) | |
| out = out.transpose(1, 2).reshape(batch_size, height * width, channel) | |
| out = self.out_proj(out) | |
| out = out.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| return out | |
| class TimePositionEmbedding(nn.Module): | |
| def __init__(self, seq_len=1000, dim=320): | |
| super().__init__() | |
| base = 10000 | |
| inv_freq = 1 / base ** (torch.arange(0, dim, step=2).float() / dim) | |
| inv_freq = inv_freq.unsqueeze(0) | |
| position = torch.arange(0, seq_len, step=1).unsqueeze(1) | |
| position = position * inv_freq | |
| pe = torch.zeros(size=(seq_len, dim)) | |
| pe[:, 0::2] = position.sin() | |
| pe[:, 1::2] = position.cos() | |
| self.register_buffer("pe", pe, persistent=False) | |
| def forward(self, time): | |
| time = time.reshape(-1) | |
| return self.pe[time] | |
| class TimeEmbedding(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(dim, dim * 4), nn.SiLU(), nn.Linear(dim * 4, dim * 4) | |
| ) | |
| def forward(self, x): | |
| return self.mlp(x) | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_channel, out_channel, time_dim): | |
| super().__init__() | |
| self.norm1 = nn.GroupNorm(32, in_channel) | |
| self.norm2 = nn.GroupNorm(32, out_channel) | |
| self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1) | |
| self.time_proj = nn.Linear(time_dim, out_channel) | |
| self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1) | |
| self.residual_conv = nn.Identity() | |
| if in_channel != out_channel: | |
| self.residual_conv = nn.Conv2d(in_channel, out_channel, kernel_size=1) | |
| def forward(self, x, time): | |
| residual = x | |
| x = F.silu(self.conv1(self.norm1(x))) | |
| time = self.time_proj(time)[:, :, None, None] | |
| x += time | |
| x = self.norm2(x) | |
| x = F.silu(self.conv2(x)) | |
| return self.residual_conv(residual) + x | |
| class DownSampler(nn.Module): | |
| def __init__(self, in_channel): | |
| super().__init__() | |
| self.conv = nn.Conv2d( | |
| in_channel, in_channel, stride=2, padding=1, kernel_size=3 | |
| ) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class UpSampler(nn.Module): | |
| def __init__(self, in_channel): | |
| super().__init__() | |
| self.conv = nn.Conv2d( | |
| in_channel, in_channel, stride=1, padding=1, kernel_size=3 | |
| ) | |
| self.up = nn.Upsample(scale_factor=2) | |
| def forward(self, x): | |
| x = self.up(x) | |
| return self.conv(x) | |
| class SwitchSequential(nn.Sequential): | |
| def forward(self, x, time): | |
| for module in self: | |
| if isinstance(module, ResidualBlock): | |
| x = module(x, time) | |
| else: | |
| x = module(x) | |
| return x | |
| class Unet(nn.Module): | |
| def __init__(self, time_dim=320, n_head=8): | |
| super().__init__() | |
| # 时间嵌入 | |
| self.time_position_embedding = TimePositionEmbedding() | |
| self.time_proj = TimeEmbedding(dim=320) | |
| time_dim = time_dim * 4 | |
| # ---------------- Encoder:保存“下采样前”的特征做 skip ---------------- | |
| self.down_blocks = nn.ModuleList( | |
| [ | |
| # 输出:128 通道,分辨率 H | |
| SwitchSequential( | |
| nn.Conv2d(3, 64, kernel_size=3, padding=1, stride=1), | |
| ResidualBlock(64, 128, time_dim=time_dim), | |
| ResidualBlock(128, 128, time_dim=time_dim), | |
| ), | |
| # 输出:256 通道,分辨率 H/2 | |
| SwitchSequential( | |
| ResidualBlock(128, 256, time_dim=time_dim), | |
| ResidualBlock(256, 256, time_dim=time_dim), | |
| ), | |
| # 输出:512 通道,分辨率 H/4 | |
| SwitchSequential( | |
| ResidualBlock(256, 512, time_dim=time_dim), | |
| ResidualBlock(512, 512, time_dim=time_dim), | |
| Attention(n_head, 512), | |
| ResidualBlock(512, 512, time_dim=time_dim), | |
| ), | |
| # 底部:512 通道,分辨率 H/8(无下采样) | |
| SwitchSequential( | |
| ResidualBlock(512, 512, time_dim=time_dim), | |
| Attention(n_head, 512), | |
| ResidualBlock(512, 512, time_dim=time_dim), | |
| ), | |
| ] | |
| ) | |
| self.down_samplers = nn.ModuleList( | |
| [ | |
| DownSampler(128), # H -> H/2 | |
| DownSampler(256), # H/2 -> H/4 | |
| DownSampler(512), # H/4 -> H/8 | |
| ] | |
| ) | |
| # ---------------- Bottleneck ---------------- | |
| self.mid_blocks = nn.ModuleList( | |
| [ | |
| SwitchSequential( | |
| ResidualBlock(512, 512, time_dim=time_dim), | |
| Attention(n_head, 512), | |
| ResidualBlock(512, 512, time_dim=time_dim), | |
| ), | |
| SwitchSequential( | |
| ResidualBlock(512, 512, time_dim=time_dim), | |
| Attention(n_head, 512), | |
| ResidualBlock(512, 512, time_dim=time_dim), | |
| ), | |
| SwitchSequential( | |
| ResidualBlock(512, 512, time_dim=time_dim), | |
| Attention(n_head, 512), | |
| ResidualBlock(512, 512, time_dim=time_dim), | |
| ), | |
| ] | |
| ) | |
| # ---------------- Decoder:先上采样,再与对应 skip 拼接 ---------------- | |
| # up_blocks[0]:在最底层先做一轮处理(不拼接) | |
| # up_blocks[1]:分辨率 H/4,拼接 skip@H/4(512 通道),输出保持 512 | |
| # up_blocks[2]:分辨率 H/2,拼接 skip@H/2(256 通道),输出 256 | |
| # up_blocks[3]:分辨率 H,拼接 skip@H(128 通道),输出 64 | |
| self.up_blocks = nn.ModuleList( | |
| [ | |
| SwitchSequential( # H/8,512 -> 512(不拼接) | |
| ResidualBlock(512, 512, time_dim=time_dim), | |
| Attention(n_head, 512), | |
| ResidualBlock(512, 512, time_dim=time_dim), | |
| ), | |
| SwitchSequential( # H/4,(512 + 512) -> 512 | |
| ResidualBlock(512 + 512, 512, time_dim=time_dim), | |
| Attention(n_head, 512), | |
| ResidualBlock(512, 512, time_dim=time_dim), | |
| ), | |
| SwitchSequential( # H/2,(512 + 256) -> 256 | |
| ResidualBlock(512 + 256, 256, time_dim=time_dim), | |
| ResidualBlock(256, 256, time_dim=time_dim), | |
| Attention(n_head, 256), | |
| ResidualBlock(256, 256, time_dim=time_dim), | |
| ), | |
| SwitchSequential( # H,(256 + 128) -> 64 | |
| ResidualBlock(256 + 128, 64, time_dim=time_dim), | |
| ResidualBlock(64, 64, time_dim=time_dim), | |
| ), | |
| ] | |
| ) | |
| # 与各阶段输出通道匹配的上采样器: | |
| # 先把 512@H/8 上采样到 512@H/4,再 512@H/2,最后 256@H | |
| self.up_samplers = nn.ModuleList( | |
| [ | |
| UpSampler(512), # H/8 -> H/4 | |
| UpSampler(512), # H/4 -> H/2 | |
| UpSampler(256), # H/2 -> H | |
| ] | |
| ) | |
| self.head = nn.Conv2d(64, 3, kernel_size=3, padding=1, stride=1) | |
| def forward(self, x, time): | |
| # 时间嵌入 | |
| t = self.time_proj(self.time_position_embedding(time)) | |
| # -------- Encoder:每个 down_block 输出作为 pre-down skip,然后再下采样 -------- | |
| skips = [] | |
| for i, block in enumerate(self.down_blocks): | |
| x = block(x, t) # 处理当前分辨率 | |
| skips.append(x) # 保存“下采样前”的特征 | |
| if i < len(self.down_samplers): | |
| x = self.down_samplers[i](x) # 下采样到更小分辨率 | |
| # -------- Bottleneck -------- | |
| for block in self.mid_blocks: | |
| x = block(x, t) | |
| # -------- Decoder -------- | |
| # 底部先做一轮处理(不拼接) | |
| x = self.up_blocks[0](x, t) # 仍在 H/8,通道 512 | |
| # Stage 1:H/8 -> H/4,拼接 skip@H/4(skips[2]) | |
| x = self.up_samplers[0](x) # 512@H/4 | |
| x = torch.cat([x, skips[2]], dim=1) # (512 + 512)@H/4 | |
| x = self.up_blocks[1](x, t) # 512@H/4 | |
| # Stage 2:H/4 -> H/2,拼接 skip@H/2(skips[1]) | |
| x = self.up_samplers[1](x) # 512@H/2 | |
| x = torch.cat([x, skips[1]], dim=1) # (512 + 256)@H/2 | |
| x = self.up_blocks[2](x, t) # 256@H/2 | |
| # Stage 3:H/2 -> H,拼接 skip@H(skips[0]) | |
| x = self.up_samplers[2](x) # 256@H | |
| x = torch.cat([x, skips[0]], dim=1) # (256 + 128)@H | |
| x = self.up_blocks[3](x, t) # 64@H | |
| # 头部 | |
| x = self.head(x) # -> 3@H | |
| return x | |
| if __name__ == "__main__": | |
| model = Unet() | |
| x = torch.randn(2, 3, 64, 64) | |
| t = torch.randint(0, 1000, (2,)) | |
| out = model(x, t) | |
| print(out.shape) | |
| # torch.save({"model": model.state_dict()}, "unet.pt") |