| import torch |
| import torch.nn as nn |
| import math |
|
|
| class SinusoidalPosEmb(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, time): |
| device = time.device |
| half_dim = self.dim // 2 |
| emb = math.log(10000) / (half_dim - 1) |
| emb = torch.exp(torch.arange(half_dim, device=device) * -emb) |
| emb = time[:, None] * emb[None, :] |
| emb = torch.cat((emb.sin(), emb.cos()), dim=-1) |
| return emb |
|
|
| class ResidualBlock(nn.Module): |
| def __init__(self, in_channels, out_channels, time_emb_dim): |
| super().__init__() |
| |
| self.time_mlp = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(time_emb_dim, out_channels * 2) |
| ) |
| |
| self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) |
| self.norm1 = nn.GroupNorm(32, out_channels) |
| |
| self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) |
| self.norm2 = nn.GroupNorm(32, out_channels) |
| |
| self.residual_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() |
| |
| self.act = nn.SiLU() |
|
|
| def forward(self, x, time_emb): |
| h = self.conv1(x) |
| h = self.norm1(h) |
| h = self.act(h) |
| |
| |
| time_emb = self.time_mlp(time_emb) |
| time_emb = time_emb[:, :, None, None] |
| scale, shift = time_emb.chunk(2, dim=1) |
| h = h * (scale + 1) + shift |
| |
| h = self.conv2(h) |
| h = self.norm2(h) |
| h = self.act(h) |
| |
| return h + self.residual_conv(x) |
|
|
| class SelfAttention(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.norm = nn.GroupNorm(32, channels) |
| self.qkv = nn.Conv2d(channels, channels * 3, 1) |
| self.out = nn.Conv2d(channels, channels, 1) |
| self.scale = 1.0 / math.sqrt(channels) |
|
|
| def forward(self, x): |
| b, c, h, w = x.shape |
| h_norm = self.norm(x) |
| qkv = self.qkv(h_norm) |
| q, k, v = qkv.chunk(3, dim=1) |
| |
| q = q.reshape(b, c, h * w).transpose(-2, -1) |
| k = k.reshape(b, c, h * w) |
| v = v.reshape(b, c, h * w).transpose(-2, -1) |
| |
| attn = torch.softmax(q @ k * self.scale, dim=-1) |
| out = attn @ v |
| out = out.transpose(-2, -1).reshape(b, c, h, w) |
| |
| return x + self.out(out) |
|
|
| class UNet(nn.Module): |
| def __init__(self, img_size=64, in_channels=3, out_channels=3, base_channels=128, ch_mult=(1, 2, 4)): |
| super().__init__() |
| |
| self.time_embed = SinusoidalPosEmb(base_channels) |
| self.time_mlp = nn.Sequential( |
| nn.Linear(base_channels, base_channels * 4), |
| nn.SiLU(), |
| nn.Linear(base_channels * 4, base_channels * 4) |
| ) |
| |
| |
| self.init_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1) |
| |
| |
| self.down_channels = [] |
| self.down_blocks = nn.ModuleList([]) |
| channels = base_channels |
| for i, mult in enumerate(ch_mult): |
| out_ch = base_channels * mult |
| self.down_channels.append(out_ch) |
| self.down_blocks.append(nn.ModuleList([ |
| ResidualBlock(channels, out_ch, base_channels * 4), |
| ResidualBlock(out_ch, out_ch, base_channels * 4), |
| ])) |
| channels = out_ch |
| if i < len(ch_mult) - 1: |
| self.down_blocks[-1].append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1)) |
| else: |
| self.down_blocks[-1].append(nn.Identity()) |
| |
| |
| self.bottleneck = nn.ModuleList([ |
| ResidualBlock(channels, channels, base_channels * 4), |
| SelfAttention(channels), |
| ResidualBlock(channels, channels, base_channels * 4) |
| ]) |
| |
| |
| self.up_blocks = nn.ModuleList([]) |
| for i, mult in reversed(list(enumerate(ch_mult))): |
| out_ch = base_channels * mult |
| |
| |
| skip_ch = self.down_channels[i] |
| in_ch = channels + skip_ch |
| |
| self.up_blocks.append(nn.ModuleList([ |
| ResidualBlock(in_ch, out_ch, base_channels * 4), |
| ResidualBlock(out_ch, out_ch, base_channels * 4), |
| ])) |
| channels = out_ch |
| if i > 0: |
| self.up_blocks[-1].append(nn.Upsample(scale_factor=2)) |
| else: |
| self.up_blocks[-1].append(nn.Identity()) |
| |
| |
| self.final_conv = nn.Sequential( |
| nn.GroupNorm(32, base_channels), |
| nn.SiLU(), |
| nn.Conv2d(base_channels, out_channels, 3, padding=1) |
| ) |
|
|
| def forward(self, x, t): |
| |
| t_emb = self.time_embed(t) |
| t_emb = self.time_mlp(t_emb) |
| |
| |
| h = self.init_conv(x) |
| |
| |
| skips = [] |
| for down_block in self.down_blocks: |
| res1, res2, downsample = down_block |
| h = res1(h, t_emb) |
| h = res2(h, t_emb) |
| skips.append(h) |
| h = downsample(h) |
| |
| |
| for layer in self.bottleneck: |
| if isinstance(layer, SelfAttention): |
| h = layer(h) |
| else: |
| h = layer(h, t_emb) |
| |
| |
| for i, up_block in enumerate(self.up_blocks): |
| res1, res2, upsample = up_block |
| |
| skip_idx = len(skips) - 1 - i |
| if skip_idx >= 0: |
| h = torch.cat([h, skips[skip_idx]], dim=1) |
| h = res1(h, t_emb) |
| h = res2(h, t_emb) |
| h = upsample(h) |
| |
| return self.final_conv(h) |
|
|
| if __name__ == "__main__": |
| |
| print("Initializing UNet...") |
| model = UNet(img_size=64, base_channels=128) |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"Total parameters: {total_params:,}") |
| |
| |
| print("\nTesting forward pass...") |
| x = torch.randn(1, 3, 64, 64) |
| t = torch.randint(0, 1000, (1,)) |
| |
| with torch.no_grad(): |
| output = model(x, t) |
| |
| print(f"Input shape: {x.shape}") |
| print(f"Output shape: {output.shape}") |
| print("\nModel architecture verified successfully!") |
|
|