Delete tinydit.py
Browse files- tinydit.py +0 -107
tinydit.py
DELETED
|
@@ -1,107 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
|
| 5 |
-
class PatchEmbed(nn.Module):
|
| 6 |
-
def __init__(self, patch_size, in_channels, embed_dim):
|
| 7 |
-
super().__init__()
|
| 8 |
-
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 9 |
-
|
| 10 |
-
def forward(self, x):
|
| 11 |
-
B, C, latent_h, latent_w = x.shape
|
| 12 |
-
x = self.proj(x)
|
| 13 |
-
x = x.flatten(2).transpose(1, 2)
|
| 14 |
-
return x, latent_h, latent_w
|
| 15 |
-
|
| 16 |
-
class SinusoidalPositionEmbeddings(nn.Module):
|
| 17 |
-
def __init__(self, dim):
|
| 18 |
-
super().__init__()
|
| 19 |
-
self.dim = dim
|
| 20 |
-
|
| 21 |
-
def forward(self, time):
|
| 22 |
-
device = time.device
|
| 23 |
-
half_dim = self.dim // 2
|
| 24 |
-
|
| 25 |
-
embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
| 26 |
-
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
|
| 27 |
-
|
| 28 |
-
embeddings = time[:, None] * embeddings[None, :]
|
| 29 |
-
|
| 30 |
-
embeddings = torch.cat(
|
| 31 |
-
(embeddings.sin(), embeddings.cos()), dim=-1
|
| 32 |
-
)
|
| 33 |
-
return embeddings
|
| 34 |
-
|
| 35 |
-
class LabelEmbedding(nn.Module):
|
| 36 |
-
def __init__(self, num_classes, dim):
|
| 37 |
-
super().__init__()
|
| 38 |
-
|
| 39 |
-
self.embedding = nn.Embedding(num_classes + 1, dim)
|
| 40 |
-
self.null_classes = num_classes
|
| 41 |
-
|
| 42 |
-
def forward(self, labels):
|
| 43 |
-
return self.embedding(labels)
|
| 44 |
-
|
| 45 |
-
class DiTBlock(nn.Module):
|
| 46 |
-
def __init__(self, dim, heads, mlp_dim):
|
| 47 |
-
super().__init__()
|
| 48 |
-
|
| 49 |
-
self.norm1 = nn.LayerNorm(dim)
|
| 50 |
-
self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 51 |
-
self.norm2 = nn.LayerNorm(dim)
|
| 52 |
-
self.mlp = nn.Sequential(
|
| 53 |
-
nn.Linear(dim, mlp_dim),
|
| 54 |
-
nn.GELU(),
|
| 55 |
-
nn.Linear(mlp_dim, dim)
|
| 56 |
-
)
|
| 57 |
-
|
| 58 |
-
def forward(self, x):
|
| 59 |
-
h = self.norm1(x)
|
| 60 |
-
attn_out, _ = self.attn(h, h, h)
|
| 61 |
-
x = x + attn_out
|
| 62 |
-
x = x + self.mlp(self.norm2(x))
|
| 63 |
-
return x
|
| 64 |
-
|
| 65 |
-
class Unpatchify(nn.Module):
|
| 66 |
-
def __init__(self, patch_size, dim, out_channels):
|
| 67 |
-
super().__init__()
|
| 68 |
-
self.patch_size = patch_size
|
| 69 |
-
self.out_channels = out_channels
|
| 70 |
-
self.proj = nn.Linear(dim, out_channels * patch_size * patch_size)
|
| 71 |
-
|
| 72 |
-
def forward(self, x, latent_h, latent_w):
|
| 73 |
-
B, N, D = x.shape
|
| 74 |
-
x = self.proj(x)
|
| 75 |
-
H = latent_h // self.patch_size
|
| 76 |
-
W = latent_w // self.patch_size
|
| 77 |
-
x = x.view(B, H, W, self.out_channels, self.patch_size, self.patch_size)
|
| 78 |
-
x = x.permute(0, 3, 1, 4, 2, 5)
|
| 79 |
-
x = x.reshape(B, self.out_channels, H*self.patch_size, W*self.patch_size)
|
| 80 |
-
return x
|
| 81 |
-
|
| 82 |
-
class TinyDit(nn.Module):
|
| 83 |
-
def __init__(self, latent_channels, dim, depth, patch_size, num_classes):
|
| 84 |
-
super().__init__()
|
| 85 |
-
self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=latent_channels, embed_dim=dim)
|
| 86 |
-
self.label_embed = LabelEmbedding(num_classes=num_classes, dim=dim)
|
| 87 |
-
self.time_embed = nn.Sequential(
|
| 88 |
-
SinusoidalPositionEmbeddings(dim),
|
| 89 |
-
nn.Linear(dim, dim),
|
| 90 |
-
nn.GELU(),
|
| 91 |
-
nn.Linear(dim, dim)
|
| 92 |
-
)
|
| 93 |
-
self.pos_embed = nn.Parameter(torch.randn(1, latent_channels, dim))
|
| 94 |
-
self.blocks = nn.ModuleList([DiTBlock(dim=dim, heads=8, mlp_dim=dim*4) for _ in range(depth)])
|
| 95 |
-
self.unpatchify = Unpatchify(patch_size=patch_size, dim=dim, out_channels=latent_channels)
|
| 96 |
-
|
| 97 |
-
def forward(self, x, t, labels=None):
|
| 98 |
-
x, H, W = self.patch_embed(x)
|
| 99 |
-
x = x + self.pos_embed
|
| 100 |
-
t = self.time_embed(t)
|
| 101 |
-
if labels is not None:
|
| 102 |
-
t = t + self.label_embed(labels)
|
| 103 |
-
x = x + t[:, None, :]
|
| 104 |
-
for block in self.blocks:
|
| 105 |
-
x = block(x)
|
| 106 |
-
x = self.unpatchify(x, H, W)
|
| 107 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|