HTill commited on
Commit
dc410bf
·
verified ·
1 Parent(s): 86c2a7c

Update model_core.py

Browse files
Files changed (1) hide show
  1. model_core.py +26 -18
model_core.py CHANGED
@@ -76,41 +76,49 @@ class Mlp(nn.Module):
76
  return x
77
 
78
 
79
- class SinCos2DEmbed(nn.Module):
80
- def __init__(self):
 
 
 
81
  super().__init__()
82
 
83
  def forward(self, x):
84
  # x has the shape [batch_size, embed_dim, grid_length, grid_height]
85
- # Note: grid_length corresponds to H (Time/Frequency), grid_height to W
86
- _, embed_dim, grid_length, grid_height = x.shape
87
 
88
  # Create grid positions
89
  grid_length_a = torch.arange(grid_length, dtype=torch.float32, device=x.device)
90
  grid_height_a = torch.arange(grid_height, dtype=torch.float32, device=x.device)
91
- grid = torch.meshgrid(grid_length_a, grid_height_a, indexing="xy")
 
92
 
93
- sub_embed_dim = embed_dim // 4
94
  omega = torch.arange(sub_embed_dim, dtype=torch.float32, device=x.device)
95
  omega /= sub_embed_dim
96
- omega = 1.0 / 10000**omega
97
 
98
- # embed_length (dimension 0 of grid)
99
- out_length = torch.einsum("mn,d->dmn", grid[0], omega)
100
  embed_length_sin = torch.sin(out_length)
101
  embed_length_cos = torch.cos(out_length)
102
- embed_length = torch.cat([embed_length_sin, embed_length_cos], dim=0)
103
 
104
- # embed_height (dimension 1 of grid)
105
- out_height = torch.einsum("mn,d->dmn", grid[1], omega)
106
- embed_height_sin = torch.sin(out_height)
107
- embed_height_cos = torch.cos(out_height)
108
- embed_height = torch.cat([embed_height_sin, embed_height_cos], dim=0)
 
 
 
 
 
 
 
 
109
 
110
- # concat length and height embeddings
111
- embed = torch.cat([embed_length, embed_height], dim=0).unsqueeze(dim=0)
112
 
113
- x = x + embed
114
  return x
115
 
116
 
 
76
  return x
77
 
78
 
79
+ class SinCos2DEmbed(torch.nn.Module):
80
+
81
+ def __init__(
82
+ self,
83
+ ):
84
  super().__init__()
85
 
86
  def forward(self, x):
87
  # x has the shape [batch_size, embed_dim, grid_length, grid_height]
88
+ batch_size, embed_dim, grid_length, grid_height = x.shape
 
89
 
90
  # Create grid positions
91
  grid_length_a = torch.arange(grid_length, dtype=torch.float32, device=x.device)
92
  grid_height_a = torch.arange(grid_height, dtype=torch.float32, device=x.device)
93
+ grid = torch.meshgrid(grid_height_a, grid_length_a, indexing="xy")
94
+
95
 
96
+ sub_embed_dim = embed_dim//4
97
  omega = torch.arange(sub_embed_dim, dtype=torch.float32, device=x.device)
98
  omega /= sub_embed_dim
99
+ omega = 1.0 / 10000**omega
100
 
101
+ # embed_length
102
+ out_length = torch.einsum("mn,d->dmn", grid[0],omega)
103
  embed_length_sin = torch.sin(out_length)
104
  embed_length_cos = torch.cos(out_length)
 
105
 
106
+ embed_length = torch.concatenate([embed_length_sin,embed_length_cos],dim=0)
107
+
108
+ # embed_heigth
109
+
110
+ out_heigth = torch.einsum("mn,d->dmn", grid[1], omega)
111
+ embed_heigth_sin = torch.sin(out_heigth)
112
+ embed_heigth_cos = torch.cos(out_heigth)
113
+
114
+ embed_heigth = torch.concatenate([embed_heigth_sin,embed_heigth_cos],dim=0)
115
+
116
+ # concat length and heigth
117
+
118
+ embed = torch.concatenate([embed_length, embed_heigth],dim=0).unsqueeze(dim=0)
119
 
120
+ x = x + embed
 
121
 
 
122
  return x
123
 
124