Commit
·
c3b4b22
1
Parent(s):
81dc951
Update augvit_model.py
Browse files- augvit_model.py +9 -9
augvit_model.py
CHANGED
|
@@ -117,8 +117,8 @@ class AUGViT(Model):
|
|
| 117 |
self.patch_den= nn.Dense(units=dim,name='patchden')
|
| 118 |
|
| 119 |
|
| 120 |
-
self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim]),name='pos_emb',trainable=True)
|
| 121 |
-
|
| 122 |
self.dropout = nn.Dropout(rate=emb_dropout,name='drop')
|
| 123 |
|
| 124 |
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=dropout,name='trans')
|
|
@@ -135,13 +135,13 @@ class AUGViT(Model):
|
|
| 135 |
x = self.patch_den(x)
|
| 136 |
b, n, d = x.shape
|
| 137 |
print(x.shape)
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
x += self.pos_embedding[:, :(n )]
|
| 145 |
print(x.shape)
|
| 146 |
x = self.dropout(x, training=training)
|
| 147 |
print(x.shape)
|
|
|
|
| 117 |
self.patch_den= nn.Dense(units=dim,name='patchden')
|
| 118 |
|
| 119 |
|
| 120 |
+
# self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim]),name='pos_emb',trainable=True)
|
| 121 |
+
self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls',trainable=True)
|
| 122 |
self.dropout = nn.Dropout(rate=emb_dropout,name='drop')
|
| 123 |
|
| 124 |
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=dropout,name='trans')
|
|
|
|
| 135 |
x = self.patch_den(x)
|
| 136 |
b, n, d = x.shape
|
| 137 |
print(x.shape)
|
| 138 |
+
cls_tokens = tf.cast(
|
| 139 |
+
tf.broadcast_to(self.cls_token, [b, 1, d]),
|
| 140 |
+
dtype=x.dtype,
|
| 141 |
+
)
|
| 142 |
+
x = tf.concat([cls_tokens, x], axis=1)
|
| 143 |
+
print(x.shape,cls_tokens.shape )
|
| 144 |
+
# x += self.pos_embedding[:, :(n+1 )]
|
| 145 |
print(x.shape)
|
| 146 |
x = self.dropout(x, training=training)
|
| 147 |
print(x.shape)
|