Update tinydit.py
Browse files- tinydit.py +1 -1
tinydit.py
CHANGED
|
@@ -90,7 +90,7 @@ class TinyDit(nn.Module):
|
|
| 90 |
nn.GELU(),
|
| 91 |
nn.Linear(dim, dim)
|
| 92 |
)
|
| 93 |
-
self.pos_embed = nn.Parameter(torch.randn(1,
|
| 94 |
self.blocks = nn.ModuleList([DiTBlock(dim=dim, heads=8, mlp_dim=dim*4) for _ in range(depth)])
|
| 95 |
self.unpatchify = Unpatchify(patch_size=patch_size, dim=dim, out_channels=latent_channels)
|
| 96 |
|
|
|
|
| 90 |
nn.GELU(),
|
| 91 |
nn.Linear(dim, dim)
|
| 92 |
)
|
| 93 |
+
self.pos_embed = nn.Parameter(torch.randn(1, latent_channels, dim))
|
| 94 |
self.blocks = nn.ModuleList([DiTBlock(dim=dim, heads=8, mlp_dim=dim*4) for _ in range(depth)])
|
| 95 |
self.unpatchify = Unpatchify(patch_size=patch_size, dim=dim, out_channels=latent_channels)
|
| 96 |
|