Yuchan commited on
Commit
5e06e87
ยท
verified ยท
1 Parent(s): ec2770f

Update Inference.py

Browse files
Files changed (1) hide show
  1. Inference.py +81 -119
Inference.py CHANGED
@@ -77,10 +77,6 @@ def text_to_ids(text):
77
  def ids_to_text(ids):
78
  return sp.decode(ids)
79
 
80
- # =======================
81
- # 3) ๋ชจ๋ธ ๋ ˆ์ด์–ด (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
82
- # =======================
83
-
84
  class SwiGLU(layers.Layer):
85
  def __init__(self, d_model, d_ff):
86
  super().__init__()
@@ -91,139 +87,105 @@ class SwiGLU(layers.Layer):
91
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
92
  return self.out(x_val * tf.nn.silu(x_gate))
93
 
94
- class gMLPBlock(layers.Layer):
95
- def __init__(self, d_model, seq_len, dropout=0.1):
96
- super().__init__()
97
- self.d_model = d_model
98
- self.seq_len = seq_len
99
- self.norm = layers.LayerNormalization(epsilon=1e-6)
100
-
101
- # FFN: Channel Expansion
102
- # d_model * 4๋กœ ํ™•์žฅ
103
- self.channel_proj = layers.Dense(d_model * 4, use_bias=True)
104
- self.dropout = layers.Dropout(dropout)
105
-
106
- # Spatial Gating Unit (SGU)
107
- self.sgu_norm = layers.LayerNormalization(epsilon=1e-6)
108
- self.sgu_proj = layers.Dense(seq_len, use_bias=False)
109
-
110
- # ์ถœ๋ ฅ ์ฐจ์›์„ d_model * 2 (U์˜ ์ฐจ์›)๋กœ ์„ค์ •
111
- self.sgu_final = layers.Dense(d_model * 2, use_bias=True)
112
-
113
- self.out_proj = layers.Dense(d_model, use_bias=True)
114
-
115
- def call(self, x, training=False):
116
- # 1. Norm and Channel Expansion
117
- residual = x
118
- x_norm = self.norm(x)
119
- x_proj = self.channel_proj(x_norm) # Shape: (B, L, 4*D)
120
-
121
- # 2. Split (U and V streams)
122
- u, v = tf.split(x_proj, 2, axis=-1) # u, v Shape: (B, L, 2*D)
123
-
124
- # 3. Spatial Gating Unit (SGU)
125
- v_norm = self.sgu_norm(v)
126
- v_norm_T = tf.transpose(v_norm, perm=[0, 2, 1]) # (B, 2D, L)
127
-
128
- # ๐Ÿ’ก ํ† ํฐ ๋ฏน์‹ฑ ๋ฐœ์ƒ (์‹œํ€€์Šค ์ถ•์œผ๋กœ Dense ์ ์šฉ)
129
- v_proj = self.sgu_proj(v_norm_T) # (B, 2D, L)
130
- v_proj_T = tf.transpose(v_proj, perm=[0, 2, 1]) # (B, L, 2D)
131
-
132
- # 4. Activation and Gate Generation
133
- # ํ‘œ์ค€ gMLP๋Š” U์— GELU๋ฅผ ์ ์šฉํ•˜๊ณ  V๋Š” ์„ ํ˜• ๊ฒŒ์ดํŠธ๋กœ ์‚ฌ์šฉ
134
- # ์—ฌ๊ธฐ์„œ๋Š” U์— GELU๋ฅผ ์ ์šฉ
135
- u_act = tf.nn.gelu(u)
136
- v_gate = self.sgu_final(v_proj_T) # Shape: (B, L, 2*D)
137
-
138
- # 5. Gating and Contraction
139
- z = u_act * v_gate # ๊ฒŒ์ดํŒ…
140
- z = self.dropout(z, training=training)
141
- out = self.out_proj(z) # Shape: (B, L, D)
142
-
143
- # 6. Residual Connection
144
- return residual + out
145
-
146
- class CrossBlock(layers.Layer):
147
- def __init__(self, clip_value=5.0, eps=1e-6): # ๐Ÿ’ก d_model ์ธ์ž ์ถ”๊ฐ€
148
- super().__init__()
149
- self.clip_value = clip_value
150
- self.eps = eps
151
- self.attn = layers.MultiHeadAttention(8, 20)
152
- # ๐Ÿ’ก ์ˆ˜์ •: ์ถœ๋ ฅ ์ฐจ์›์„ 1์—์„œ d_model๋กœ ๋ณ€๊ฒฝ
153
- def call(self, x, z):
154
- y = self.attn(x, z, z)
155
- return y
156
 
