| import tensorflow as tf |
| import numpy as np |
| import json |
|
|
| |
| class TokenShift(tf.keras.layers.Layer): |
| def call(self, x): |
| shifted = tf.concat([tf.zeros_like(x[:, :1, :]), x[:, :-1, :]], axis=1) |
| return (x + shifted) / 2.0 |
|
|
| |
| class TimeMix(tf.keras.layers.Layer): |
| def __init__(self, d_model, n_heads, **kwargs): |
| super().__init__(**kwargs) |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.head_dim = d_model // n_heads |
| self.shift = TokenShift() |
| self.qkv = tf.keras.layers.Dense(3 * d_model, use_bias=False) |
| self.out_proj = tf.keras.layers.Dense(d_model, use_bias=False) |
|
|
| def call(self, x, training=False): |
| x = self.shift(x) |
| B, T, C = tf.shape(x)[0], tf.shape(x)[1], self.d_model |
| qkv = self.qkv(x) |
| q, k, v = tf.split(qkv, 3, axis=-1) |
|
|
| q = tf.reshape(q, [B, T, self.n_heads, self.head_dim]) |
| k = tf.reshape(k, [B, T, self.n_heads, self.head_dim]) |
| v = tf.reshape(v, [B, T, self.n_heads, self.head_dim]) |
|
|
| q = tf.transpose(q, [0, 2, 1, 3]) |
| k = tf.transpose(k, [0, 2, 1, 3]) |
| v = tf.transpose(v, [0, 2, 1, 3]) |
|
|
| scale = tf.math.sqrt(tf.cast(self.head_dim, tf.float32)) |
| attn = tf.matmul(q, k, transpose_b=True) / scale |
|
|
| mask = tf.linalg.band_part(tf.ones([T, T]), -1, 0) |
| attn = attn * mask[tf.newaxis, tf.newaxis, :, :] + (1.0 - mask[tf.newaxis, tf.newaxis, :, :]) * -1e9 |
| attn = tf.nn.softmax(attn, axis=-1) |
|
|
| out = tf.matmul(attn, v) |
| out = tf.transpose(out, [0, 2, 1, 3]) |
| out = tf.reshape(out, [B, T, C]) |
| return self.out_proj(out) |
|
|
| |
| class ChannelMix(tf.keras.layers.Layer): |
| def __init__(self, d_model, expand=4, **kwargs): |
| super().__init__(**kwargs) |
| self.shift = TokenShift() |
| self.fc1 = tf.keras.layers.Dense(d_model * expand, use_bias=False) |
| self.fc2 = tf.keras.layers.Dense(d_model, use_bias=False) |
|
|
| def call(self, x, training=False): |
| x = self.shift(x) |
| h = self.fc1(x) |
| h = tf.nn.relu(h) ** 2 |
| return self.fc2(h) |
|
|
| |
| class TeraBlock(tf.keras.layers.Layer): |
| def __init__(self, d_model, n_heads, drop_rate=0.0, **kwargs): |
| super().__init__(**kwargs) |
| self.norm1 = tf.keras.layers.GroupNormalization(groups=4, axis=-1) |
| self.time_mix = TimeMix(d_model, n_heads) |
| self.norm2 = tf.keras.layers.GroupNormalization(groups=4, axis=-1) |
| self.channel_mix = ChannelMix(d_model) |
| self.drop_rate = drop_rate |
|
|
| def call(self, x, training=False): |
| |
| if training and self.drop_rate > 0.0: |
| if tf.random.uniform([]) < self.drop_rate: |
| return x |
|
|
| h = self.norm1(x) |
| x = x + self.time_mix(h, training=training) |
| h = self.norm2(x) |
| x = x + self.channel_mix(h, training=training) |
| return x |
|
|
| |
| class TeraLM(tf.keras.Model): |
| def __init__(self, vocab_size, d_model=128, n_heads=4, n_layers=3, |
| max_seq=32, drop_rate=0.05, **kwargs): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.n_layers = n_layers |
| self.max_seq = max_seq |
|
|
| self.tok_emb = tf.keras.layers.Embedding(vocab_size, d_model) |
| self.pos_emb = tf.keras.layers.Embedding(max_seq, d_model) |
| self.blocks = [ |
| TeraBlock(d_model, n_heads, drop_rate=drop_rate * (i / max(n_layers - 1, 1))) |
| for i in range(n_layers) |
| ] |
| self.ln_f = tf.keras.layers.GroupNormalization(groups=4, axis=-1) |
| self.head = tf.keras.layers.Dense(vocab_size, use_bias=False) |
|
|
| def call(self, x, training=False): |
| B, T = tf.shape(x)[0], tf.shape(x)[1] |
| pos = tf.range(T)[tf.newaxis, :] |
| h = self.tok_emb(x) + self.pos_emb(pos) |
| for block in self.blocks: |
| h = block(h, training=training) |
| h = self.ln_f(h) |
| return self.head(h) |
|
|
| def get_config(self): |
| return { |
| "vocab_size": self.vocab_size, |
| "d_model": self.d_model, |
| "n_heads": self.n_heads, |
| "n_layers": self.n_layers, |
| "max_seq": self.max_seq, |
| } |
|
|
| |
| TeraAIModel = TeraLM |
|
|