YashNagraj75 commited on
Commit
d66ddca
·
1 Parent(s): e2c268e

Remove print statements

Browse files
Files changed (2) hide show
  1. model/patch_embed.py +1 -8
  2. model/transformer.py +0 -4
model/patch_embed.py CHANGED
@@ -10,7 +10,7 @@ def get_patch_positional_embedding(pos_emb_dim, grid_size, device):
10
  grid_size_h, grid_size_w = grid_size
11
  grid_h = torch.arange(grid_size_h, dtype=torch.float32, device=device)
12
  grid_w = torch.arange(grid_size_w, dtype=torch.float32, device=device)
13
- grid = torch.meshgrid(grid_h, grid_w,indexing="ij")
14
  grid = torch.stack(grid, dim=0)
15
 
16
  grid_h_positions = grid[0].reshape(-1)
@@ -30,8 +30,6 @@ def get_patch_positional_embedding(pos_emb_dim, grid_size, device):
30
  pos_emb = torch.cat([grid_h_emb, grid_w_emb], dim=-1)
31
 
32
  # pos_emb -> (Number of patch tokens, pos_emb_dim)
33
- print(f"Grid H emb: {grid_h_emb.shape} in Patch Embedding")
34
- print(f"Grid W emb: {grid_w_emb.shape} in Patch Embedding")
35
 
36
  return pos_emb
37
 
@@ -68,7 +66,6 @@ class PatchEmbedding(nn.Module):
68
  nn.init.constant_(self.patch_embed[0].bias, 0)
69
 
70
  def forward(self, x):
71
-
72
  out = rearrange(
73
  x,
74
  "b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)",
@@ -76,20 +73,16 @@ class PatchEmbedding(nn.Module):
76
  pw=self.patch_width,
77
  )
78
 
79
- print(f"Image shape after rearraging: {out.shape} in Patch Embedding Layer")
80
  actual_h = x.shape[2] # Height from input tensor
81
  actual_w = x.shape[3] # Width from input tensor
82
  grid_size_h = actual_h // self.patch_height
83
  grid_size_w = actual_w // self.patch_width
84
-
85
-
86
  out = self.patch_embed(out)
87
  pos_emb = get_patch_positional_embedding(
88
  self.hidden_dim, grid_size=(grid_size_h, grid_size_w), device=x.device
89
  )
90
 
91
  out += pos_emb
92
- print(f"Patch Embeddings: {out.shape}\n")
93
  return out
94
 
95
 
 
10
  grid_size_h, grid_size_w = grid_size
11
  grid_h = torch.arange(grid_size_h, dtype=torch.float32, device=device)
12
  grid_w = torch.arange(grid_size_w, dtype=torch.float32, device=device)
13
+ grid = torch.meshgrid(grid_h, grid_w, indexing="ij")
14
  grid = torch.stack(grid, dim=0)
15
 
16
  grid_h_positions = grid[0].reshape(-1)
 
30
  pos_emb = torch.cat([grid_h_emb, grid_w_emb], dim=-1)
31
 
32
  # pos_emb -> (Number of patch tokens, pos_emb_dim)
 
 
33
 
34
  return pos_emb
35
 
 
66
  nn.init.constant_(self.patch_embed[0].bias, 0)
67
 
68
  def forward(self, x):
 
69
  out = rearrange(
70
  x,
71
  "b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)",
 
73
  pw=self.patch_width,
74
  )
75
 
 
76
  actual_h = x.shape[2] # Height from input tensor
77
  actual_w = x.shape[3] # Width from input tensor
78
  grid_size_h = actual_h // self.patch_height
79
  grid_size_w = actual_w // self.patch_width
 
 
80
  out = self.patch_embed(out)
81
  pos_emb = get_patch_positional_embedding(
82
  self.hidden_dim, grid_size=(grid_size_h, grid_size_w), device=x.device
83
  )
84
 
85
  out += pos_emb
 
86
  return out
87
 
88
 
model/transformer.py CHANGED
@@ -93,8 +93,6 @@ class DIT(nn.Module):
93
  1 + pre_mlp_scale.unsqueeze(1)
94
  ) + pre_mlp_shift.unsqueeze(1)
95
 
96
- print(f"\nOutput before unpatchify block {out.shape} in DIT block")
97
-
98
  actual_h = x.shape[2] # Height from input tensor
99
  actual_w = x.shape[3] # Width from input tensor
100
  actual_nh = actual_h // self.patch_height
@@ -102,7 +100,6 @@ class DIT(nn.Module):
102
 
103
  # Unpatichify
104
  out = self.out_proj(out)
105
- print(f"Output after projection: {out.shape} in DIT block")
106
  out = rearrange(
107
  out,
108
  "b (nh nw) (ph pw c) -> b c (nh ph) (nw pw)",
@@ -112,7 +109,6 @@ class DIT(nn.Module):
112
  nh=actual_nh,
113
  )
114
 
115
- print(f"Output after unpatchify block: {out.shape} in DIT block\n")
116
  return out
117
 
118
 
 
93
  1 + pre_mlp_scale.unsqueeze(1)
94
  ) + pre_mlp_shift.unsqueeze(1)
95
 
 
 
96
  actual_h = x.shape[2] # Height from input tensor
97
  actual_w = x.shape[3] # Width from input tensor
98
  actual_nh = actual_h // self.patch_height
 
100
 
101
  # Unpatichify
102
  out = self.out_proj(out)
 
103
  out = rearrange(
104
  out,
105
  "b (nh nw) (ph pw c) -> b c (nh ph) (nw pw)",
 
109
  nh=actual_nh,
110
  )
111
 
 
112
  return out
113
 
114