tera-v2 / model.py
vedaco's picture
Upload folder using huggingface_hub
02136f2 verified
import tensorflow as tf
import numpy as np
import json
# ---- Token Shift ----
class TokenShift(tf.keras.layers.Layer):
def call(self, x):
shifted = tf.concat([tf.zeros_like(x[:, :1, :]), x[:, :-1, :]], axis=1)
return (x + shifted) / 2.0
# ---- Time Mix ----
class TimeMix(tf.keras.layers.Layer):
def __init__(self, d_model, n_heads, **kwargs):
super().__init__(**kwargs)
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.shift = TokenShift()
self.qkv = tf.keras.layers.Dense(3 * d_model, use_bias=False)
self.out_proj = tf.keras.layers.Dense(d_model, use_bias=False)
def call(self, x, training=False):
x = self.shift(x)
B, T, C = tf.shape(x)[0], tf.shape(x)[1], self.d_model
qkv = self.qkv(x)
q, k, v = tf.split(qkv, 3, axis=-1)
q = tf.reshape(q, [B, T, self.n_heads, self.head_dim])
k = tf.reshape(k, [B, T, self.n_heads, self.head_dim])
v = tf.reshape(v, [B, T, self.n_heads, self.head_dim])
q = tf.transpose(q, [0, 2, 1, 3])
k = tf.transpose(k, [0, 2, 1, 3])
v = tf.transpose(v, [0, 2, 1, 3])
scale = tf.math.sqrt(tf.cast(self.head_dim, tf.float32))
attn = tf.matmul(q, k, transpose_b=True) / scale
mask = tf.linalg.band_part(tf.ones([T, T]), -1, 0)
attn = attn * mask[tf.newaxis, tf.newaxis, :, :] + (1.0 - mask[tf.newaxis, tf.newaxis, :, :]) * -1e9
attn = tf.nn.softmax(attn, axis=-1)
out = tf.matmul(attn, v)
out = tf.transpose(out, [0, 2, 1, 3])
out = tf.reshape(out, [B, T, C])
return self.out_proj(out)
# ---- Channel Mix (FFN) with Squared ReLU ----
class ChannelMix(tf.keras.layers.Layer):
def __init__(self, d_model, expand=4, **kwargs):
super().__init__(**kwargs)
self.shift = TokenShift()
self.fc1 = tf.keras.layers.Dense(d_model * expand, use_bias=False)
self.fc2 = tf.keras.layers.Dense(d_model, use_bias=False)
def call(self, x, training=False):
x = self.shift(x)
h = self.fc1(x)
h = tf.nn.relu(h) ** 2 # Squared ReLU
return self.fc2(h)
# ---- Single TERA Block ----
class TeraBlock(tf.keras.layers.Layer):
def __init__(self, d_model, n_heads, drop_rate=0.0, **kwargs):
super().__init__(**kwargs)
self.norm1 = tf.keras.layers.GroupNormalization(groups=4, axis=-1)
self.time_mix = TimeMix(d_model, n_heads)
self.norm2 = tf.keras.layers.GroupNormalization(groups=4, axis=-1)
self.channel_mix = ChannelMix(d_model)
self.drop_rate = drop_rate
def call(self, x, training=False):
# Stochastic depth
if training and self.drop_rate > 0.0:
if tf.random.uniform([]) < self.drop_rate:
return x
h = self.norm1(x)
x = x + self.time_mix(h, training=training)
h = self.norm2(x)
x = x + self.channel_mix(h, training=training)
return x
# ---- TERA LM ----
class TeraLM(tf.keras.Model):
def __init__(self, vocab_size, d_model=128, n_heads=4, n_layers=3,
max_seq=32, drop_rate=0.05, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.max_seq = max_seq
self.tok_emb = tf.keras.layers.Embedding(vocab_size, d_model)
self.pos_emb = tf.keras.layers.Embedding(max_seq, d_model)
self.blocks = [
TeraBlock(d_model, n_heads, drop_rate=drop_rate * (i / max(n_layers - 1, 1)))
for i in range(n_layers)
]
self.ln_f = tf.keras.layers.GroupNormalization(groups=4, axis=-1)
self.head = tf.keras.layers.Dense(vocab_size, use_bias=False)
def call(self, x, training=False):
B, T = tf.shape(x)[0], tf.shape(x)[1]
pos = tf.range(T)[tf.newaxis, :]
h = self.tok_emb(x) + self.pos_emb(pos)
for block in self.blocks:
h = block(h, training=training)
h = self.ln_f(h)
return self.head(h)
def get_config(self):
return {
"vocab_size": self.vocab_size,
"d_model": self.d_model,
"n_heads": self.n_heads,
"n_layers": self.n_layers,
"max_seq": self.max_seq,
}
# Alias for compatibility
TeraAIModel = TeraLM