import torch import torch.nn as nn import torch.nn.functional as F class PatchEmbed(nn.Module): def __init__(self, patch_size, in_channels, embed_dim): super().__init__() self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): B, C, latent_h, latent_w = x.shape x = self.proj(x) x = x.flatten(2).transpose(1, 2) return x, latent_h, latent_w class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, time): device = time.device half_dim = self.dim // 2 embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = time[:, None] * embeddings[None, :] embeddings = torch.cat( (embeddings.sin(), embeddings.cos()), dim=-1 ) return embeddings class LabelEmbedding(nn.Module): def __init__(self, num_classes, dim): super().__init__() self.embedding = nn.Embedding(num_classes + 1, dim) self.null_classes = num_classes def forward(self, labels): return self.embedding(labels) class DiTBlock(nn.Module): def __init__(self, dim, heads, mlp_dim): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(dim, heads, batch_first=True) self.norm2 = nn.LayerNorm(dim) self.mlp = nn.Sequential( nn.Linear(dim, mlp_dim), nn.GELU(), nn.Linear(mlp_dim, dim) ) def forward(self, x): h = self.norm1(x) attn_out, _ = self.attn(h, h, h) x = x + attn_out x = x + self.mlp(self.norm2(x)) return x class Unpatchify(nn.Module): def __init__(self, patch_size, dim, out_channels): super().__init__() self.patch_size = patch_size self.out_channels = out_channels self.proj = nn.Linear(dim, out_channels * patch_size * patch_size) def forward(self, x, latent_h, latent_w): B, N, D = x.shape x = self.proj(x) H = latent_h // self.patch_size W = latent_w // self.patch_size x = x.view(B, H, W, self.out_channels, self.patch_size, self.patch_size) x = x.permute(0, 3, 1, 4, 2, 5) x = x.reshape(B, self.out_channels, H*self.patch_size, W*self.patch_size) return x class TinyDit(nn.Module): def __init__(self, latent_channels, dim, depth, patch_size, num_classes): super().__init__() self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=latent_channels, embed_dim=dim) self.label_embed = LabelEmbedding(num_classes=num_classes, dim=dim) self.time_embed = nn.Sequential( SinusoidalPositionEmbeddings(dim), nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, dim) ) self.pos_embed = nn.Parameter(torch.randn(1, latent_channels, dim)) self.blocks = nn.ModuleList([DiTBlock(dim=dim, heads=8, mlp_dim=dim*4) for _ in range(depth)]) self.unpatchify = Unpatchify(patch_size=patch_size, dim=dim, out_channels=latent_channels) def forward(self, x, t, labels=None): x, H, W = self.patch_embed(x) x = x + self.pos_embed t = self.time_embed(t) if labels is not None: t = t + self.label_embed(labels) x = x + t[:, None, :] for block in self.blocks: x = block(x) x = self.unpatchify(x, H, W) return x