| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class PatchEmbed(nn.Module): |
| def __init__(self, patch_size, in_channels, embed_dim): |
| super().__init__() |
| self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) |
|
|
| def forward(self, x): |
| B, C, latent_h, latent_w = x.shape |
| x = self.proj(x) |
| x = x.flatten(2).transpose(1, 2) |
| return x, latent_h, latent_w |
|
|
| class SinusoidalPositionEmbeddings(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, time): |
| device = time.device |
| half_dim = self.dim // 2 |
|
|
| embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1) |
| embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) |
|
|
| embeddings = time[:, None] * embeddings[None, :] |
|
|
| embeddings = torch.cat( |
| (embeddings.sin(), embeddings.cos()), dim=-1 |
| ) |
| return embeddings |
|
|
| class LabelEmbedding(nn.Module): |
| def __init__(self, num_classes, dim): |
| super().__init__() |
|
|
| self.embedding = nn.Embedding(num_classes + 1, dim) |
| self.null_classes = num_classes |
|
|
| def forward(self, labels): |
| return self.embedding(labels) |
|
|
| class DiTBlock(nn.Module): |
| def __init__(self, dim, heads, mlp_dim): |
| super().__init__() |
|
|
| self.norm1 = nn.LayerNorm(dim) |
| self.attn = nn.MultiheadAttention(dim, heads, batch_first=True) |
| self.norm2 = nn.LayerNorm(dim) |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, mlp_dim), |
| nn.GELU(), |
| nn.Linear(mlp_dim, dim) |
| ) |
|
|
| def forward(self, x): |
| h = self.norm1(x) |
| attn_out, _ = self.attn(h, h, h) |
| x = x + attn_out |
| x = x + self.mlp(self.norm2(x)) |
| return x |
|
|
| class Unpatchify(nn.Module): |
| def __init__(self, patch_size, dim, out_channels): |
| super().__init__() |
| self.patch_size = patch_size |
| self.out_channels = out_channels |
| self.proj = nn.Linear(dim, out_channels * patch_size * patch_size) |
|
|
| def forward(self, x, latent_h, latent_w): |
| B, N, D = x.shape |
| x = self.proj(x) |
| H = latent_h // self.patch_size |
| W = latent_w // self.patch_size |
| x = x.view(B, H, W, self.out_channels, self.patch_size, self.patch_size) |
| x = x.permute(0, 3, 1, 4, 2, 5) |
| x = x.reshape(B, self.out_channels, H*self.patch_size, W*self.patch_size) |
| return x |
|
|
| class TinyDit(nn.Module): |
| def __init__(self, latent_channels, dim, depth, patch_size, num_classes): |
| super().__init__() |
| self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=latent_channels, embed_dim=dim) |
| self.label_embed = LabelEmbedding(num_classes=num_classes, dim=dim) |
| self.time_embed = nn.Sequential( |
| SinusoidalPositionEmbeddings(dim), |
| nn.Linear(dim, dim), |
| nn.GELU(), |
| nn.Linear(dim, dim) |
| ) |
| self.pos_embed = nn.Parameter(torch.randn(1, latent_channels, dim)) |
| self.blocks = nn.ModuleList([DiTBlock(dim=dim, heads=8, mlp_dim=dim*4) for _ in range(depth)]) |
| self.unpatchify = Unpatchify(patch_size=patch_size, dim=dim, out_channels=latent_channels) |
|
|
| def forward(self, x, t, labels=None): |
| x, H, W = self.patch_embed(x) |
| x = x + self.pos_embed |
| t = self.time_embed(t) |
| if labels is not None: |
| t = t + self.label_embed(labels) |
| x = x + t[:, None, :] |
| for block in self.blocks: |
| x = block(x) |
| x = self.unpatchify(x, H, W) |
| return x |