Update CLIP.py
Browse files
CLIP.py
CHANGED
|
@@ -54,9 +54,14 @@ class Bottleneck(tf.keras.layers.Layer):
|
|
| 54 |
return out
|
| 55 |
|
| 56 |
|
| 57 |
-
class AttentionPool2d:
|
| 58 |
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 59 |
-
self.positional_embedding =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
self.k_proj = Dense(embed_dim)
|
| 61 |
self.q_proj = Dense(embed_dim)
|
| 62 |
self.v_proj = Dense(embed_dim)
|
|
@@ -213,15 +218,25 @@ class Transformer:
|
|
| 213 |
return self.resblocks(x)
|
| 214 |
|
| 215 |
|
| 216 |
-
class VisionTransformer:
|
| 217 |
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
| 218 |
self.input_resolution = input_resolution
|
| 219 |
self.output_dim = output_dim
|
| 220 |
self.conv1 = Conv2d(width, kernel_size=patch_size, strides=patch_size, use_bias=False)
|
| 221 |
|
| 222 |
scale = width ** -0.5
|
| 223 |
-
self.class_embedding =
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
self.ln_pre = LayerNorm(width)
|
| 226 |
|
| 227 |
self.transformer = Transformer(width, layers, heads)
|
|
@@ -296,17 +311,32 @@ class CLIP(Model):
|
|
| 296 |
)
|
| 297 |
|
| 298 |
self.vocab_size = vocab_size
|
| 299 |
-
self.token_embedding =
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
self.ln_final = LayerNorm(transformer_width)
|
| 305 |
|
| 306 |
-
self.text_projection =
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
def build_attention_mask(self):
|
| 312 |
mask = tf.ones((self.context_length, self.context_length))
|
|
|
|
| 54 |
return out
|
| 55 |
|
| 56 |
|
| 57 |
+
class AttentionPool2d(tf.keras.layers.Layer):
|
| 58 |
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 59 |
+
self.positional_embedding = self.add_weight(
|
| 60 |
+
name='positional_embedding',
|
| 61 |
+
shape=[self.spacial_dim ** 2 + 1, self.embed_dim],
|
| 62 |
+
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1./self.embed_dim**0.5),
|
| 63 |
+
trainable=True
|
| 64 |
+
)
|
| 65 |
self.k_proj = Dense(embed_dim)
|
| 66 |
self.q_proj = Dense(embed_dim)
|
| 67 |
self.v_proj = Dense(embed_dim)
|
|
|
|
| 218 |
return self.resblocks(x)
|
| 219 |
|
| 220 |
|
| 221 |
+
class VisionTransformer(tf.keras.layers.Layer):
|
| 222 |
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
| 223 |
self.input_resolution = input_resolution
|
| 224 |
self.output_dim = output_dim
|
| 225 |
self.conv1 = Conv2d(width, kernel_size=patch_size, strides=patch_size, use_bias=False)
|
| 226 |
|
| 227 |
scale = width ** -0.5
|
| 228 |
+
self.class_embedding = self.add_weight(
|
| 229 |
+
name='class_embedding',
|
| 230 |
+
shape=[self.width],
|
| 231 |
+
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.0) * self.scale,
|
| 232 |
+
trainable=True
|
| 233 |
+
)
|
| 234 |
+
self.positional_embedding = self.add_weight(
|
| 235 |
+
name='positional_embedding',
|
| 236 |
+
shape=[(self.input_resolution // self.patch_size) ** 2 + 1, self.width],
|
| 237 |
+
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.0) * self.scale,
|
| 238 |
+
trainable=True
|
| 239 |
+
)
|
| 240 |
self.ln_pre = LayerNorm(width)
|
| 241 |
|
| 242 |
self.transformer = Transformer(width, layers, heads)
|
|
|
|
| 311 |
)
|
| 312 |
|
| 313 |
self.vocab_size = vocab_size
|
| 314 |
+
self.token_embedding = self.add_weight(
|
| 315 |
+
name='token_embedding',
|
| 316 |
+
shape=(vocab_size, transformer_width),
|
| 317 |
+
initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
|
| 318 |
+
trainable=True
|
| 319 |
+
)
|
| 320 |
+
self.positional_embedding = self.add_weight(
|
| 321 |
+
name='positional_embedding',
|
| 322 |
+
shape=(self.context_length, transformer_width),
|
| 323 |
+
initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
|
| 324 |
+
trainable=True
|
| 325 |
+
)
|
| 326 |
self.ln_final = LayerNorm(transformer_width)
|
| 327 |
|
| 328 |
+
self.text_projection = self.add_weight(
|
| 329 |
+
name='text_projection',
|
| 330 |
+
shape=(transformer_width, embed_dim),
|
| 331 |
+
initializer=tf.keras.initializers.RandomNormal(stddev=transformer_width ** -0.5),
|
| 332 |
+
trainable=True
|
| 333 |
+
)
|
| 334 |
+
self.logit_scale = self.add_weight(
|
| 335 |
+
name='logit_scale',
|
| 336 |
+
shape=[],
|
| 337 |
+
initializer=tf.keras.initializers.Constant(np.log(1 / 0.07)),
|
| 338 |
+
trainable=True
|
| 339 |
+
)
|
| 340 |
|
| 341 |
def build_attention_mask(self):
|
| 342 |
mask = tf.ones((self.context_length, self.context_length))
|