model-prototype / Mo_jax.py
Yuchan
Create Mo_jax.py
63439a6 verified
raw
history blame
13.4 kB
# Flax + JAX TPU-ready reimplementation of your ReLM model and training loop.
# Requirements:
# pip install --upgrade "jax[tpu]" flax optax sentencepiece
import os
import math
import numpy as np
import sentencepiece as spm
from functools import partial
from typing import Any, Callable, Optional, Tuple, Sequence
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
from flax.training import train_state, checkpoints
import optax
import tqdm
# ------------------
# Config
# ------------------
SEQ_LEN = 512
# global batch size (across all devices)
GLOBAL_BATCH = 256
# adjust for memory
LIMIT = 200_000 # number of sequences to load (reduce if OOM)
VOCAB_MODEL = "ko_unigram.model"
CORPUS_PATH = "corpus.txt"
DTYPE = jnp.bfloat16 if jax.local_devices()[0].platform == "tpu" else jnp.float32
SEED = 42
LEARNING_RATE = 1e-4
EPOCHS = 1
# Derived
NUM_DEVICES = jax.device_count()
assert GLOBAL_BATCH % NUM_DEVICES == 0, "GLOBAL_BATCH must be divisible by device count"
PER_DEVICE_BATCH = GLOBAL_BATCH // NUM_DEVICES
print("devices:", jax.devices())
print("num_devices:", NUM_DEVICES, "per_device_batch:", PER_DEVICE_BATCH, "dtype:", DTYPE)
# ------------------
# Tokenizer loader
# ------------------
sp = spm.SentencePieceProcessor()
sp.load(VOCAB_MODEL)
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
start_id = sp.piece_to_id("<start>")
end_id = sp.piece_to_id("<end>")
vocab_size = sp.get_piece_size()
print("vocab_size:", vocab_size, "pad_id:", pad_id, "start_id:", start_id, "end_id:", end_id)
# ------------------
# Data pipeline (simple, numpy-based)
# - Reads corpus line-by-line, tokenizes, pads/truncates to SEQ_LEN.
# - Builds a numpy array (N, SEQ_LEN) for inputs and targets (shifted by 1).
# - Shards batches across devices for pmap.
# ------------------
def line_to_ids(line: str, max_len: int = SEQ_LEN):
ids = sp.encode(line.strip(), out_type=int)
if len(ids) > max_len - 1:
ids = ids[: max_len - 1]
ids = ids + [end_id]
pad_len = max_len - len(ids)
ids = ids + [pad_id] * pad_len
return np.array(ids, dtype=np.int32)
def build_dataset(corpus_path: str, limit: int = LIMIT):
arr = []
with open(corpus_path, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
if i >= limit:
break
line = line.strip()
if not line:
continue
arr.append(line_to_ids(line))
data = np.stack(arr, axis=0) # (N, SEQ_LEN)
print("Loaded dataset shape:", data.shape)
return data
# create inputs and targets
data_np = build_dataset(CORPUS_PATH, LIMIT)
inputs = data_np
targets = np.concatenate([data_np[:,1:], np.full((data_np.shape[0],1), pad_id, dtype=np.int32)], axis=1)
# shuffle and create batches
def create_batch_iter(inputs: np.ndarray, targets: np.ndarray, batch_size: int, rng: np.random.Generator):
idx = np.arange(inputs.shape[0])
rng.shuffle(idx)
for i in range(0, len(idx) - batch_size + 1, batch_size):
batch_idx = idx[i:i+batch_size]
x = inputs[batch_idx]
y = targets[batch_idx]
yield x, y
# helper to shard numpy batch for pmap: shape (num_devices, per_device, ...)
def shard(xs: np.ndarray):
return xs.reshape((NUM_DEVICES, -1) + xs.shape[1:])
# ------------------
# Flax model implementation
# ------------------
class SwiGLU(nn.Module):
d_model: int
@nn.compact
def __call__(self, x):
# project to 2*intermediate, then split
proj = nn.Dense(self.d_model * 2, dtype=jnp.float32)(x) # keep proj in float32
x_val, x_gate = jnp.split(proj, 2, axis=-1)
out = x_val * nn.silu(x_gate)
out = nn.Dense(self.d_model, dtype=jnp.float32)(out)
return out.astype(x.dtype)
class LoU(nn.Module):
d_model: int
clip_value: float = 5.0
eps: float = 1e-6
@nn.compact
def __call__(self, x):
# x: (batch, seq, d)
x_f32 = x.astype(jnp.float32)
residual = x_f32
norm1 = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)
x_norm = norm1(x_f32)
Q = nn.Dense(self.d_model, dtype=jnp.float32)
K = nn.Dense(self.d_model, dtype=jnp.float32)
V = nn.Dense(self.d_model, dtype=jnp.float32)
q = Q(x_norm)
k = K(x_norm)
v = V(x_norm)
g_q = (jnp.tanh(q) + 1.0) / 2.0
g_k = (jnp.tanh(k) + 1.0) / 2.0
score = g_q * g_k # (b, seq, d)
alpha_linear = nn.Dense(1, dtype=jnp.float32)
alpha_dynamic = alpha_linear(x_norm) # (b, seq, 1)
# EMA over time: use scan across sequence axis
# transpose to (seq, batch, d) to scan over time
score_t = jnp.transpose(score, (1,0,2))
alpha_t = jnp.transpose(alpha_dynamic, (1,0,2))
def step(carry, inputs):
prev_ema = carry
x_t, a_t = inputs
new = a_t * x_t + (1.0 - a_t) * prev_ema
return new, new
init = score_t[0]
_, ema_seq = jax.lax.scan(step, init, (score_t[1:], alpha_t[1:]))
ema_full = jnp.concatenate([init[None, ...], ema_seq], axis=0) # (seq, batch, d)
ema = jnp.transpose(ema_full, (1,0,2)) # (batch, seq, d)
mean_last = jnp.mean(ema, axis=-1, keepdims=True)
denom = jnp.maximum(mean_last, self.eps)
score_norm = ema / denom
score_clipped = jnp.clip(score_norm, -self.clip_value, self.clip_value)
x_comb = score_clipped * v
out = x_comb + residual
out = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)(out)
out = SwiGLU(self.d_model)(out.astype(x.dtype))
return out.astype(x.dtype)
class Lo(nn.Module):
d_model: int
@nn.compact
def __call__(self, x):
h = nn.Dense(64, dtype=jnp.float32)(x)
h = nn.silu(h)
h = nn.Dense(self.d_model, dtype=jnp.float32)(h)
out = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)(h) + x
return out.astype(x.dtype)
class Block(nn.Module):
d_model: int
@nn.compact
def __call__(self, x):
x = LoU(self.d_model)(x)
x = Lo(self.d_model)(x)
return x
class ReLM(nn.Module):
vocab_size: int
max_seq_len: int
d_model: int
n_layers: int
dtype: Any = jnp.float32
def setup(self):
self.token_embed = nn.Embed(self.vocab_size, self.d_model, dtype=self.dtype)
self.pos_embed = nn.Embed(self.max_seq_len, self.d_model, dtype=self.dtype)
self.blocks = [Block(self.d_model) for _ in range(self.n_layers)]
self.ln_f = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)
def __call__(self, x, deterministic=True):
# x: (batch, seq)
b, seq = x.shape
positions = jnp.arange(seq)[None, :]
x = self.token_embed(x) + self.pos_embed(positions)
for blk in self.blocks:
x = blk(x)
x = self.ln_f(x)
# tie weights: token embedding matrix
embedding_matrix = self.token_embed.embedding # (vocab, d)
logits = jnp.einsum("bld,vd->blv", x, embedding_matrix)
return logits.astype(jnp.float32)
# ------------------
# Loss & metrics
# ------------------
def smoothed_cross_entropy(logits, targets, pad_id, eps=0.1):
# logits: (b, seq, v)
# targets: (b, seq) int32
vocab = logits.shape[-1]
logits = logits.reshape(-1, vocab)
targets = targets.reshape(-1)
mask = (targets != pad_id).astype(jnp.float32)
# one-hot smoothed
one_hot = jax.nn.one_hot(targets, vocab)
smooth = (1.0 - eps) * one_hot + eps / float(vocab)
log_probs = jax.nn.log_softmax(logits, axis=-1)
loss_per_token = -jnp.sum(smooth * log_probs, axis=-1)
loss_per_token = loss_per_token * mask
denom = jnp.sum(mask) + 1e-8
loss = jnp.sum(loss_per_token) / denom
return loss
def masked_perplexity_from_logits(logits, targets, pad_id, eps=0.1):
vocab = logits.shape[-1]
logits = logits.reshape(-1, vocab)
targets = targets.reshape(-1)
mask = (targets != pad_id).astype(jnp.float32)
one_hot = jax.nn.one_hot(targets, vocab)
smooth = (1.0 - eps) * one_hot + eps / float(vocab)
log_probs = jax.nn.log_softmax(logits, axis=-1)
loss_per_token = -jnp.sum(smooth * log_probs, axis=-1) * mask
mean_loss = jnp.sum(loss_per_token) / (jnp.sum(mask) + 1e-8)
return jnp.exp(mean_loss)
# ------------------
# Training state
# ------------------
class TrainState(train_state.TrainState):
pass
def create_train_state(rng, model, learning_rate):
params = model.init(rng, jnp.zeros((1, SEQ_LEN), dtype=jnp.int32))["params"]
tx = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adamw(learning_rate=learning_rate, b1=0.9, b2=0.95, eps=1e-8)
)
return TrainState.create(apply_fn=model.apply, params=params, tx=tx)
# ------------------
# pmap'd step functions
# ------------------
@partial(jax.pmap, axis_name="batch")
def train_step(state, batch_x, batch_y, rng):
def loss_fn(params):
logits = state.apply_fn({"params": params}, batch_x, deterministic=False)
loss = smoothed_cross_entropy(logits, batch_y, pad_id)
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
grads = jax.lax.pmean(grads, axis_name="batch")
new_state = state.apply_gradients(grads=grads)
# metrics
ppl = masked_perplexity_from_logits(logits, batch_y, pad_id)
metrics = {"loss": loss, "ppl": ppl}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return new_state, metrics
@partial(jax.pmap, axis_name="batch")
def eval_step(state, batch_x, batch_y):
logits = state.apply_fn({"params": state.params}, batch_x, deterministic=True)
loss = smoothed_cross_entropy(logits, batch_y, pad_id)
ppl = masked_perplexity_from_logits(logits, batch_y, pad_id)
metrics = {"loss": loss, "ppl": ppl}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return metrics
# ------------------
# Training loop
# ------------------
rng = random.PRNGKey(SEED)
rng, init_rng = random.split(rng)
model = ReLM(vocab_size=vocab_size, max_seq_len=SEQ_LEN, d_model=512, n_layers=9, dtype=DTYPE)
state = create_train_state(init_rng, model, LEARNING_RATE)
# replicate to devices
state = jax.device_put_replicated(state, jax.local_devices())
print("Starting training...")
global_step = 0
for epoch in range(EPOCHS):
print(f"Epoch {epoch+1}/{EPOCHS}")
np_rng = np.random.default_rng(SEED + epoch)
batch_iter = create_batch_iter(inputs, targets, GLOBAL_BATCH, np_rng)
pbar = tqdm.tqdm(batch_iter, total= max(1, inputs.shape[0] // GLOBAL_BATCH))
for batch_x, batch_y in pbar:
# shard
batch_x = shard(batch_x)
batch_y = shard(batch_y)
rng, step_rng = random.split(rng)
# make per-device rngs
step_rngs = random.split(step_rng, NUM_DEVICES)
state, metrics = train_step(state, batch_x, batch_y, step_rngs)
# metrics are per-device; take first replica
m = jax.tree_util.tree_map(lambda x: x[0], metrics)
pbar.set_postfix(loss=float(m["loss"]), ppl=float(m["ppl"]))
global_step += 1
# ------------------
# Save params
# ------------------
save_dir = "./checkpoints"
os.makedirs(save_dir, exist_ok=True)
# save using flax.serialization via checkpoints
checkpoints.save_checkpoint(save_dir, jax.tree_map(lambda x: np.array(x), state), step=global_step, keep=3)
print("Saved checkpoint to", save_dir)
# ------------------
# Sampling (top-p) - single-device (CPU) sampling for simplicity
# ------------------
import math
def top_p_sample_logits(rng, logits, p=0.9, temperature=1.0):
# logits: (vocab,)
probs = jax.nn.softmax(logits / temperature)
# convert to numpy for sorting (ok for single token)
probs_np = np.array(probs)
sorted_idx = np.argsort(probs_np)[::-1]
sorted_probs = probs_np[sorted_idx]
cum = np.cumsum(sorted_probs)
cutoff = np.searchsorted(cum, p)
top_idx = sorted_idx[: cutoff + 1]
top_probs = sorted_probs[: cutoff + 1]
top_probs = top_probs / top_probs.sum()
# sample
next_token = np.random.choice(top_idx, p=top_probs)
return int(next_token)
def generate_text(state, prompt: str, max_gen=256, p=0.9, temperature=0.8, min_len=20):
# load params from replicated state (take first replica)
params = jax.tree_map(lambda x: np.array(x[0]), state.params)
tokens = sp.encode("<start> " + prompt, out_type=int)
generated = tokens.copy()
for step in range(max_gen):
cur = generated[-SEQ_LEN:]
if len(cur) < SEQ_LEN:
cur = cur + [pad_id] * (SEQ_LEN - len(cur))
x = np.array([cur], dtype=np.int32)
logits = model.apply({"params": params}, x, deterministic=True) # (1, seq, vocab)
logits = np.array(logits[0, len(generated)-1 if len(generated)-1 < SEQ_LEN else SEQ_LEN-1])
# penalize end/pad a bit
logits[end_id] -= 5.0
logits[pad_id] -= 10.0
next_id = top_p_sample_logits(None, logits, p=p, temperature=temperature)
generated.append(next_id)
if next_id == end_id and len(generated) >= min_len:
break
return sp.decode(generated)
# quick generate
print("\n\n===== 생성 결과 =====")
print(generate_text(state, "지난 2년 동안 출연연이 국가가 필요한 연구를", p=0.9))