""" custom_objects.py - Fully Fixed & Compatible with TF 2.10+ / HF Spaces """ import tensorflow as tf from tensorflow.keras import layers # ====================================================== # COMPATIBILITY IDENTITY LAYER # ====================================================== # Fallback Identity for environments lacking tf.keras.layers.Identity try: Identity = layers.Identity except AttributeError: class Identity(layers.Layer): def call(self, inputs): return inputs def compute_output_shape(self, input_shape): return input_shape def get_config(self): return super().get_config() # ====================================================== # VISION TRANSFORMER LAYERS # ====================================================== class ClassToken(layers.Layer): def __init__(self, name="class_token", **kwargs): super().__init__(name=name, **kwargs) self.supports_masking = True def build(self, input_shape): embed_dim = input_shape[-1] self.cls = self.add_weight( "cls_token", shape=(1, 1, embed_dim), initializer="zeros", trainable=True ) super().build(input_shape) def call(self, x): b = tf.shape(x)[0] cls = tf.tile(self.cls, [b, 1, 1]) return tf.concat([cls, x], axis=1) class PatchEmbeddings(layers.Layer): def __init__(self, patch_size=16, embed_dim=768, **kwargs): super().__init__(**kwargs) self.patch_size = patch_size self.embed_dim = embed_dim def build(self, input_shape): self.proj = layers.Conv2D( filters=self.embed_dim, kernel_size=self.patch_size, strides=self.patch_size, padding="valid" ) super().build(input_shape) def call(self, x): x = self.proj(x) B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3] x = tf.reshape(x, [B, H * W, C]) return x class AddPositionEmbs(layers.Layer): def __init__(self, initializer="zeros", **kwargs): super().__init__(**kwargs) self.initializer = initializer def build(self, input_shape): seq_len, dim = input_shape[1], input_shape[2] self.pe = self.add_weight( "position_embeddings", shape=(1, seq_len, dim), initializer=self.initializer, trainable=True ) super().build(input_shape) def call(self, x): x_len = tf.shape(x)[1] pe_len = tf.shape(self.pe)[1] dim = tf.shape(self.pe)[2] # If same length → normal addition if x_len == pe_len: return x + self.pe # Resize positional embeddings correctly pe = tf.reshape(self.pe, (1, pe_len, dim, 1)) # to NHWC pe = tf.image.resize(pe, (x_len, dim)) # resize LENGTH only pe = tf.reshape(pe, (1, x_len, dim)) # back to (1, L, D) pe = tf.cast(pe, x.dtype) return x + pe class TransformerBlock(layers.Layer): def __init__(self, num_heads=12, mlp_dim=3072, dropout_rate=0.1, **kwargs): super().__init__(**kwargs) self.num_heads = num_heads self.mlp_dim = mlp_dim self.dropout_rate = dropout_rate def build(self, input_shape): dim = input_shape[-1] self.norm1 = layers.LayerNormalization(epsilon=1e-6) self.att = layers.MultiHeadAttention( num_heads=self.num_heads, key_dim=dim // self.num_heads, ) self.drop1 = layers.Dropout(self.dropout_rate) self.norm2 = layers.LayerNormalization(epsilon=1e-6) self.d1 = layers.Dense(self.mlp_dim, activation="gelu") self.drop2 = layers.Dropout(self.dropout_rate) self.d2 = layers.Dense(dim) self.drop3 = layers.Dropout(self.dropout_rate) super().build(input_shape) def call(self, x, training=None): h = self.norm1(x) h = self.att(h, h) h = self.drop1(h, training=training) x = x + h h = self.norm2(x) h = self.d1(h) h = self.drop2(h, training=training) h = self.d2(h) h = self.drop3(h, training=training) return x + h class ExtractToken(layers.Layer): def call(self, x): return x[:, 0] class MlpBlock(layers.Layer): def __init__(self, hidden_dim=3072, dropout=0.1, activation="gelu", **kwargs): super().__init__(**kwargs) self.hidden_dim = hidden_dim self.dropout = dropout self.activation = activation def build(self, input_shape): self.d1 = layers.Dense(self.hidden_dim) self.d2 = layers.Dense(input_shape[-1]) self.drop1 = layers.Dropout(self.dropout) self.drop2 = layers.Dropout(self.dropout) super().build(input_shape) def call(self, x, training=None): h = self.d1(x) h = tf.nn.gelu(h) if self.activation == "gelu" else tf.nn.relu(h) h = self.drop1(h, training=training) h = self.d2(h) return self.drop2(h, training=training) class SimpleMultiHeadAttention(layers.Layer): def __init__(self, num_heads=8, key_dim=64, **kwargs): super().__init__(**kwargs) self.num_heads = num_heads self.key_dim = key_dim def build(self, input_shape): self.mha = layers.MultiHeadAttention( num_heads=self.num_heads, key_dim=self.key_dim ) super().build(input_shape) def call(self, x): return self.mha(x, x) class FixedDropout(layers.Dropout): pass # define a placeholder FixedDropout so H5 can load # ====================================================== # RETURN ALL CUSTOM OBJECTS # ====================================================== def get_custom_objects(): return { "Identity": Identity, "ClassToken": ClassToken, "PatchEmbeddings": PatchEmbeddings, "AddPositionEmbs": AddPositionEmbs, "TransformerBlock": TransformerBlock, "ExtractToken": ExtractToken, "MlpBlock": MlpBlock, "SimpleMultiHeadAttention": SimpleMultiHeadAttention, "FixedDropout": FixedDropout, # Standard layers exposed for H5 compatibility "MultiHeadAttention": layers.MultiHeadAttention, "LayerNormalization": layers.LayerNormalization, "Dropout": layers.Dropout, "Dense": layers.Dense, "Conv2D": layers.Conv2D, "Flatten": layers.Flatten, "Reshape": layers.Reshape, "Activation": layers.Activation, # Activations "gelu": tf.nn.gelu, "swish": tf.nn.swish, "relu": tf.nn.relu, "sigmoid": tf.nn.sigmoid, "softmax": tf.nn.softmax, }