nitesh501 commited on
Commit
45f73cb
·
verified ·
1 Parent(s): 5e8ff6c

Update tinydit.py

Browse files
Files changed (1) hide show
  1. tinydit.py +1 -1
tinydit.py CHANGED
@@ -90,7 +90,7 @@ class TinyDit(nn.Module):
90
  nn.GELU(),
91
  nn.Linear(dim, dim)
92
  )
93
- self.pos_embed = nn.Parameter(torch.randn(1, 16, dim))
94
  self.blocks = nn.ModuleList([DiTBlock(dim=dim, heads=8, mlp_dim=dim*4) for _ in range(depth)])
95
  self.unpatchify = Unpatchify(patch_size=patch_size, dim=dim, out_channels=latent_channels)
96
 
 
90
  nn.GELU(),
91
  nn.Linear(dim, dim)
92
  )
93
+ self.pos_embed = nn.Parameter(torch.randn(1, latent_channels, dim))
94
  self.blocks = nn.ModuleList([DiTBlock(dim=dim, heads=8, mlp_dim=dim*4) for _ in range(depth)])
95
  self.unpatchify = Unpatchify(patch_size=patch_size, dim=dim, out_channels=latent_channels)
96