nitesh501 commited on
Commit
02259e1
·
verified ·
1 Parent(s): 4691adb

Delete tinydit.py

Browse files
Files changed (1) hide show
  1. tinydit.py +0 -107
tinydit.py DELETED
@@ -1,107 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- class PatchEmbed(nn.Module):
6
- def __init__(self, patch_size, in_channels, embed_dim):
7
- super().__init__()
8
- self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
9
-
10
- def forward(self, x):
11
- B, C, latent_h, latent_w = x.shape
12
- x = self.proj(x)
13
- x = x.flatten(2).transpose(1, 2)
14
- return x, latent_h, latent_w
15
-
16
- class SinusoidalPositionEmbeddings(nn.Module):
17
- def __init__(self, dim):
18
- super().__init__()
19
- self.dim = dim
20
-
21
- def forward(self, time):
22
- device = time.device
23
- half_dim = self.dim // 2
24
-
25
- embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
26
- embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
27
-
28
- embeddings = time[:, None] * embeddings[None, :]
29
-
30
- embeddings = torch.cat(
31
- (embeddings.sin(), embeddings.cos()), dim=-1
32
- )
33
- return embeddings
34
-
35
- class LabelEmbedding(nn.Module):
36
- def __init__(self, num_classes, dim):
37
- super().__init__()
38
-
39
- self.embedding = nn.Embedding(num_classes + 1, dim)
40
- self.null_classes = num_classes
41
-
42
- def forward(self, labels):
43
- return self.embedding(labels)
44
-
45
- class DiTBlock(nn.Module):
46
- def __init__(self, dim, heads, mlp_dim):
47
- super().__init__()
48
-
49
- self.norm1 = nn.LayerNorm(dim)
50
- self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
51
- self.norm2 = nn.LayerNorm(dim)
52
- self.mlp = nn.Sequential(
53
- nn.Linear(dim, mlp_dim),
54
- nn.GELU(),
55
- nn.Linear(mlp_dim, dim)
56
- )
57
-
58
- def forward(self, x):
59
- h = self.norm1(x)
60
- attn_out, _ = self.attn(h, h, h)
61
- x = x + attn_out
62
- x = x + self.mlp(self.norm2(x))
63
- return x
64
-
65
- class Unpatchify(nn.Module):
66
- def __init__(self, patch_size, dim, out_channels):
67
- super().__init__()
68
- self.patch_size = patch_size
69
- self.out_channels = out_channels
70
- self.proj = nn.Linear(dim, out_channels * patch_size * patch_size)
71
-
72
- def forward(self, x, latent_h, latent_w):
73
- B, N, D = x.shape
74
- x = self.proj(x)
75
- H = latent_h // self.patch_size
76
- W = latent_w // self.patch_size
77
- x = x.view(B, H, W, self.out_channels, self.patch_size, self.patch_size)
78
- x = x.permute(0, 3, 1, 4, 2, 5)
79
- x = x.reshape(B, self.out_channels, H*self.patch_size, W*self.patch_size)
80
- return x
81
-
82
- class TinyDit(nn.Module):
83
- def __init__(self, latent_channels, dim, depth, patch_size, num_classes):
84
- super().__init__()
85
- self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=latent_channels, embed_dim=dim)
86
- self.label_embed = LabelEmbedding(num_classes=num_classes, dim=dim)
87
- self.time_embed = nn.Sequential(
88
- SinusoidalPositionEmbeddings(dim),
89
- nn.Linear(dim, dim),
90
- nn.GELU(),
91
- nn.Linear(dim, dim)
92
- )
93
- self.pos_embed = nn.Parameter(torch.randn(1, latent_channels, dim))
94
- self.blocks = nn.ModuleList([DiTBlock(dim=dim, heads=8, mlp_dim=dim*4) for _ in range(depth)])
95
- self.unpatchify = Unpatchify(patch_size=patch_size, dim=dim, out_channels=latent_channels)
96
-
97
- def forward(self, x, t, labels=None):
98
- x, H, W = self.patch_embed(x)
99
- x = x + self.pos_embed
100
- t = self.time_embed(t)
101
- if labels is not None:
102
- t = t + self.label_embed(labels)
103
- x = x + t[:, None, :]
104
- for block in self.blocks:
105
- x = block(x)
106
- x = self.unpatchify(x, H, W)
107
- return x