Yuchan commited on
Commit
63439a6
·
verified ·
1 Parent(s): b3fccf3

Create Mo_jax.py

Browse files
Files changed (1) hide show
  1. Mo_jax.py +381 -0
Mo_jax.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Flax + JAX TPU-ready reimplementation of your ReLM model and training loop.
2
+ # Requirements:
3
+ # pip install --upgrade "jax[tpu]" flax optax sentencepiece
4
+
5
+ import os
6
+ import math
7
+ import numpy as np
8
+ import sentencepiece as spm
9
+ from functools import partial
10
+ from typing import Any, Callable, Optional, Tuple, Sequence
11
+
12
+ import jax
13
+ import jax.numpy as jnp
14
+ from jax import random
15
+ from flax import linen as nn
16
+ from flax.training import train_state, checkpoints
17
+ import optax
18
+ import tqdm
19
+
20
+ # ------------------
21
+ # Config
22
+ # ------------------
23
+ SEQ_LEN = 512
24
+ # global batch size (across all devices)
25
+ GLOBAL_BATCH = 256
26
+ # adjust for memory
27
+ LIMIT = 200_000 # number of sequences to load (reduce if OOM)
28
+ VOCAB_MODEL = "ko_unigram.model"
29
+ CORPUS_PATH = "corpus.txt"
30
+ DTYPE = jnp.bfloat16 if jax.local_devices()[0].platform == "tpu" else jnp.float32
31
+ SEED = 42
32
+ LEARNING_RATE = 1e-4
33
+ EPOCHS = 1
34
+
35
+ # Derived
36
+ NUM_DEVICES = jax.device_count()
37
+ assert GLOBAL_BATCH % NUM_DEVICES == 0, "GLOBAL_BATCH must be divisible by device count"
38
+ PER_DEVICE_BATCH = GLOBAL_BATCH // NUM_DEVICES
39
+
40
+ print("devices:", jax.devices())
41
+ print("num_devices:", NUM_DEVICES, "per_device_batch:", PER_DEVICE_BATCH, "dtype:", DTYPE)
42
+
43
+ # ------------------
44
+ # Tokenizer loader
45
+ # ------------------
46
+ sp = spm.SentencePieceProcessor()
47
+ sp.load(VOCAB_MODEL)
48
+
49
+ pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
50
+ start_id = sp.piece_to_id("<start>")
51
+ end_id = sp.piece_to_id("<end>")
52
+ vocab_size = sp.get_piece_size()
53
+ print("vocab_size:", vocab_size, "pad_id:", pad_id, "start_id:", start_id, "end_id:", end_id)
54
+
55
+ # ------------------
56
+ # Data pipeline (simple, numpy-based)
57
+ # - Reads corpus line-by-line, tokenizes, pads/truncates to SEQ_LEN.
58
+ # - Builds a numpy array (N, SEQ_LEN) for inputs and targets (shifted by 1).
59
+ # - Shards batches across devices for pmap.
60
+ # ------------------
61
+ def line_to_ids(line: str, max_len: int = SEQ_LEN):
62
+ ids = sp.encode(line.strip(), out_type=int)
63
+ if len(ids) > max_len - 1:
64
+ ids = ids[: max_len - 1]
65
+ ids = ids + [end_id]
66
+ pad_len = max_len - len(ids)
67
+ ids = ids + [pad_id] * pad_len
68
+ return np.array(ids, dtype=np.int32)
69
+
70
+ def build_dataset(corpus_path: str, limit: int = LIMIT):
71
+ arr = []
72
+ with open(corpus_path, "r", encoding="utf-8") as f:
73
+ for i, line in enumerate(f):
74
+ if i >= limit:
75
+ break
76
+ line = line.strip()
77
+ if not line:
78
+ continue
79
+ arr.append(line_to_ids(line))
80
+ data = np.stack(arr, axis=0) # (N, SEQ_LEN)
81
+ print("Loaded dataset shape:", data.shape)
82
+ return data
83
+
84
+ # create inputs and targets
85
+ data_np = build_dataset(CORPUS_PATH, LIMIT)
86
+ inputs = data_np
87
+ targets = np.concatenate([data_np[:,1:], np.full((data_np.shape[0],1), pad_id, dtype=np.int32)], axis=1)
88
+
89
+ # shuffle and create batches
90
+ def create_batch_iter(inputs: np.ndarray, targets: np.ndarray, batch_size: int, rng: np.random.Generator):
91
+ idx = np.arange(inputs.shape[0])
92
+ rng.shuffle(idx)
93
+ for i in range(0, len(idx) - batch_size + 1, batch_size):
94
+ batch_idx = idx[i:i+batch_size]
95
+ x = inputs[batch_idx]
96
+ y = targets[batch_idx]
97
+ yield x, y
98
+
99
+ # helper to shard numpy batch for pmap: shape (num_devices, per_device, ...)
100
+ def shard(xs: np.ndarray):
101
+ return xs.reshape((NUM_DEVICES, -1) + xs.shape[1:])
102
+
103
+ # ------------------
104
+ # Flax model implementation
105
+ # ------------------
106
+ class SwiGLU(nn.Module):
107
+ d_model: int
108
+
109
+ @nn.compact
110
+ def __call__(self, x):
111
+ # project to 2*intermediate, then split
112
+ proj = nn.Dense(self.d_model * 2, dtype=jnp.float32)(x) # keep proj in float32
113
+ x_val, x_gate = jnp.split(proj, 2, axis=-1)
114
+ out = x_val * nn.silu(x_gate)
115
+ out = nn.Dense(self.d_model, dtype=jnp.float32)(out)
116
+ return out.astype(x.dtype)
117
+
118
+ class LoU(nn.Module):
119
+ d_model: int
120
+ clip_value: float = 5.0
121
+ eps: float = 1e-6
122
+
123
+ @nn.compact
124
+ def __call__(self, x):
125
+ # x: (batch, seq, d)
126
+ x_f32 = x.astype(jnp.float32)
127
+ residual = x_f32
128
+
129
+ norm1 = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)
130
+ x_norm = norm1(x_f32)
131
+
132
+ Q = nn.Dense(self.d_model, dtype=jnp.float32)
133
+ K = nn.Dense(self.d_model, dtype=jnp.float32)
134
+ V = nn.Dense(self.d_model, dtype=jnp.float32)
135
+
136
+ q = Q(x_norm)
137
+ k = K(x_norm)
138
+ v = V(x_norm)
139
+
140
+ g_q = (jnp.tanh(q) + 1.0) / 2.0
141
+ g_k = (jnp.tanh(k) + 1.0) / 2.0
142
+ score = g_q * g_k # (b, seq, d)
143
+
144
+ alpha_linear = nn.Dense(1, dtype=jnp.float32)
145
+ alpha_dynamic = alpha_linear(x_norm) # (b, seq, 1)
146
+
147
+ # EMA over time: use scan across sequence axis
148
+ # transpose to (seq, batch, d) to scan over time
149
+ score_t = jnp.transpose(score, (1,0,2))
150
+ alpha_t = jnp.transpose(alpha_dynamic, (1,0,2))
151
+
152
+ def step(carry, inputs):
153
+ prev_ema = carry
154
+ x_t, a_t = inputs
155
+ new = a_t * x_t + (1.0 - a_t) * prev_ema
156
+ return new, new
157
+
158
+ init = score_t[0]
159
+ _, ema_seq = jax.lax.scan(step, init, (score_t[1:], alpha_t[1:]))
160
+ ema_full = jnp.concatenate([init[None, ...], ema_seq], axis=0) # (seq, batch, d)
161
+ ema = jnp.transpose(ema_full, (1,0,2)) # (batch, seq, d)
162
+
163
+ mean_last = jnp.mean(ema, axis=-1, keepdims=True)
164
+ denom = jnp.maximum(mean_last, self.eps)
165
+ score_norm = ema / denom
166
+ score_clipped = jnp.clip(score_norm, -self.clip_value, self.clip_value)
167
+
168
+ x_comb = score_clipped * v
169
+ out = x_comb + residual
170
+ out = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)(out)
171
+ out = SwiGLU(self.d_model)(out.astype(x.dtype))
172
+ return out.astype(x.dtype)
173
+
174
+ class Lo(nn.Module):
175
+ d_model: int
176
+
177
+ @nn.compact
178
+ def __call__(self, x):
179
+ h = nn.Dense(64, dtype=jnp.float32)(x)
180
+ h = nn.silu(h)
181
+ h = nn.Dense(self.d_model, dtype=jnp.float32)(h)
182
+ out = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)(h) + x
183
+ return out.astype(x.dtype)
184
+
185
+ class Block(nn.Module):
186
+ d_model: int
187
+
188
+ @nn.compact
189
+ def __call__(self, x):
190
+ x = LoU(self.d_model)(x)
191
+ x = Lo(self.d_model)(x)
192
+ return x
193
+
194
+ class ReLM(nn.Module):
195
+ vocab_size: int
196
+ max_seq_len: int
197
+ d_model: int
198
+ n_layers: int
199
+ dtype: Any = jnp.float32
200
+
201
+ def setup(self):
202
+ self.token_embed = nn.Embed(self.vocab_size, self.d_model, dtype=self.dtype)
203
+ self.pos_embed = nn.Embed(self.max_seq_len, self.d_model, dtype=self.dtype)
204
+ self.blocks = [Block(self.d_model) for _ in range(self.n_layers)]
205
+ self.ln_f = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)
206
+
207
+ def __call__(self, x, deterministic=True):
208
+ # x: (batch, seq)
209
+ b, seq = x.shape
210
+ positions = jnp.arange(seq)[None, :]
211
+ x = self.token_embed(x) + self.pos_embed(positions)
212
+ for blk in self.blocks:
213
+ x = blk(x)
214
+ x = self.ln_f(x)
215
+ # tie weights: token embedding matrix
216
+ embedding_matrix = self.token_embed.embedding # (vocab, d)
217
+ logits = jnp.einsum("bld,vd->blv", x, embedding_matrix)
218
+ return logits.astype(jnp.float32)
219
+
220
+ # ------------------
221
+ # Loss & metrics
222
+ # ------------------
223
+ def smoothed_cross_entropy(logits, targets, pad_id, eps=0.1):
224
+ # logits: (b, seq, v)
225
+ # targets: (b, seq) int32
226
+ vocab = logits.shape[-1]
227
+ logits = logits.reshape(-1, vocab)
228
+ targets = targets.reshape(-1)
229
+ mask = (targets != pad_id).astype(jnp.float32)
230
+ # one-hot smoothed
231
+ one_hot = jax.nn.one_hot(targets, vocab)
232
+ smooth = (1.0 - eps) * one_hot + eps / float(vocab)
233
+ log_probs = jax.nn.log_softmax(logits, axis=-1)
234
+ loss_per_token = -jnp.sum(smooth * log_probs, axis=-1)
235
+ loss_per_token = loss_per_token * mask
236
+ denom = jnp.sum(mask) + 1e-8
237
+ loss = jnp.sum(loss_per_token) / denom
238
+ return loss
239
+
240
+ def masked_perplexity_from_logits(logits, targets, pad_id, eps=0.1):
241
+ vocab = logits.shape[-1]
242
+ logits = logits.reshape(-1, vocab)
243
+ targets = targets.reshape(-1)
244
+ mask = (targets != pad_id).astype(jnp.float32)
245
+ one_hot = jax.nn.one_hot(targets, vocab)
246
+ smooth = (1.0 - eps) * one_hot + eps / float(vocab)
247
+ log_probs = jax.nn.log_softmax(logits, axis=-1)
248
+ loss_per_token = -jnp.sum(smooth * log_probs, axis=-1) * mask
249
+ mean_loss = jnp.sum(loss_per_token) / (jnp.sum(mask) + 1e-8)
250
+ return jnp.exp(mean_loss)
251
+
252
+ # ------------------
253
+ # Training state
254
+ # ------------------
255
+ class TrainState(train_state.TrainState):
256
+ pass
257
+
258
+ def create_train_state(rng, model, learning_rate):
259
+ params = model.init(rng, jnp.zeros((1, SEQ_LEN), dtype=jnp.int32))["params"]
260
+ tx = optax.chain(
261
+ optax.clip_by_global_norm(1.0),
262
+ optax.adamw(learning_rate=learning_rate, b1=0.9, b2=0.95, eps=1e-8)
263
+ )
264
+ return TrainState.create(apply_fn=model.apply, params=params, tx=tx)
265
+
266
+ # ------------------
267
+ # pmap'd step functions
268
+ # ------------------
269
+ @partial(jax.pmap, axis_name="batch")
270
+ def train_step(state, batch_x, batch_y, rng):
271
+ def loss_fn(params):
272
+ logits = state.apply_fn({"params": params}, batch_x, deterministic=False)
273
+ loss = smoothed_cross_entropy(logits, batch_y, pad_id)
274
+ return loss, logits
275
+
276
+ grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
277
+ (loss, logits), grads = grad_fn(state.params)
278
+ grads = jax.lax.pmean(grads, axis_name="batch")
279
+ new_state = state.apply_gradients(grads=grads)
280
+ # metrics
281
+ ppl = masked_perplexity_from_logits(logits, batch_y, pad_id)
282
+ metrics = {"loss": loss, "ppl": ppl}
283
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
284
+ return new_state, metrics
285
+
286
+ @partial(jax.pmap, axis_name="batch")
287
+ def eval_step(state, batch_x, batch_y):
288
+ logits = state.apply_fn({"params": state.params}, batch_x, deterministic=True)
289
+ loss = smoothed_cross_entropy(logits, batch_y, pad_id)
290
+ ppl = masked_perplexity_from_logits(logits, batch_y, pad_id)
291
+ metrics = {"loss": loss, "ppl": ppl}
292
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
293
+ return metrics
294
+
295
+ # ------------------
296
+ # Training loop
297
+ # ------------------
298
+ rng = random.PRNGKey(SEED)
299
+ rng, init_rng = random.split(rng)
300
+ model = ReLM(vocab_size=vocab_size, max_seq_len=SEQ_LEN, d_model=512, n_layers=9, dtype=DTYPE)
301
+ state = create_train_state(init_rng, model, LEARNING_RATE)
302
+
303
+ # replicate to devices
304
+ state = jax.device_put_replicated(state, jax.local_devices())
305
+
306
+ print("Starting training...")
307
+
308
+ global_step = 0
309
+ for epoch in range(EPOCHS):
310
+ print(f"Epoch {epoch+1}/{EPOCHS}")
311
+ np_rng = np.random.default_rng(SEED + epoch)
312
+ batch_iter = create_batch_iter(inputs, targets, GLOBAL_BATCH, np_rng)
313
+ pbar = tqdm.tqdm(batch_iter, total= max(1, inputs.shape[0] // GLOBAL_BATCH))
314
+
315
+ for batch_x, batch_y in pbar:
316
+ # shard
317
+ batch_x = shard(batch_x)
318
+ batch_y = shard(batch_y)
319
+ rng, step_rng = random.split(rng)
320
+ # make per-device rngs
321
+ step_rngs = random.split(step_rng, NUM_DEVICES)
322
+ state, metrics = train_step(state, batch_x, batch_y, step_rngs)
323
+ # metrics are per-device; take first replica
324
+ m = jax.tree_util.tree_map(lambda x: x[0], metrics)
325
+ pbar.set_postfix(loss=float(m["loss"]), ppl=float(m["ppl"]))
326
+ global_step += 1
327
+
328
+ # ------------------
329
+ # Save params
330
+ # ------------------
331
+ save_dir = "./checkpoints"
332
+ os.makedirs(save_dir, exist_ok=True)
333
+ # save using flax.serialization via checkpoints
334
+ checkpoints.save_checkpoint(save_dir, jax.tree_map(lambda x: np.array(x), state), step=global_step, keep=3)
335
+ print("Saved checkpoint to", save_dir)
336
+
337
+ # ------------------
338
+ # Sampling (top-p) - single-device (CPU) sampling for simplicity
339
+ # ------------------
340
+ import math
341
+
342
+ def top_p_sample_logits(rng, logits, p=0.9, temperature=1.0):
343
+ # logits: (vocab,)
344
+ probs = jax.nn.softmax(logits / temperature)
345
+ # convert to numpy for sorting (ok for single token)
346
+ probs_np = np.array(probs)
347
+ sorted_idx = np.argsort(probs_np)[::-1]
348
+ sorted_probs = probs_np[sorted_idx]
349
+ cum = np.cumsum(sorted_probs)
350
+ cutoff = np.searchsorted(cum, p)
351
+ top_idx = sorted_idx[: cutoff + 1]
352
+ top_probs = sorted_probs[: cutoff + 1]
353
+ top_probs = top_probs / top_probs.sum()
354
+ # sample
355
+ next_token = np.random.choice(top_idx, p=top_probs)
356
+ return int(next_token)
357
+
358
+ def generate_text(state, prompt: str, max_gen=256, p=0.9, temperature=0.8, min_len=20):
359
+ # load params from replicated state (take first replica)
360
+ params = jax.tree_map(lambda x: np.array(x[0]), state.params)
361
+ tokens = sp.encode("<start> " + prompt, out_type=int)
362
+ generated = tokens.copy()
363
+ for step in range(max_gen):
364
+ cur = generated[-SEQ_LEN:]
365
+ if len(cur) < SEQ_LEN:
366
+ cur = cur + [pad_id] * (SEQ_LEN - len(cur))
367
+ x = np.array([cur], dtype=np.int32)
368
+ logits = model.apply({"params": params}, x, deterministic=True) # (1, seq, vocab)
369
+ logits = np.array(logits[0, len(generated)-1 if len(generated)-1 < SEQ_LEN else SEQ_LEN-1])
370
+ # penalize end/pad a bit
371
+ logits[end_id] -= 5.0
372
+ logits[pad_id] -= 10.0
373
+ next_id = top_p_sample_logits(None, logits, p=p, temperature=temperature)
374
+ generated.append(next_id)
375
+ if next_id == end_id and len(generated) >= min_len:
376
+ break
377
+ return sp.decode(generated)
378
+
379
+ # quick generate
380
+ print("\n\n===== 생성 결과 =====")
381
+ print(generate_text(state, "지난 2년 동안 출연연이 국가가 필요한 연구를", p=0.9))