File size: 4,455 Bytes
02136f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import tensorflow as tf
import numpy as np
import json

# ---- Token Shift ----
class TokenShift(tf.keras.layers.Layer):
    def call(self, x):
        shifted = tf.concat([tf.zeros_like(x[:, :1, :]), x[:, :-1, :]], axis=1)
        return (x + shifted) / 2.0

# ---- Time Mix ----
class TimeMix(tf.keras.layers.Layer):
    def __init__(self, d_model, n_heads, **kwargs):
        super().__init__(**kwargs)
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.shift = TokenShift()
        self.qkv = tf.keras.layers.Dense(3 * d_model, use_bias=False)
        self.out_proj = tf.keras.layers.Dense(d_model, use_bias=False)

    def call(self, x, training=False):
        x = self.shift(x)
        B, T, C = tf.shape(x)[0], tf.shape(x)[1], self.d_model
        qkv = self.qkv(x)
        q, k, v = tf.split(qkv, 3, axis=-1)

        q = tf.reshape(q, [B, T, self.n_heads, self.head_dim])
        k = tf.reshape(k, [B, T, self.n_heads, self.head_dim])
        v = tf.reshape(v, [B, T, self.n_heads, self.head_dim])

        q = tf.transpose(q, [0, 2, 1, 3])
        k = tf.transpose(k, [0, 2, 1, 3])
        v = tf.transpose(v, [0, 2, 1, 3])

        scale = tf.math.sqrt(tf.cast(self.head_dim, tf.float32))
        attn = tf.matmul(q, k, transpose_b=True) / scale

        mask = tf.linalg.band_part(tf.ones([T, T]), -1, 0)
        attn = attn * mask[tf.newaxis, tf.newaxis, :, :] + (1.0 - mask[tf.newaxis, tf.newaxis, :, :]) * -1e9
        attn = tf.nn.softmax(attn, axis=-1)

        out = tf.matmul(attn, v)
        out = tf.transpose(out, [0, 2, 1, 3])
        out = tf.reshape(out, [B, T, C])
        return self.out_proj(out)

# ---- Channel Mix (FFN) with Squared ReLU ----
class ChannelMix(tf.keras.layers.Layer):
    def __init__(self, d_model, expand=4, **kwargs):
        super().__init__(**kwargs)
        self.shift = TokenShift()
        self.fc1 = tf.keras.layers.Dense(d_model * expand, use_bias=False)
        self.fc2 = tf.keras.layers.Dense(d_model, use_bias=False)

    def call(self, x, training=False):
        x = self.shift(x)
        h = self.fc1(x)
        h = tf.nn.relu(h) ** 2  # Squared ReLU
        return self.fc2(h)

# ---- Single TERA Block ----
class TeraBlock(tf.keras.layers.Layer):
    def __init__(self, d_model, n_heads, drop_rate=0.0, **kwargs):
        super().__init__(**kwargs)
        self.norm1 = tf.keras.layers.GroupNormalization(groups=4, axis=-1)
        self.time_mix = TimeMix(d_model, n_heads)
        self.norm2 = tf.keras.layers.GroupNormalization(groups=4, axis=-1)
        self.channel_mix = ChannelMix(d_model)
        self.drop_rate = drop_rate

    def call(self, x, training=False):
        # Stochastic depth
        if training and self.drop_rate > 0.0:
            if tf.random.uniform([]) < self.drop_rate:
                return x

        h = self.norm1(x)
        x = x + self.time_mix(h, training=training)
        h = self.norm2(x)
        x = x + self.channel_mix(h, training=training)
        return x

# ---- TERA LM ----
class TeraLM(tf.keras.Model):
    def __init__(self, vocab_size, d_model=128, n_heads=4, n_layers=3,
                 max_seq=32, drop_rate=0.05, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.max_seq = max_seq

        self.tok_emb = tf.keras.layers.Embedding(vocab_size, d_model)
        self.pos_emb = tf.keras.layers.Embedding(max_seq, d_model)
        self.blocks = [
            TeraBlock(d_model, n_heads, drop_rate=drop_rate * (i / max(n_layers - 1, 1)))
            for i in range(n_layers)
        ]
        self.ln_f = tf.keras.layers.GroupNormalization(groups=4, axis=-1)
        self.head = tf.keras.layers.Dense(vocab_size, use_bias=False)

    def call(self, x, training=False):
        B, T = tf.shape(x)[0], tf.shape(x)[1]
        pos = tf.range(T)[tf.newaxis, :]
        h = self.tok_emb(x) + self.pos_emb(pos)
        for block in self.blocks:
            h = block(h, training=training)
        h = self.ln_f(h)
        return self.head(h)

    def get_config(self):
        return {
            "vocab_size": self.vocab_size,
            "d_model": self.d_model,
            "n_heads": self.n_heads,
            "n_layers": self.n_layers,
            "max_seq": self.max_seq,
        }

# Alias for compatibility
TeraAIModel = TeraLM