TinyDiT / tinydit.py
aniure
uploading
d069b0b
import torch
import torch.nn as nn
import torch.nn.functional as F
class PatchEmbed(nn.Module):
def __init__(self, patch_size, in_channels, embed_dim):
super().__init__()
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, latent_h, latent_w = x.shape
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x, latent_h, latent_w
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat(
(embeddings.sin(), embeddings.cos()), dim=-1
)
return embeddings
class LabelEmbedding(nn.Module):
def __init__(self, num_classes, dim):
super().__init__()
self.embedding = nn.Embedding(num_classes + 1, dim)
self.null_classes = num_classes
def forward(self, labels):
return self.embedding(labels)
class DiTBlock(nn.Module):
def __init__(self, dim, heads, mlp_dim):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, dim)
)
def forward(self, x):
h = self.norm1(x)
attn_out, _ = self.attn(h, h, h)
x = x + attn_out
x = x + self.mlp(self.norm2(x))
return x
class Unpatchify(nn.Module):
def __init__(self, patch_size, dim, out_channels):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels
self.proj = nn.Linear(dim, out_channels * patch_size * patch_size)
def forward(self, x, latent_h, latent_w):
B, N, D = x.shape
x = self.proj(x)
H = latent_h // self.patch_size
W = latent_w // self.patch_size
x = x.view(B, H, W, self.out_channels, self.patch_size, self.patch_size)
x = x.permute(0, 3, 1, 4, 2, 5)
x = x.reshape(B, self.out_channels, H*self.patch_size, W*self.patch_size)
return x
class TinyDit(nn.Module):
def __init__(self, latent_channels, dim, depth, patch_size, num_classes):
super().__init__()
self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=latent_channels, embed_dim=dim)
self.label_embed = LabelEmbedding(num_classes=num_classes, dim=dim)
self.time_embed = nn.Sequential(
SinusoidalPositionEmbeddings(dim),
nn.Linear(dim, dim),
nn.GELU(),
nn.Linear(dim, dim)
)
self.pos_embed = nn.Parameter(torch.randn(1, latent_channels, dim))
self.blocks = nn.ModuleList([DiTBlock(dim=dim, heads=8, mlp_dim=dim*4) for _ in range(depth)])
self.unpatchify = Unpatchify(patch_size=patch_size, dim=dim, out_channels=latent_channels)
def forward(self, x, t, labels=None):
x, H, W = self.patch_embed(x)
x = x + self.pos_embed
t = self.time_embed(t)
if labels is not None:
t = t + self.label_embed(labels)
x = x + t[:, None, :]
for block in self.blocks:
x = block(x)
x = self.unpatchify(x, H, W)
return x