157
  class LoU(layers.Layer):
158
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
159
  super().__init__()
160
  self.d_model = d_model
161
  self.clip_value = float(clip_value)
162
- self.mha = layers.MultiHeadAttention(8, 20)
163
- self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
 
 
164
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
 
165
 
166
- self.glu = SwiGLU(d_model, 350)
167
- self.cross = CrossBlock()
168
-
169
- def call(self, x, z):
170
  x_f32 = tf.cast(x, tf.float32)
171
  residual = x_f32
172
- x = self.norm1(x)
 
 
 
 
 
 
 
173
 
174
- x_comb = self.mha(x, x, x, use_causal_mask=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  out = self.norm(x_comb + residual)
177
- out = self.cross(out, z)
178
  out = self.glu(out)
179
  return tf.cast(out, x.dtype)
180
- # =======================
181
- # 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
182
- # =======================
183
 
184
- class AlphaS2S(tf.keras.Model):
185
- def __init__(self, num_layers, d_model, num_heads, input_vocab_size, target_vocab_size, max_len=200, dropout=0.1):
 
186
  super().__init__()
187
- self.max_len = max_len
188
- self.d_model = d_model
189
-
190
- # ์ธ์ฝ”๋”์™€ ๋””์ฝ”๋” ์ž„๋ฒ ๋”ฉ ๋ฐ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์€ ๋ชจ๋‘ max_len์„ ์‚ฌ์šฉ
191
- self.enc_embedding = layers.Embedding(input_vocab_size, d_model)
192
- self.enc_pos_embedding = layers.Embedding(max_len, d_model)
193
- self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
194
- self.dec_pos_embedding = layers.Embedding(max_len, d_model)
195
-
196
- # EncoderBlock๊ณผ LoU๋Š” ๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผํ•œ ๊ตฌ์กฐ
197
- self.enc_layers = [gMLPBlock(d_model, seq_len=max_len) for _ in range(num_layers)]
198
- self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
199
-
200
- self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
201
-
202
- def call(self, inputs, training=False):
203
- # enc_inputs์™€ dec_inputs๋Š” ๋™์ผํ•œ ์‹œํ€€์Šค (Unified Input)
204
- enc_inputs = inputs["enc_inputs"]
205
- dec_inputs = inputs["dec_inputs"]
206
-
207
- enc_pos = tf.range(tf.shape(enc_inputs)[1])[tf.newaxis, :]
208
- dec_pos = tf.range(tf.shape(dec_inputs)[1])[tf.newaxis, :]
209
-
210
- # ์ธ์ฝ”๋” ์‹คํ–‰
211
- x = self.enc_embedding(enc_inputs) + self.enc_pos_embedding(enc_pos)
212
- # Note: ๋งˆ์Šคํฌ ์—†์Œ -> Bi-directional (BERT-like Encoder)
213
- for layer in self.enc_layers: x = layer(x, training=training)
214
- enc_out = x # ์ธ์ฝ”๋”์˜ ์ตœ์ข… ์ถœ๋ ฅ (๋””์ฝ”๋”์˜ 'z' ์ž…๋ ฅ)
215
-
216
- # ๋””์ฝ”๋” ์‹คํ–‰
217
- y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
218
- # Note: LoU๋Š” ๋‚ด๋ถ€์ ์œผ๋กœ EMA๋ฅผ ์‚ฌ์šฉํ•˜๋ฉฐ, ์ผ๋ฐ˜์ ์ธ Cross-Attention ๋ธ”๋ก์˜ ์—ญํ• ์„ ์ˆ˜ํ–‰
219
- for layer in self.dec_layers: y = layer(y, enc_out, training=training)
220
-
221
- return self.final_layer(y)
222
 
223
- # ๊ฐ€์ค‘์น˜ ์ €์žฅ
224
- chat_model = AlphaS2S(num_layers=4, d_model=160, num_heads=8,
225
- input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=max_len)
226
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  dummy_input = {
228
  "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
229
  "dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
 
77
  def ids_to_text(ids):
78
  return sp.decode(ids)
79
 
 
 
 
 
80
  class SwiGLU(layers.Layer):
81
  def __init__(self, d_model, d_ff):
82
  super().__init__()
 
87
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
88
  return self.out(x_val * tf.nn.silu(x_gate))
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  class LoU(layers.Layer):
92
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
93
  super().__init__()
94
  self.d_model = d_model
95
  self.clip_value = float(clip_value)
96
+ self.eps = float(eps)
97
+ self.Q = layers.Dense(d_model, dtype='float32')
98
+ self.K = layers.Dense(d_model, dtype='float32')
99
+ self.V = layers.Dense(d_model, dtype='float32')
100
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
101
+ self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
102
 
103
+ self.glu = SwiGLU(d_model, 320)
104
+ def call(self, x):
 
 
105
  x_f32 = tf.cast(x, tf.float32)
106
  residual = x_f32
107
+ x_f32 = self.norm1(x)
108
+
109
+ q = self.Q(x_f32)
110
+ k = self.K(x_f32)
111
+ V = self.V(x_f32)
112
+ g_q = (tf.nn.tanh(q) + 1.0) / 2.0
113
+ g_k = (tf.nn.tanh(k) + 1.0) / 2.0
114
+ score = g_q * g_k
115
 
116
+ score = tf.cumsum(score, axis=1) # (B, L, D)
117
+
118
+ # ๐Ÿ’ก ์ˆ˜์ •๋œ ๋ถ€๋ถ„: ํ˜„์žฌ ํ† ํฐ๊นŒ์ง€์˜ ๋ˆ„์ ํ•ฉ ํ‰๊ท ์œผ๋กœ ์ •๊ทœํ™”
119
+ seq_len = tf.shape(score)[1]
120
+ # [1, 2, 3, ..., L]์„ D_model ์ฐจ์›์œผ๋กœ ํ™•์žฅ
121
+ count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
122
+ count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
123
+
124
+ # ๋ˆ„์ ํ•ฉ์„ ํ˜„์žฌ๊นŒ์ง€์˜ ํ† ํฐ ๊ฐœ์ˆ˜๋กœ ๋‚˜๋ˆ„์–ด ํ‰๊ท  ๋ˆ„์ ํ•ฉ ๊ณ„์‚ฐ (B, L, D)
125
+ score_mean = score / count_for_mean
126
+
127
+ # ์ •๊ทœํ™” ๋ถ„๋ชจ ์„ค์ •
128
+ denom = tf.maximum(score_mean, self.eps)
129
+ score_norm = score / denom
130
+ # -----------------------------------------------
131
+
132
+ score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
133
+ x_comb = score_clipped * V
134
 
135
  out = self.norm(x_comb + residual)
 
136
  out = self.glu(out)
137
  return tf.cast(out, x.dtype)
 
 
 
138
 
139
+
140
+ class Lo(layers.Layer):
141
+ def __init__(self, d_model):
142
  super().__init__()
143
+ self.d = layers.Dense(64, activation='silu')
144
+ self.w = layers.Dense(d_model)
145
+ self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ def call(self, x):
148
+ p = self.d(x)
149
+ p = self.w(p)
150
+ return self.norm(p) + x
151
+
152
+ class Block(layers.Layer):
153
+ def __init__(self, d_model):
154
+ super().__init__()
155
+ self.lou = LoU(d_model)
156
+ self.lo = Lo(d_model)
157
+
158
+ def call(self, x):
159
+ x = self.lou(x)
160
+ x = self.lo(x)
161
+ return x
162
+
163
+ class ReLM(tf.keras.Model):
164
+ def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
165
+ super().__init__()
166
+ self.token_embedding = layers.Embedding(vocab_size, d_model)
167
+ self.pos_embedding = layers.Embedding(max_seq_len, d_model)
168
+ self.blocks = [Block(d_model) for _ in range(n_layers)]
169
+ self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
170
+
171
+ def call(self, x, training=False):
172
+ batch_size, seq_len = tf.shape(x)[0], tf.shape(x)[1]
173
+ positions = tf.range(seq_len)[tf.newaxis, :]
174
+ x = self.token_embedding(x) + self.pos_embedding(positions)
175
+ for block in self.blocks:
176
+ x = block(x)
177
+ x = self.ln_f(x)
178
+ embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
179
+ logits = tf.matmul(x, embedding_matrix, transpose_b=True)
180
+ return tf.cast(logits, tf.float32)
181
+
182
+
183
+ model = ReLM(
184
+ vocab_size=vocab_size,
185
+ max_seq_len=max_len,
186
+ d_model=256,
187
+ n_layers=1
188
+ )
189
  dummy_input = {
190
  "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
191
  "dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)