Update model_core.py
Browse files- model_core.py +26 -18
model_core.py
CHANGED
|
@@ -76,41 +76,49 @@ class Mlp(nn.Module):
|
|
| 76 |
return x
|
| 77 |
|
| 78 |
|
| 79 |
-
class SinCos2DEmbed(nn.Module):
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
| 81 |
super().__init__()
|
| 82 |
|
| 83 |
def forward(self, x):
|
| 84 |
# x has the shape [batch_size, embed_dim, grid_length, grid_height]
|
| 85 |
-
|
| 86 |
-
_, embed_dim, grid_length, grid_height = x.shape
|
| 87 |
|
| 88 |
# Create grid positions
|
| 89 |
grid_length_a = torch.arange(grid_length, dtype=torch.float32, device=x.device)
|
| 90 |
grid_height_a = torch.arange(grid_height, dtype=torch.float32, device=x.device)
|
| 91 |
-
grid = torch.meshgrid(
|
|
|
|
| 92 |
|
| 93 |
-
sub_embed_dim = embed_dim
|
| 94 |
omega = torch.arange(sub_embed_dim, dtype=torch.float32, device=x.device)
|
| 95 |
omega /= sub_embed_dim
|
| 96 |
-
omega = 1.0 / 10000**omega
|
| 97 |
|
| 98 |
-
# embed_length
|
| 99 |
-
out_length = torch.einsum("mn,d->dmn", grid[0],
|
| 100 |
embed_length_sin = torch.sin(out_length)
|
| 101 |
embed_length_cos = torch.cos(out_length)
|
| 102 |
-
embed_length = torch.cat([embed_length_sin, embed_length_cos], dim=0)
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
-
|
| 111 |
-
embed = torch.cat([embed_length, embed_height], dim=0).unsqueeze(dim=0)
|
| 112 |
|
| 113 |
-
x = x + embed
|
| 114 |
return x
|
| 115 |
|
| 116 |
|
|
|
|
| 76 |
return x
|
| 77 |
|
| 78 |
|
| 79 |
+
class SinCos2DEmbed(torch.nn.Module):
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
):
|
| 84 |
super().__init__()
|
| 85 |
|
| 86 |
def forward(self, x):
|
| 87 |
# x has the shape [batch_size, embed_dim, grid_length, grid_height]
|
| 88 |
+
batch_size, embed_dim, grid_length, grid_height = x.shape
|
|
|
|
| 89 |
|
| 90 |
# Create grid positions
|
| 91 |
grid_length_a = torch.arange(grid_length, dtype=torch.float32, device=x.device)
|
| 92 |
grid_height_a = torch.arange(grid_height, dtype=torch.float32, device=x.device)
|
| 93 |
+
grid = torch.meshgrid(grid_height_a, grid_length_a, indexing="xy")
|
| 94 |
+
|
| 95 |
|
| 96 |
+
sub_embed_dim = embed_dim//4
|
| 97 |
omega = torch.arange(sub_embed_dim, dtype=torch.float32, device=x.device)
|
| 98 |
omega /= sub_embed_dim
|
| 99 |
+
omega = 1.0 / 10000**omega
|
| 100 |
|
| 101 |
+
# embed_length
|
| 102 |
+
out_length = torch.einsum("mn,d->dmn", grid[0],omega)
|
| 103 |
embed_length_sin = torch.sin(out_length)
|
| 104 |
embed_length_cos = torch.cos(out_length)
|
|
|
|
| 105 |
|
| 106 |
+
embed_length = torch.concatenate([embed_length_sin,embed_length_cos],dim=0)
|
| 107 |
+
|
| 108 |
+
# embed_heigth
|
| 109 |
+
|
| 110 |
+
out_heigth = torch.einsum("mn,d->dmn", grid[1], omega)
|
| 111 |
+
embed_heigth_sin = torch.sin(out_heigth)
|
| 112 |
+
embed_heigth_cos = torch.cos(out_heigth)
|
| 113 |
+
|
| 114 |
+
embed_heigth = torch.concatenate([embed_heigth_sin,embed_heigth_cos],dim=0)
|
| 115 |
+
|
| 116 |
+
# concat length and heigth
|
| 117 |
+
|
| 118 |
+
embed = torch.concatenate([embed_length, embed_heigth],dim=0).unsqueeze(dim=0)
|
| 119 |
|
| 120 |
+
x = x + embed
|
|
|
|
| 121 |
|
|
|
|
| 122 |
return x
|
| 123 |
|
| 124 |
|