Fake / custom_objects.py
eesfeg's picture
fake
0220026
"""
custom_objects.py - Fully Fixed & Compatible with TF 2.10+ / HF Spaces
"""
import tensorflow as tf
from tensorflow.keras import layers
# ======================================================
# COMPATIBILITY IDENTITY LAYER
# ======================================================
# Fallback Identity for environments lacking tf.keras.layers.Identity
try:
Identity = layers.Identity
except AttributeError:
class Identity(layers.Layer):
def call(self, inputs):
return inputs
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
return super().get_config()
# ======================================================
# VISION TRANSFORMER LAYERS
# ======================================================
class ClassToken(layers.Layer):
def __init__(self, name="class_token", **kwargs):
super().__init__(name=name, **kwargs)
self.supports_masking = True
def build(self, input_shape):
embed_dim = input_shape[-1]
self.cls = self.add_weight(
"cls_token",
shape=(1, 1, embed_dim),
initializer="zeros",
trainable=True
)
super().build(input_shape)
def call(self, x):
b = tf.shape(x)[0]
cls = tf.tile(self.cls, [b, 1, 1])
return tf.concat([cls, x], axis=1)
class PatchEmbeddings(layers.Layer):
def __init__(self, patch_size=16, embed_dim=768, **kwargs):
super().__init__(**kwargs)
self.patch_size = patch_size
self.embed_dim = embed_dim
def build(self, input_shape):
self.proj = layers.Conv2D(
filters=self.embed_dim,
kernel_size=self.patch_size,
strides=self.patch_size,
padding="valid"
)
super().build(input_shape)
def call(self, x):
x = self.proj(x)
B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
x = tf.reshape(x, [B, H * W, C])
return x
class AddPositionEmbs(layers.Layer):
def __init__(self, initializer="zeros", **kwargs):
super().__init__(**kwargs)
self.initializer = initializer
def build(self, input_shape):
seq_len, dim = input_shape[1], input_shape[2]
self.pe = self.add_weight(
"position_embeddings",
shape=(1, seq_len, dim),
initializer=self.initializer,
trainable=True
)
super().build(input_shape)
def call(self, x):
x_len = tf.shape(x)[1]
pe_len = tf.shape(self.pe)[1]
dim = tf.shape(self.pe)[2]
# If same length → normal addition
if x_len == pe_len:
return x + self.pe
# Resize positional embeddings correctly
pe = tf.reshape(self.pe, (1, pe_len, dim, 1)) # to NHWC
pe = tf.image.resize(pe, (x_len, dim)) # resize LENGTH only
pe = tf.reshape(pe, (1, x_len, dim)) # back to (1, L, D)
pe = tf.cast(pe, x.dtype)
return x + pe
class TransformerBlock(layers.Layer):
def __init__(self, num_heads=12, mlp_dim=3072, dropout_rate=0.1, **kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.mlp_dim = mlp_dim
self.dropout_rate = dropout_rate
def build(self, input_shape):
dim = input_shape[-1]
self.norm1 = layers.LayerNormalization(epsilon=1e-6)
self.att = layers.MultiHeadAttention(
num_heads=self.num_heads,
key_dim=dim // self.num_heads,
)
self.drop1 = layers.Dropout(self.dropout_rate)
self.norm2 = layers.LayerNormalization(epsilon=1e-6)
self.d1 = layers.Dense(self.mlp_dim, activation="gelu")
self.drop2 = layers.Dropout(self.dropout_rate)
self.d2 = layers.Dense(dim)
self.drop3 = layers.Dropout(self.dropout_rate)
super().build(input_shape)
def call(self, x, training=None):
h = self.norm1(x)
h = self.att(h, h)
h = self.drop1(h, training=training)
x = x + h
h = self.norm2(x)
h = self.d1(h)
h = self.drop2(h, training=training)
h = self.d2(h)
h = self.drop3(h, training=training)
return x + h
class ExtractToken(layers.Layer):
def call(self, x):
return x[:, 0]
class MlpBlock(layers.Layer):
def __init__(self, hidden_dim=3072, dropout=0.1, activation="gelu", **kwargs):
super().__init__(**kwargs)
self.hidden_dim = hidden_dim
self.dropout = dropout
self.activation = activation
def build(self, input_shape):
self.d1 = layers.Dense(self.hidden_dim)
self.d2 = layers.Dense(input_shape[-1])
self.drop1 = layers.Dropout(self.dropout)
self.drop2 = layers.Dropout(self.dropout)
super().build(input_shape)
def call(self, x, training=None):
h = self.d1(x)
h = tf.nn.gelu(h) if self.activation == "gelu" else tf.nn.relu(h)
h = self.drop1(h, training=training)
h = self.d2(h)
return self.drop2(h, training=training)
class SimpleMultiHeadAttention(layers.Layer):
def __init__(self, num_heads=8, key_dim=64, **kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.key_dim = key_dim
def build(self, input_shape):
self.mha = layers.MultiHeadAttention(
num_heads=self.num_heads,
key_dim=self.key_dim
)
super().build(input_shape)
def call(self, x):
return self.mha(x, x)
class FixedDropout(layers.Dropout):
pass
# define a placeholder FixedDropout so H5 can load
# ======================================================
# RETURN ALL CUSTOM OBJECTS
# ======================================================
def get_custom_objects():
return {
"Identity": Identity,
"ClassToken": ClassToken,
"PatchEmbeddings": PatchEmbeddings,
"AddPositionEmbs": AddPositionEmbs,
"TransformerBlock": TransformerBlock,
"ExtractToken": ExtractToken,
"MlpBlock": MlpBlock,
"SimpleMultiHeadAttention": SimpleMultiHeadAttention,
"FixedDropout": FixedDropout,
# Standard layers exposed for H5 compatibility
"MultiHeadAttention": layers.MultiHeadAttention,
"LayerNormalization": layers.LayerNormalization,
"Dropout": layers.Dropout,
"Dense": layers.Dense,
"Conv2D": layers.Conv2D,
"Flatten": layers.Flatten,
"Reshape": layers.Reshape,
"Activation": layers.Activation,
# Activations
"gelu": tf.nn.gelu,
"swish": tf.nn.swish,
"relu": tf.nn.relu,
"sigmoid": tf.nn.sigmoid,
"softmax": tf.nn.softmax,
}