Yuchan commited on
Commit
37ba2d7
ยท
verified ยท
1 Parent(s): 28e8f57

Create Inference.py

Browse files
Files changed (1) hide show
  1. Inference.py +248 -0
Inference.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import pandas as pd
4
+ import tensorflow as tf
5
+ from tensorflow.keras import layers
6
+ import sentencepiece as spm
7
+ import requests
8
+
9
+ # โฌ‡๏ธ ํ† ํฌ๋‚˜์ด์ € ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
10
+ sp = spm.SentencePieceProcessor()
11
+ sp.load("ko_unigram.model")
12
+
13
+ # โฌ‡๏ธ ํŠน์ˆ˜ ํ† ํฐ ID ์ถ”์ถœ
14
+ pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
15
+ start_id = sp.piece_to_id("<start>")
16
+ sep_id = sp.piece_to_id("<sep>")
17
+ end_id = sp.piece_to_id("<end>")
18
+ unk_id = sp.piece_to_id("<unk>")
19
+
20
+ vocab_size = sp.get_piece_size()
21
+ print(f"โœ… Vocabulary size: {vocab_size}")
22
+
23
+ # โฌ‡๏ธ ํ…์ŠคํŠธ <-> ID ๋ณ€ํ™˜ ํ•จ์ˆ˜
24
+ def text_to_ids(text):
25
+ return sp.encode(text, out_type=int)
26
+
27
+ def ids_to_text(ids):
28
+ return sp.decode(ids)
29
+
30
+ max_len = 100
31
+ batch_size = 128
32
+
33
+ class Lo(layers.Layer):
34
+ def __init__(self, d_model):
35
+ super().__init__()
36
+ # ๋‚ด๋ถ€ ๊ณ„์‚ฐ์€ float32๋กœ ์œ ์ง€
37
+ self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
38
+ self.p = layers.Dense(96, use_bias=True, dtype='float32')
39
+ self._out_dtype = 'float32'
40
+
41
+ def call(self, x):
42
+ # x may be bfloat16; cast to float32 for stable intermediate computation
43
+ x_f32 = tf.cast(x, tf.float32)
44
+ x = self.proj(x_f32)
45
+ x = tf.nn.gelu(x)
46
+ x = self.p(x)
47
+ # cast back to model dtype for consistency
48
+ return tf.cast(x, self._out_dtype)
49
+
50
+ class LoSoU(layers.Layer):
51
+ """
52
+ ์•ˆ์ •ํ™”๋œ LoSoU ๋ ˆ์ด์–ด (๋™์  alpha ์‚ฌ์šฉ)
53
+ - alpha ๊ฐ’์„ ์ž…๋ ฅ์— ๋”ฐ๋ผ ๋™์ ์œผ๋กœ ๊ณ„์‚ฐ: alpha = sigmoid(Linear(x))
54
+ - ๋ˆ„์ ํ•ฉ ๋Œ€์‹  ์ง€์ˆ˜์ด๋™ํ‰๊ท (EMA) ์‚ฌ์šฉ (alpha: smoothing factor)
55
+ - ๋‚ด๋ถ€ ๊ณ„์‚ฐ์€ float32๋กœ ์ˆ˜ํ–‰ (TPU bfloat16 ์•ˆ์ •์„ฑ ํ–ฅ์ƒ)
56
+ - EMA ๊ฒฐ๊ณผ ํด๋ฆฌํ•‘ ๋ฐ ์ž‘์€ epsilon ์ ์šฉ
57
+ - ์•ˆ์ „ํ•œ split ์ฒ˜๋ฆฌ (์ง์ˆ˜ ์ฐจ์› ๊ฐ€์ •; ์•„๋‹ˆ๋ผ๋ฉด ๋งˆ์ง€๋ง‰ ์ฐจ์› pad ํ•„์š”)
58
+ """
59
+ def __init__(self, d_model, clip_value=5.0, eps=1e-6):
60
+ super().__init__()
61
+ # ๋Œ€๋ถ€๋ถ„ ์—ฐ์‚ฐ์„ float32๋กœ ์ˆ˜ํ–‰
62
+ self.d_model = d_model
63
+ self.clip_value = float(clip_value)
64
+ self.eps = float(eps)
65
+
66
+ # projection / gating layers in float32
67
+ self.Q = layers.Dense(96, dtype='float32')
68
+ self.K = layers.Dense(96, dtype='float32')
69
+ self.V = Lo(d_model) # Lo already handles casting to model dtype; we'll cast back to float32
70
+ self.proj = layers.Dense(d_model, use_bias=True, dtype='float32')
71
+ self.O = layers.Dense(d_model, dtype='float32')
72
+ self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
73
+
74
+ # ๋™์  alpha ๊ณ„์‚ฐ์„ ์œ„ํ•œ ๋ ˆ์ด์–ด
75
+ # alpha๋Š” [0, 1] ๋ฒ”์œ„์—ฌ์•ผ ํ•˜๋ฏ€๋กœ sigmoid ์‚ฌ์šฉ
76
+ # ์ž…๋ ฅ x์˜ d_model ์ฐจ์›์„ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ ์ƒ˜ํ”Œ์— ๋Œ€ํ•ด alpha ๊ณ„์‚ฐ
77
+ # ์˜ˆ: (B, L, d_model) -> (B, L, 1) -> (B, L, 1) with sigmoid
78
+ # ๋˜๋Š” (B, L, d_model) -> (B, L, d_model) -> global reduce -> (B, L, 1)
79
+ # ๊ฐ„๋‹จํžˆ ๊ฐ ์œ„์น˜์— ๋Œ€ํ•ด ๋™์ผํ•œ alpha ์‚ฌ์šฉ (์ž…๋ ฅ์˜ ํ‰๊ท  ๊ธฐ๋ฐ˜)
80
+ # ๋˜๋Š” ์œ„์น˜๋ณ„๋กœ ๋‹ค๋ฅด๊ฒŒ ์‚ฌ์šฉ (๊ฐ ์œ„์น˜์— ๋Œ€ํ•ด ๊ณ„์‚ฐ)
81
+ # ์—ฌ๊ธฐ์„œ๋Š” ์œ„์น˜๋ณ„๋กœ ๋‹ค๋ฅด๊ฒŒ ๊ณ„์‚ฐ (B, L, 1)
82
+ self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
83
+
84
+ def _ema_over_time(self, score, alpha_dynamic):
85
+ # score: (B, L, D) float32 in [0,1] roughly
86
+ # alpha_dynamic: (B, L, 1) float32 in [0,1]
87
+
88
+ # transpose to (L, B, D) to scan over time steps
89
+ seq = tf.transpose(score, perm=[1, 0, 2]) # (L, B, D)
90
+ alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2]) # (L, B, 1)
91
+
92
+ def step(prev_ema, inputs):
93
+ x_t, alpha_t = inputs
94
+ # prev_ema: (B, D), x_t: (B, D), alpha_t: (B, 1)
95
+ new = alpha_t * x_t + (1.0 - alpha_t) * prev_ema
96
+ return new
97
+
98
+ # ์ดˆ๊ธฐ๊ฐ’์„ ์ฒซ step ๊ฐ’์œผ๋กœ ์„ค์ •
99
+ init = seq[0] # (B, D)
100
+ first_alpha = alpha_seq[0] # (B, 1)
101
+
102
+ # scan์˜ elems๋Š” (L-1, B, D) ๋ฐ (L-1, B, 1) ์ด์–ด์•ผ ํ•จ
103
+ remaining_seq = seq[1:] # (L-1, B, D)
104
+ remaining_alpha = alpha_seq[1:] # (L-1, B, 1)
105
+
106
+ # elems๋Š” ๋‘ ํ…์„œ์˜ ํŠœํ”Œ๋กœ ๊ตฌ์„ฑ: (x_t, alpha_t)
107
+ elems = (remaining_seq, remaining_alpha)
108
+
109
+ ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
110
+ # ์ดˆ๊ธฐ๊ฐ’ ํฌํ•จ
111
+ ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0) # (L, B, D)
112
+
113
+ # transpose back to (B, L, D)
114
+ ema = tf.transpose(ema_seq, perm=[1, 0, 2])
115
+ return ema
116
+
117
+ def call(self, x):
118
+ # x: (B, L, d_model) maybe bfloat16 or float32
119
+ # cast to float32 for all internal computations
120
+ x_f32 = tf.cast(x, tf.float32)
121
+ residual = x_f32
122
+
123
+ # Q, K, V
124
+ q = self.Q(x_f32) # (B, L, 96)
125
+ k = self.K(x_f32) # (B, L, 96)
126
+ V = tf.cast(self.V(x), tf.float32) # ensure V's output is float32
127
+
128
+ # gating signals in (0,1)
129
+ g_q = tf.nn.sigmoid(q)
130
+ g_k = tf.nn.sigmoid(k)
131
+
132
+ # elementwise product -> bounded roughly [0,1]
133
+ score = g_q * g_k
134
+
135
+ # ๋™์  alpha ๊ณ„์‚ฐ: (B, L, d_model) -> (B, L, 1)
136
+ alpha_dynamic = self.alpha_linear(x_f32) # (B, L, 1)
137
+ # ํ•„์š”์‹œ alpha_dynamic์— ๋Œ€ํ•œ ํ›„์ฒ˜๋ฆฌ (์˜ˆ: min/max ๋“ฑ) ๊ฐ€๋Šฅ
138
+ # ex: alpha_dynamic = tf.clip_by_value(alpha_dynamic, 0.01, 0.99)
139
+
140
+ # EMA across time (stable alternative to cumsum)
141
+ score_ema = self._ema_over_time(score, alpha_dynamic)
142
+
143
+ # optionally normalize by (mean + eps) across last dim to reduce scale variations
144
+ mean_last = tf.reduce_mean(score_ema, axis=-1, keepdims=True) # (B, L, 1)
145
+ denom = tf.maximum(mean_last, self.eps)
146
+ score_norm = score_ema / denom
147
+
148
+ # clip to avoid extremes
149
+ score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
150
+
151
+ # combine with V
152
+ x_comb = score_clipped * V # (B, L, d_model)
153
+
154
+ out = self.proj(x_comb) # (B, L, d_model)
155
+
156
+ # ensure out dim even for split
157
+ d = out.shape[-1] # this is an int (static shape)
158
+ if d is not None and d % 2 == 1:
159
+ out = tf.pad(out, [[0,0],[0,0],[0,1]])
160
+
161
+ a, b = tf.split(out, 2, axis=-1)
162
+ gated = tf.nn.silu(a) * b
163
+ out = self.O(gated)
164
+
165
+ out = self.norm(out + residual)
166
+
167
+ # cast back to original dtype for downstream layers
168
+ return tf.cast(out, x.dtype)
169
+
170
+
171
+ class Block(layers.Layer):
172
+ def __init__(self, d_model, hyper_n):
173
+ super().__init__()
174
+ self.losou = [LoSoU(d_model) for _ in range(hyper_n)]
175
+
176
+ def call(self, x):
177
+ for losou in self.losou:
178
+ x = losou(x)
179
+ return x
180
+
181
+ class ReLaM(tf.keras.Model):
182
+ def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
183
+ super().__init__()
184
+ self.token_embedding = layers.Embedding(vocab_size, d_model)
185
+ self.pos_embedding = layers.Embedding(max_seq_len, d_model)
186
+ self.blocks = [Block(d_model, hyper_n=3) for _ in range(n_layers)]
187
+
188
+ # LayerNormalization์€ float32๋กœ ํ•ด์„œ ์ •๋ฐ€๋„ ๋ฌธ์ œ ๋ฐฉ์ง€
189
+ self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
190
+
191
+ def call(self, x, training=False):
192
+ batch_size, seq_len = tf.shape(x)[0], tf.shape(x)[1]
193
+ positions = tf.range(seq_len)[tf.newaxis, :]
194
+
195
+ x = self.token_embedding(x) + self.pos_embedding(positions)
196
+ for block in self.blocks:
197
+ x = block(x)
198
+
199
+ x = self.ln_f(x)
200
+
201
+ embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
202
+ logits = tf.matmul(x, embedding_matrix, transpose_b=True)
203
+ return tf.cast(logits, tf.float32)
204
+
205
+ # ๋ชจ๋ธ ์ƒ์„ฑ
206
+ model = ReLaM(
207
+ vocab_size=vocab_size,
208
+ max_seq_len=max_len,
209
+ d_model=256,
210
+ n_layers=1
211
+ )
212
+
213
+ dummy_input = tf.zeros((1, max_len), dtype=tf.int32)
214
+ _ = model(dummy_input)
215
+ model.load_weights('/content/Cobra.weights.h5')
216
+ print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ์™„๋ฃŒ!")
217
+
218
+ def generate_text_topp(model, prompt, max_len=100, max_gen=98, p=0.9, temperature=0.8, min_len=30):
219
+ model_input = text_to_ids(f"<start> {prompt} <sep>")
220
+ model_input = model_input[:max_len]
221
+ generated = list(model_input)
222
+ for step in range(max_gen):
223
+ if len(generated) > max_len:
224
+ input_seq = generated[-max_len:]
225
+ else:
226
+ input_seq = generated
227
+ input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
228
+ input_tensor = tf.convert_to_tensor([input_padded])
229
+ logits = model(input_tensor, training=False)
230
+ next_token_logits = logits[0, len(input_seq) - 1].numpy()
231
+ next_token_logits[end_id] -= 5.0
232
+ next_token_logits[pad_id] -= 10.0
233
+ probs = tf.nn.softmax(next_token_logits / temperature).numpy()
234
+ sorted_indices = np.argsort(probs)[::-1]
235
+ sorted_probs = probs[sorted_indices]
236
+ cumulative_probs = np.cumsum(sorted_probs)
237
+ cutoff = np.searchsorted(cumulative_probs, p)
238
+ top_indices = sorted_indices[:cutoff + 1]
239
+ top_probs = sorted_probs[:cutoff + 1]
240
+ top_probs /= np.sum(top_probs)
241
+ next_token_id = np.random.choice(top_indices, p=top_probs)
242
+ if next_token_id == end_id and len(generated) >= min_len:
243
+ break
244
+ generated.append(int(next_token_id))
245
+ return ids_to_text(generated)
246
+
247
+ print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
248
+ print(generate_text_topp(model, "์ œ๊ฐ€ ์ด๋”ฐ๊ฐ€ ๋ฒ„์Šค๋ฅผ ํƒ€์•ผ ํ•ด์„œ ์ค€๋น„ ์ข€ ํ•ด์•ผ๊ฒ ์–ด์š”. ์žฌ๋ฏธ์žˆ๋Š” ๋Œ€ํ™”์˜€์Šต๋‹ˆ๋‹ค!", p=0.8))