|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import math |
|
|
import numpy as np |
|
|
import sentencepiece as spm |
|
|
from functools import partial |
|
|
from typing import Any, Callable, Optional, Tuple, Sequence |
|
|
|
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
from jax import random |
|
|
from flax import linen as nn |
|
|
from flax.training import train_state, checkpoints |
|
|
import optax |
|
|
import tqdm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SEQ_LEN = 512 |
|
|
|
|
|
GLOBAL_BATCH = 256 |
|
|
|
|
|
LIMIT = 200_000 |
|
|
VOCAB_MODEL = "ko_unigram.model" |
|
|
CORPUS_PATH = "corpus.txt" |
|
|
DTYPE = jnp.bfloat16 if jax.local_devices()[0].platform == "tpu" else jnp.float32 |
|
|
SEED = 42 |
|
|
LEARNING_RATE = 1e-4 |
|
|
EPOCHS = 1 |
|
|
|
|
|
|
|
|
NUM_DEVICES = jax.device_count() |
|
|
assert GLOBAL_BATCH % NUM_DEVICES == 0, "GLOBAL_BATCH must be divisible by device count" |
|
|
PER_DEVICE_BATCH = GLOBAL_BATCH // NUM_DEVICES |
|
|
|
|
|
print("devices:", jax.devices()) |
|
|
print("num_devices:", NUM_DEVICES, "per_device_batch:", PER_DEVICE_BATCH, "dtype:", DTYPE) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sp = spm.SentencePieceProcessor() |
|
|
sp.load(VOCAB_MODEL) |
|
|
|
|
|
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0 |
|
|
start_id = sp.piece_to_id("<start>") |
|
|
end_id = sp.piece_to_id("<end>") |
|
|
vocab_size = sp.get_piece_size() |
|
|
print("vocab_size:", vocab_size, "pad_id:", pad_id, "start_id:", start_id, "end_id:", end_id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def line_to_ids(line: str, max_len: int = SEQ_LEN): |
|
|
ids = sp.encode(line.strip(), out_type=int) |
|
|
if len(ids) > max_len - 1: |
|
|
ids = ids[: max_len - 1] |
|
|
ids = ids + [end_id] |
|
|
pad_len = max_len - len(ids) |
|
|
ids = ids + [pad_id] * pad_len |
|
|
return np.array(ids, dtype=np.int32) |
|
|
|
|
|
def build_dataset(corpus_path: str, limit: int = LIMIT): |
|
|
arr = [] |
|
|
with open(corpus_path, "r", encoding="utf-8") as f: |
|
|
for i, line in enumerate(f): |
|
|
if i >= limit: |
|
|
break |
|
|
line = line.strip() |
|
|
if not line: |
|
|
continue |
|
|
arr.append(line_to_ids(line)) |
|
|
data = np.stack(arr, axis=0) |
|
|
print("Loaded dataset shape:", data.shape) |
|
|
return data |
|
|
|
|
|
|
|
|
data_np = build_dataset(CORPUS_PATH, LIMIT) |
|
|
inputs = data_np |
|
|
targets = np.concatenate([data_np[:,1:], np.full((data_np.shape[0],1), pad_id, dtype=np.int32)], axis=1) |
|
|
|
|
|
|
|
|
def create_batch_iter(inputs: np.ndarray, targets: np.ndarray, batch_size: int, rng: np.random.Generator): |
|
|
idx = np.arange(inputs.shape[0]) |
|
|
rng.shuffle(idx) |
|
|
for i in range(0, len(idx) - batch_size + 1, batch_size): |
|
|
batch_idx = idx[i:i+batch_size] |
|
|
x = inputs[batch_idx] |
|
|
y = targets[batch_idx] |
|
|
yield x, y |
|
|
|
|
|
|
|
|
def shard(xs: np.ndarray): |
|
|
return xs.reshape((NUM_DEVICES, -1) + xs.shape[1:]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SwiGLU(nn.Module): |
|
|
d_model: int |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, x): |
|
|
|
|
|
proj = nn.Dense(self.d_model * 2, dtype=jnp.float32)(x) |
|
|
x_val, x_gate = jnp.split(proj, 2, axis=-1) |
|
|
out = x_val * nn.silu(x_gate) |
|
|
out = nn.Dense(self.d_model, dtype=jnp.float32)(out) |
|
|
return out.astype(x.dtype) |
|
|
|
|
|
class LoU(nn.Module): |
|
|
d_model: int |
|
|
clip_value: float = 5.0 |
|
|
eps: float = 1e-6 |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, x): |
|
|
|
|
|
x_f32 = x.astype(jnp.float32) |
|
|
residual = x_f32 |
|
|
|
|
|
norm1 = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32) |
|
|
x_norm = norm1(x_f32) |
|
|
|
|
|
Q = nn.Dense(self.d_model, dtype=jnp.float32) |
|
|
K = nn.Dense(self.d_model, dtype=jnp.float32) |
|
|
V = nn.Dense(self.d_model, dtype=jnp.float32) |
|
|
|
|
|
q = Q(x_norm) |
|
|
k = K(x_norm) |
|
|
v = V(x_norm) |
|
|
|
|
|
g_q = (jnp.tanh(q) + 1.0) / 2.0 |
|
|
g_k = (jnp.tanh(k) + 1.0) / 2.0 |
|
|
score = g_q * g_k |
|
|
|
|
|
alpha_linear = nn.Dense(1, dtype=jnp.float32) |
|
|
alpha_dynamic = alpha_linear(x_norm) |
|
|
|
|
|
|
|
|
|
|
|
score_t = jnp.transpose(score, (1,0,2)) |
|
|
alpha_t = jnp.transpose(alpha_dynamic, (1,0,2)) |
|
|
|
|
|
def step(carry, inputs): |
|
|
prev_ema = carry |
|
|
x_t, a_t = inputs |
|
|
new = a_t * x_t + (1.0 - a_t) * prev_ema |
|
|
return new, new |
|
|
|
|
|
init = score_t[0] |
|
|
_, ema_seq = jax.lax.scan(step, init, (score_t[1:], alpha_t[1:])) |
|
|
ema_full = jnp.concatenate([init[None, ...], ema_seq], axis=0) |
|
|
ema = jnp.transpose(ema_full, (1,0,2)) |
|
|
|
|
|
mean_last = jnp.mean(ema, axis=-1, keepdims=True) |
|
|
denom = jnp.maximum(mean_last, self.eps) |
|
|
score_norm = ema / denom |
|
|
score_clipped = jnp.clip(score_norm, -self.clip_value, self.clip_value) |
|
|
|
|
|
x_comb = score_clipped * v |
|
|
out = x_comb + residual |
|
|
out = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)(out) |
|
|
out = SwiGLU(self.d_model)(out.astype(x.dtype)) |
|
|
return out.astype(x.dtype) |
|
|
|
|
|
class Lo(nn.Module): |
|
|
d_model: int |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, x): |
|
|
h = nn.Dense(64, dtype=jnp.float32)(x) |
|
|
h = nn.silu(h) |
|
|
h = nn.Dense(self.d_model, dtype=jnp.float32)(h) |
|
|
out = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)(h) + x |
|
|
return out.astype(x.dtype) |
|
|
|
|
|
class Block(nn.Module): |
|
|
d_model: int |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, x): |
|
|
x = LoU(self.d_model)(x) |
|
|
x = Lo(self.d_model)(x) |
|
|
return x |
|
|
|
|
|
class ReLM(nn.Module): |
|
|
vocab_size: int |
|
|
max_seq_len: int |
|
|
d_model: int |
|
|
n_layers: int |
|
|
dtype: Any = jnp.float32 |
|
|
|
|
|
def setup(self): |
|
|
self.token_embed = nn.Embed(self.vocab_size, self.d_model, dtype=self.dtype) |
|
|
self.pos_embed = nn.Embed(self.max_seq_len, self.d_model, dtype=self.dtype) |
|
|
self.blocks = [Block(self.d_model) for _ in range(self.n_layers)] |
|
|
self.ln_f = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32) |
|
|
|
|
|
def __call__(self, x, deterministic=True): |
|
|
|
|
|
b, seq = x.shape |
|
|
positions = jnp.arange(seq)[None, :] |
|
|
x = self.token_embed(x) + self.pos_embed(positions) |
|
|
for blk in self.blocks: |
|
|
x = blk(x) |
|
|
x = self.ln_f(x) |
|
|
|
|
|
embedding_matrix = self.token_embed.embedding |
|
|
logits = jnp.einsum("bld,vd->blv", x, embedding_matrix) |
|
|
return logits.astype(jnp.float32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def smoothed_cross_entropy(logits, targets, pad_id, eps=0.1): |
|
|
|
|
|
|
|
|
vocab = logits.shape[-1] |
|
|
logits = logits.reshape(-1, vocab) |
|
|
targets = targets.reshape(-1) |
|
|
mask = (targets != pad_id).astype(jnp.float32) |
|
|
|
|
|
one_hot = jax.nn.one_hot(targets, vocab) |
|
|
smooth = (1.0 - eps) * one_hot + eps / float(vocab) |
|
|
log_probs = jax.nn.log_softmax(logits, axis=-1) |
|
|
loss_per_token = -jnp.sum(smooth * log_probs, axis=-1) |
|
|
loss_per_token = loss_per_token * mask |
|
|
denom = jnp.sum(mask) + 1e-8 |
|
|
loss = jnp.sum(loss_per_token) / denom |
|
|
return loss |
|
|
|
|
|
def masked_perplexity_from_logits(logits, targets, pad_id, eps=0.1): |
|
|
vocab = logits.shape[-1] |
|
|
logits = logits.reshape(-1, vocab) |
|
|
targets = targets.reshape(-1) |
|
|
mask = (targets != pad_id).astype(jnp.float32) |
|
|
one_hot = jax.nn.one_hot(targets, vocab) |
|
|
smooth = (1.0 - eps) * one_hot + eps / float(vocab) |
|
|
log_probs = jax.nn.log_softmax(logits, axis=-1) |
|
|
loss_per_token = -jnp.sum(smooth * log_probs, axis=-1) * mask |
|
|
mean_loss = jnp.sum(loss_per_token) / (jnp.sum(mask) + 1e-8) |
|
|
return jnp.exp(mean_loss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainState(train_state.TrainState): |
|
|
pass |
|
|
|
|
|
def create_train_state(rng, model, learning_rate): |
|
|
params = model.init(rng, jnp.zeros((1, SEQ_LEN), dtype=jnp.int32))["params"] |
|
|
tx = optax.chain( |
|
|
optax.clip_by_global_norm(1.0), |
|
|
optax.adamw(learning_rate=learning_rate, b1=0.9, b2=0.95, eps=1e-8) |
|
|
) |
|
|
return TrainState.create(apply_fn=model.apply, params=params, tx=tx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@partial(jax.pmap, axis_name="batch") |
|
|
def train_step(state, batch_x, batch_y, rng): |
|
|
def loss_fn(params): |
|
|
logits = state.apply_fn({"params": params}, batch_x, deterministic=False) |
|
|
loss = smoothed_cross_entropy(logits, batch_y, pad_id) |
|
|
return loss, logits |
|
|
|
|
|
grad_fn = jax.value_and_grad(loss_fn, has_aux=True) |
|
|
(loss, logits), grads = grad_fn(state.params) |
|
|
grads = jax.lax.pmean(grads, axis_name="batch") |
|
|
new_state = state.apply_gradients(grads=grads) |
|
|
|
|
|
ppl = masked_perplexity_from_logits(logits, batch_y, pad_id) |
|
|
metrics = {"loss": loss, "ppl": ppl} |
|
|
metrics = jax.lax.pmean(metrics, axis_name="batch") |
|
|
return new_state, metrics |
|
|
|
|
|
@partial(jax.pmap, axis_name="batch") |
|
|
def eval_step(state, batch_x, batch_y): |
|
|
logits = state.apply_fn({"params": state.params}, batch_x, deterministic=True) |
|
|
loss = smoothed_cross_entropy(logits, batch_y, pad_id) |
|
|
ppl = masked_perplexity_from_logits(logits, batch_y, pad_id) |
|
|
metrics = {"loss": loss, "ppl": ppl} |
|
|
metrics = jax.lax.pmean(metrics, axis_name="batch") |
|
|
return metrics |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rng = random.PRNGKey(SEED) |
|
|
rng, init_rng = random.split(rng) |
|
|
model = ReLM(vocab_size=vocab_size, max_seq_len=SEQ_LEN, d_model=512, n_layers=9, dtype=DTYPE) |
|
|
state = create_train_state(init_rng, model, LEARNING_RATE) |
|
|
|
|
|
|
|
|
state = jax.device_put_replicated(state, jax.local_devices()) |
|
|
|
|
|
print("Starting training...") |
|
|
|
|
|
global_step = 0 |
|
|
for epoch in range(EPOCHS): |
|
|
print(f"Epoch {epoch+1}/{EPOCHS}") |
|
|
np_rng = np.random.default_rng(SEED + epoch) |
|
|
batch_iter = create_batch_iter(inputs, targets, GLOBAL_BATCH, np_rng) |
|
|
pbar = tqdm.tqdm(batch_iter, total= max(1, inputs.shape[0] // GLOBAL_BATCH)) |
|
|
|
|
|
for batch_x, batch_y in pbar: |
|
|
|
|
|
batch_x = shard(batch_x) |
|
|
batch_y = shard(batch_y) |
|
|
rng, step_rng = random.split(rng) |
|
|
|
|
|
step_rngs = random.split(step_rng, NUM_DEVICES) |
|
|
state, metrics = train_step(state, batch_x, batch_y, step_rngs) |
|
|
|
|
|
m = jax.tree_util.tree_map(lambda x: x[0], metrics) |
|
|
pbar.set_postfix(loss=float(m["loss"]), ppl=float(m["ppl"])) |
|
|
global_step += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
save_dir = "./checkpoints" |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
checkpoints.save_checkpoint(save_dir, jax.tree_map(lambda x: np.array(x), state), step=global_step, keep=3) |
|
|
print("Saved checkpoint to", save_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
|
|
def top_p_sample_logits(rng, logits, p=0.9, temperature=1.0): |
|
|
|
|
|
probs = jax.nn.softmax(logits / temperature) |
|
|
|
|
|
probs_np = np.array(probs) |
|
|
sorted_idx = np.argsort(probs_np)[::-1] |
|
|
sorted_probs = probs_np[sorted_idx] |
|
|
cum = np.cumsum(sorted_probs) |
|
|
cutoff = np.searchsorted(cum, p) |
|
|
top_idx = sorted_idx[: cutoff + 1] |
|
|
top_probs = sorted_probs[: cutoff + 1] |
|
|
top_probs = top_probs / top_probs.sum() |
|
|
|
|
|
next_token = np.random.choice(top_idx, p=top_probs) |
|
|
return int(next_token) |
|
|
|
|
|
def generate_text(state, prompt: str, max_gen=256, p=0.9, temperature=0.8, min_len=20): |
|
|
|
|
|
params = jax.tree_map(lambda x: np.array(x[0]), state.params) |
|
|
tokens = sp.encode("<start> " + prompt, out_type=int) |
|
|
generated = tokens.copy() |
|
|
for step in range(max_gen): |
|
|
cur = generated[-SEQ_LEN:] |
|
|
if len(cur) < SEQ_LEN: |
|
|
cur = cur + [pad_id] * (SEQ_LEN - len(cur)) |
|
|
x = np.array([cur], dtype=np.int32) |
|
|
logits = model.apply({"params": params}, x, deterministic=True) |
|
|
logits = np.array(logits[0, len(generated)-1 if len(generated)-1 < SEQ_LEN else SEQ_LEN-1]) |
|
|
|
|
|
logits[end_id] -= 5.0 |
|
|
logits[pad_id] -= 10.0 |
|
|
next_id = top_p_sample_logits(None, logits, p=p, temperature=temperature) |
|
|
generated.append(next_id) |
|
|
if next_id == end_id and len(generated) >= min_len: |
|
|
break |
|
|
return sp.decode(generated) |
|
|
|
|
|
|
|
|
print("\n\n===== 생성 결과 =====") |
|
|
print(generate_text(state, "지난 2년 동안 출연연이 국가가 필요한 연구를", p=0.9)) |
|
|
|