| | """
|
| | custom_objects.py - Fully Fixed & Compatible with TF 2.10+ / HF Spaces
|
| | """
|
| |
|
| | import tensorflow as tf
|
| | from tensorflow.keras import layers
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | try:
|
| | Identity = layers.Identity
|
| | except AttributeError:
|
| | class Identity(layers.Layer):
|
| | def call(self, inputs):
|
| | return inputs
|
| |
|
| | def compute_output_shape(self, input_shape):
|
| | return input_shape
|
| |
|
| | def get_config(self):
|
| | return super().get_config()
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class ClassToken(layers.Layer):
|
| | def __init__(self, name="class_token", **kwargs):
|
| | super().__init__(name=name, **kwargs)
|
| | self.supports_masking = True
|
| |
|
| | def build(self, input_shape):
|
| | embed_dim = input_shape[-1]
|
| |
|
| | self.cls = self.add_weight(
|
| | "cls_token",
|
| | shape=(1, 1, embed_dim),
|
| | initializer="zeros",
|
| | trainable=True
|
| | )
|
| | super().build(input_shape)
|
| |
|
| | def call(self, x):
|
| | b = tf.shape(x)[0]
|
| | cls = tf.tile(self.cls, [b, 1, 1])
|
| | return tf.concat([cls, x], axis=1)
|
| |
|
| |
|
| | class PatchEmbeddings(layers.Layer):
|
| | def __init__(self, patch_size=16, embed_dim=768, **kwargs):
|
| | super().__init__(**kwargs)
|
| | self.patch_size = patch_size
|
| | self.embed_dim = embed_dim
|
| |
|
| | def build(self, input_shape):
|
| | self.proj = layers.Conv2D(
|
| | filters=self.embed_dim,
|
| | kernel_size=self.patch_size,
|
| | strides=self.patch_size,
|
| | padding="valid"
|
| | )
|
| | super().build(input_shape)
|
| |
|
| | def call(self, x):
|
| | x = self.proj(x)
|
| | B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
|
| | x = tf.reshape(x, [B, H * W, C])
|
| | return x
|
| |
|
| |
|
| | class AddPositionEmbs(layers.Layer):
|
| | def __init__(self, initializer="zeros", **kwargs):
|
| | super().__init__(**kwargs)
|
| | self.initializer = initializer
|
| |
|
| | def build(self, input_shape):
|
| | seq_len, dim = input_shape[1], input_shape[2]
|
| |
|
| | self.pe = self.add_weight(
|
| | "position_embeddings",
|
| | shape=(1, seq_len, dim),
|
| | initializer=self.initializer,
|
| | trainable=True
|
| | )
|
| | super().build(input_shape)
|
| |
|
| | def call(self, x):
|
| | x_len = tf.shape(x)[1]
|
| | pe_len = tf.shape(self.pe)[1]
|
| | dim = tf.shape(self.pe)[2]
|
| |
|
| |
|
| | if x_len == pe_len:
|
| | return x + self.pe
|
| |
|
| |
|
| | pe = tf.reshape(self.pe, (1, pe_len, dim, 1))
|
| | pe = tf.image.resize(pe, (x_len, dim))
|
| | pe = tf.reshape(pe, (1, x_len, dim))
|
| |
|
| | pe = tf.cast(pe, x.dtype)
|
| |
|
| | return x + pe
|
| |
|
| | class TransformerBlock(layers.Layer):
|
| | def __init__(self, num_heads=12, mlp_dim=3072, dropout_rate=0.1, **kwargs):
|
| | super().__init__(**kwargs)
|
| | self.num_heads = num_heads
|
| | self.mlp_dim = mlp_dim
|
| | self.dropout_rate = dropout_rate
|
| |
|
| | def build(self, input_shape):
|
| | dim = input_shape[-1]
|
| |
|
| | self.norm1 = layers.LayerNormalization(epsilon=1e-6)
|
| | self.att = layers.MultiHeadAttention(
|
| | num_heads=self.num_heads,
|
| | key_dim=dim // self.num_heads,
|
| | )
|
| | self.drop1 = layers.Dropout(self.dropout_rate)
|
| |
|
| | self.norm2 = layers.LayerNormalization(epsilon=1e-6)
|
| | self.d1 = layers.Dense(self.mlp_dim, activation="gelu")
|
| | self.drop2 = layers.Dropout(self.dropout_rate)
|
| | self.d2 = layers.Dense(dim)
|
| | self.drop3 = layers.Dropout(self.dropout_rate)
|
| |
|
| | super().build(input_shape)
|
| |
|
| | def call(self, x, training=None):
|
| | h = self.norm1(x)
|
| | h = self.att(h, h)
|
| | h = self.drop1(h, training=training)
|
| | x = x + h
|
| |
|
| | h = self.norm2(x)
|
| | h = self.d1(h)
|
| | h = self.drop2(h, training=training)
|
| | h = self.d2(h)
|
| | h = self.drop3(h, training=training)
|
| | return x + h
|
| |
|
| |
|
| | class ExtractToken(layers.Layer):
|
| | def call(self, x):
|
| | return x[:, 0]
|
| |
|
| |
|
| | class MlpBlock(layers.Layer):
|
| | def __init__(self, hidden_dim=3072, dropout=0.1, activation="gelu", **kwargs):
|
| | super().__init__(**kwargs)
|
| | self.hidden_dim = hidden_dim
|
| | self.dropout = dropout
|
| | self.activation = activation
|
| |
|
| | def build(self, input_shape):
|
| | self.d1 = layers.Dense(self.hidden_dim)
|
| | self.d2 = layers.Dense(input_shape[-1])
|
| | self.drop1 = layers.Dropout(self.dropout)
|
| | self.drop2 = layers.Dropout(self.dropout)
|
| | super().build(input_shape)
|
| |
|
| | def call(self, x, training=None):
|
| | h = self.d1(x)
|
| | h = tf.nn.gelu(h) if self.activation == "gelu" else tf.nn.relu(h)
|
| | h = self.drop1(h, training=training)
|
| | h = self.d2(h)
|
| | return self.drop2(h, training=training)
|
| |
|
| |
|
| | class SimpleMultiHeadAttention(layers.Layer):
|
| | def __init__(self, num_heads=8, key_dim=64, **kwargs):
|
| | super().__init__(**kwargs)
|
| | self.num_heads = num_heads
|
| | self.key_dim = key_dim
|
| |
|
| | def build(self, input_shape):
|
| | self.mha = layers.MultiHeadAttention(
|
| | num_heads=self.num_heads,
|
| | key_dim=self.key_dim
|
| | )
|
| | super().build(input_shape)
|
| |
|
| | def call(self, x):
|
| | return self.mha(x, x)
|
| |
|
| |
|
| | class FixedDropout(layers.Dropout):
|
| | pass
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def get_custom_objects():
|
| | return {
|
| | "Identity": Identity,
|
| | "ClassToken": ClassToken,
|
| | "PatchEmbeddings": PatchEmbeddings,
|
| | "AddPositionEmbs": AddPositionEmbs,
|
| | "TransformerBlock": TransformerBlock,
|
| | "ExtractToken": ExtractToken,
|
| | "MlpBlock": MlpBlock,
|
| | "SimpleMultiHeadAttention": SimpleMultiHeadAttention,
|
| | "FixedDropout": FixedDropout,
|
| |
|
| |
|
| | "MultiHeadAttention": layers.MultiHeadAttention,
|
| | "LayerNormalization": layers.LayerNormalization,
|
| | "Dropout": layers.Dropout,
|
| | "Dense": layers.Dense,
|
| | "Conv2D": layers.Conv2D,
|
| | "Flatten": layers.Flatten,
|
| | "Reshape": layers.Reshape,
|
| | "Activation": layers.Activation,
|
| |
|
| |
|
| | "gelu": tf.nn.gelu,
|
| | "swish": tf.nn.swish,
|
| | "relu": tf.nn.relu,
|
| | "sigmoid": tf.nn.sigmoid,
|
| | "softmax": tf.nn.softmax,
|
| | }
|
| |
|