import tensorflow as tf import numpy as np import json # ---- Token Shift ---- class TokenShift(tf.keras.layers.Layer): def call(self, x): shifted = tf.concat([tf.zeros_like(x[:, :1, :]), x[:, :-1, :]], axis=1) return (x + shifted) / 2.0 # ---- Time Mix ---- class TimeMix(tf.keras.layers.Layer): def __init__(self, d_model, n_heads, **kwargs): super().__init__(**kwargs) self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads self.shift = TokenShift() self.qkv = tf.keras.layers.Dense(3 * d_model, use_bias=False) self.out_proj = tf.keras.layers.Dense(d_model, use_bias=False) def call(self, x, training=False): x = self.shift(x) B, T, C = tf.shape(x)[0], tf.shape(x)[1], self.d_model qkv = self.qkv(x) q, k, v = tf.split(qkv, 3, axis=-1) q = tf.reshape(q, [B, T, self.n_heads, self.head_dim]) k = tf.reshape(k, [B, T, self.n_heads, self.head_dim]) v = tf.reshape(v, [B, T, self.n_heads, self.head_dim]) q = tf.transpose(q, [0, 2, 1, 3]) k = tf.transpose(k, [0, 2, 1, 3]) v = tf.transpose(v, [0, 2, 1, 3]) scale = tf.math.sqrt(tf.cast(self.head_dim, tf.float32)) attn = tf.matmul(q, k, transpose_b=True) / scale mask = tf.linalg.band_part(tf.ones([T, T]), -1, 0) attn = attn * mask[tf.newaxis, tf.newaxis, :, :] + (1.0 - mask[tf.newaxis, tf.newaxis, :, :]) * -1e9 attn = tf.nn.softmax(attn, axis=-1) out = tf.matmul(attn, v) out = tf.transpose(out, [0, 2, 1, 3]) out = tf.reshape(out, [B, T, C]) return self.out_proj(out) # ---- Channel Mix (FFN) with Squared ReLU ---- class ChannelMix(tf.keras.layers.Layer): def __init__(self, d_model, expand=4, **kwargs): super().__init__(**kwargs) self.shift = TokenShift() self.fc1 = tf.keras.layers.Dense(d_model * expand, use_bias=False) self.fc2 = tf.keras.layers.Dense(d_model, use_bias=False) def call(self, x, training=False): x = self.shift(x) h = self.fc1(x) h = tf.nn.relu(h) ** 2 # Squared ReLU return self.fc2(h) # ---- Single TERA Block ---- class TeraBlock(tf.keras.layers.Layer): def __init__(self, d_model, n_heads, drop_rate=0.0, **kwargs): super().__init__(**kwargs) self.norm1 = tf.keras.layers.GroupNormalization(groups=4, axis=-1) self.time_mix = TimeMix(d_model, n_heads) self.norm2 = tf.keras.layers.GroupNormalization(groups=4, axis=-1) self.channel_mix = ChannelMix(d_model) self.drop_rate = drop_rate def call(self, x, training=False): # Stochastic depth if training and self.drop_rate > 0.0: if tf.random.uniform([]) < self.drop_rate: return x h = self.norm1(x) x = x + self.time_mix(h, training=training) h = self.norm2(x) x = x + self.channel_mix(h, training=training) return x # ---- TERA LM ---- class TeraLM(tf.keras.Model): def __init__(self, vocab_size, d_model=128, n_heads=4, n_layers=3, max_seq=32, drop_rate=0.05, **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.d_model = d_model self.n_heads = n_heads self.n_layers = n_layers self.max_seq = max_seq self.tok_emb = tf.keras.layers.Embedding(vocab_size, d_model) self.pos_emb = tf.keras.layers.Embedding(max_seq, d_model) self.blocks = [ TeraBlock(d_model, n_heads, drop_rate=drop_rate * (i / max(n_layers - 1, 1))) for i in range(n_layers) ] self.ln_f = tf.keras.layers.GroupNormalization(groups=4, axis=-1) self.head = tf.keras.layers.Dense(vocab_size, use_bias=False) def call(self, x, training=False): B, T = tf.shape(x)[0], tf.shape(x)[1] pos = tf.range(T)[tf.newaxis, :] h = self.tok_emb(x) + self.pos_emb(pos) for block in self.blocks: h = block(h, training=training) h = self.ln_f(h) return self.head(h) def get_config(self): return { "vocab_size": self.vocab_size, "d_model": self.d_model, "n_heads": self.n_heads, "n_layers": self.n_layers, "max_seq": self.max_seq, } # Alias for compatibility TeraAIModel = TeraLM