Commit ·
d66ddca
1
Parent(s): e2c268e
Remove print statements
Browse files- model/patch_embed.py +1 -8
- model/transformer.py +0 -4
model/patch_embed.py
CHANGED
|
@@ -10,7 +10,7 @@ def get_patch_positional_embedding(pos_emb_dim, grid_size, device):
|
|
| 10 |
grid_size_h, grid_size_w = grid_size
|
| 11 |
grid_h = torch.arange(grid_size_h, dtype=torch.float32, device=device)
|
| 12 |
grid_w = torch.arange(grid_size_w, dtype=torch.float32, device=device)
|
| 13 |
-
grid = torch.meshgrid(grid_h, grid_w,indexing="ij")
|
| 14 |
grid = torch.stack(grid, dim=0)
|
| 15 |
|
| 16 |
grid_h_positions = grid[0].reshape(-1)
|
|
@@ -30,8 +30,6 @@ def get_patch_positional_embedding(pos_emb_dim, grid_size, device):
|
|
| 30 |
pos_emb = torch.cat([grid_h_emb, grid_w_emb], dim=-1)
|
| 31 |
|
| 32 |
# pos_emb -> (Number of patch tokens, pos_emb_dim)
|
| 33 |
-
print(f"Grid H emb: {grid_h_emb.shape} in Patch Embedding")
|
| 34 |
-
print(f"Grid W emb: {grid_w_emb.shape} in Patch Embedding")
|
| 35 |
|
| 36 |
return pos_emb
|
| 37 |
|
|
@@ -68,7 +66,6 @@ class PatchEmbedding(nn.Module):
|
|
| 68 |
nn.init.constant_(self.patch_embed[0].bias, 0)
|
| 69 |
|
| 70 |
def forward(self, x):
|
| 71 |
-
|
| 72 |
out = rearrange(
|
| 73 |
x,
|
| 74 |
"b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)",
|
|
@@ -76,20 +73,16 @@ class PatchEmbedding(nn.Module):
|
|
| 76 |
pw=self.patch_width,
|
| 77 |
)
|
| 78 |
|
| 79 |
-
print(f"Image shape after rearraging: {out.shape} in Patch Embedding Layer")
|
| 80 |
actual_h = x.shape[2] # Height from input tensor
|
| 81 |
actual_w = x.shape[3] # Width from input tensor
|
| 82 |
grid_size_h = actual_h // self.patch_height
|
| 83 |
grid_size_w = actual_w // self.patch_width
|
| 84 |
-
|
| 85 |
-
|
| 86 |
out = self.patch_embed(out)
|
| 87 |
pos_emb = get_patch_positional_embedding(
|
| 88 |
self.hidden_dim, grid_size=(grid_size_h, grid_size_w), device=x.device
|
| 89 |
)
|
| 90 |
|
| 91 |
out += pos_emb
|
| 92 |
-
print(f"Patch Embeddings: {out.shape}\n")
|
| 93 |
return out
|
| 94 |
|
| 95 |
|
|
|
|
| 10 |
grid_size_h, grid_size_w = grid_size
|
| 11 |
grid_h = torch.arange(grid_size_h, dtype=torch.float32, device=device)
|
| 12 |
grid_w = torch.arange(grid_size_w, dtype=torch.float32, device=device)
|
| 13 |
+
grid = torch.meshgrid(grid_h, grid_w, indexing="ij")
|
| 14 |
grid = torch.stack(grid, dim=0)
|
| 15 |
|
| 16 |
grid_h_positions = grid[0].reshape(-1)
|
|
|
|
| 30 |
pos_emb = torch.cat([grid_h_emb, grid_w_emb], dim=-1)
|
| 31 |
|
| 32 |
# pos_emb -> (Number of patch tokens, pos_emb_dim)
|
|
|
|
|
|
|
| 33 |
|
| 34 |
return pos_emb
|
| 35 |
|
|
|
|
| 66 |
nn.init.constant_(self.patch_embed[0].bias, 0)
|
| 67 |
|
| 68 |
def forward(self, x):
|
|
|
|
| 69 |
out = rearrange(
|
| 70 |
x,
|
| 71 |
"b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)",
|
|
|
|
| 73 |
pw=self.patch_width,
|
| 74 |
)
|
| 75 |
|
|
|
|
| 76 |
actual_h = x.shape[2] # Height from input tensor
|
| 77 |
actual_w = x.shape[3] # Width from input tensor
|
| 78 |
grid_size_h = actual_h // self.patch_height
|
| 79 |
grid_size_w = actual_w // self.patch_width
|
|
|
|
|
|
|
| 80 |
out = self.patch_embed(out)
|
| 81 |
pos_emb = get_patch_positional_embedding(
|
| 82 |
self.hidden_dim, grid_size=(grid_size_h, grid_size_w), device=x.device
|
| 83 |
)
|
| 84 |
|
| 85 |
out += pos_emb
|
|
|
|
| 86 |
return out
|
| 87 |
|
| 88 |
|
model/transformer.py
CHANGED
|
@@ -93,8 +93,6 @@ class DIT(nn.Module):
|
|
| 93 |
1 + pre_mlp_scale.unsqueeze(1)
|
| 94 |
) + pre_mlp_shift.unsqueeze(1)
|
| 95 |
|
| 96 |
-
print(f"\nOutput before unpatchify block {out.shape} in DIT block")
|
| 97 |
-
|
| 98 |
actual_h = x.shape[2] # Height from input tensor
|
| 99 |
actual_w = x.shape[3] # Width from input tensor
|
| 100 |
actual_nh = actual_h // self.patch_height
|
|
@@ -102,7 +100,6 @@ class DIT(nn.Module):
|
|
| 102 |
|
| 103 |
# Unpatichify
|
| 104 |
out = self.out_proj(out)
|
| 105 |
-
print(f"Output after projection: {out.shape} in DIT block")
|
| 106 |
out = rearrange(
|
| 107 |
out,
|
| 108 |
"b (nh nw) (ph pw c) -> b c (nh ph) (nw pw)",
|
|
@@ -112,7 +109,6 @@ class DIT(nn.Module):
|
|
| 112 |
nh=actual_nh,
|
| 113 |
)
|
| 114 |
|
| 115 |
-
print(f"Output after unpatchify block: {out.shape} in DIT block\n")
|
| 116 |
return out
|
| 117 |
|
| 118 |
|
|
|
|
| 93 |
1 + pre_mlp_scale.unsqueeze(1)
|
| 94 |
) + pre_mlp_shift.unsqueeze(1)
|
| 95 |
|
|
|
|
|
|
|
| 96 |
actual_h = x.shape[2] # Height from input tensor
|
| 97 |
actual_w = x.shape[3] # Width from input tensor
|
| 98 |
actual_nh = actual_h // self.patch_height
|
|
|
|
| 100 |
|
| 101 |
# Unpatichify
|
| 102 |
out = self.out_proj(out)
|
|
|
|
| 103 |
out = rearrange(
|
| 104 |
out,
|
| 105 |
"b (nh nw) (ph pw c) -> b c (nh ph) (nw pw)",
|
|
|
|
| 109 |
nh=actual_nh,
|
| 110 |
)
|
| 111 |
|
|
|
|
| 112 |
return out
|
| 113 |
|
| 114 |
|