Update ViT.py
Browse files
ViT.py
CHANGED
|
@@ -101,8 +101,18 @@ class ViT(Model):
|
|
| 101 |
self.to_patch_embedding.add(Dense(dim))
|
| 102 |
self.to_patch_embedding.add(LayerNormalization())
|
| 103 |
|
| 104 |
-
self.pos_embedding =
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
self.dropout = Dropout(emb_dropout)
|
| 107 |
|
| 108 |
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, drop_rate)
|
|
|
|
| 101 |
self.to_patch_embedding.add(Dense(dim))
|
| 102 |
self.to_patch_embedding.add(LayerNormalization())
|
| 103 |
|
| 104 |
+
self.pos_embedding = self.add_weight(
|
| 105 |
+
name='pos_embedding',
|
| 106 |
+
shape=(1, self.num_patches + 1, self.dim),
|
| 107 |
+
initializer=tf.keras.initializers.RandomNormal(stddev=0.02), # 设定标准差 stddev
|
| 108 |
+
trainable=True
|
| 109 |
+
)
|
| 110 |
+
self.cls_token = self.add_weight(
|
| 111 |
+
name='cls_token',
|
| 112 |
+
shape=(1, 1, self.dim),
|
| 113 |
+
initializer=tf.keras.initializers.RandomNormal(stddev=0.02), # 设定标准差 stddev
|
| 114 |
+
trainable=True
|
| 115 |
+
)
|
| 116 |
self.dropout = Dropout(emb_dropout)
|
| 117 |
|
| 118 |
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, drop_rate)
|