|
|
|
|
|
import os, math, numpy as np, sentencepiece as spm, requests, tqdm |
|
|
from functools import partial |
|
|
from typing import Any |
|
|
import jax, jax.numpy as jnp |
|
|
from jax import random |
|
|
from flax import linen as nn |
|
|
from flax.training import train_state, checkpoints |
|
|
import optax |
|
|
import requests |
|
|
|
|
|
def download_file(url, save_path): |
|
|
r = requests.get(url, stream=True) |
|
|
r.raise_for_status() |
|
|
with open(save_path, "wb") as f: |
|
|
for chunk in r.iter_content(8192*2): |
|
|
f.write(chunk) |
|
|
print(f"โ
{save_path} ์ ์ฅ๋จ") |
|
|
|
|
|
|
|
|
|
|
|
SEQ_LEN = 512 |
|
|
GLOBAL_BATCH = 256 |
|
|
LIMIT = 200_000 |
|
|
VOCAB_MODEL = "ko_unigram.model" |
|
|
CORPUS_PATH = "corpus.txt" |
|
|
SEED = 42 |
|
|
LEARNING_RATE = 1e-4 |
|
|
EPOCHS = 1 |
|
|
|
|
|
if not os.path.exists(CORPUS_PATH): |
|
|
download_file( |
|
|
"https://huggingface.co/datasets/Yuchan5386/Prototype/resolve/main/corpus_ko.txt?download=true", |
|
|
CORPUS_PATH |
|
|
) |
|
|
|
|
|
if not os.path.exists(VOCAB_MODEL): |
|
|
download_file( |
|
|
"https://huggingface.co/Yuchan5386/inlam-100m/resolve/main/ko_unigram.model?download=true", |
|
|
VOCAB_MODEL |
|
|
) |
|
|
|
|
|
DTYPE = jnp.bfloat16 if jax.local_devices()[0].platform == "tpu" else jnp.float32 |
|
|
NUM_DEVICES = jax.device_count() |
|
|
PER_DEVICE_BATCH = GLOBAL_BATCH // NUM_DEVICES |
|
|
print("devices:", jax.devices(), "dtype:", DTYPE) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sp = spm.SentencePieceProcessor() |
|
|
sp.load(VOCAB_MODEL) |
|
|
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>")!=-1 else 0 |
|
|
start_id = sp.piece_to_id("<start>") |
|
|
end_id = sp.piece_to_id("<end>") |
|
|
vocab_size = sp.get_piece_size() |
|
|
print("vocab_size:", vocab_size, "pad_id:", pad_id, "start_id:", start_id, "end_id:", end_id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def line_to_ids(line, max_len=SEQ_LEN): |
|
|
ids = sp.encode(line.strip(), out_type=int) |
|
|
if len(ids) > max_len-1: ids = ids[:max_len-1] |
|
|
ids += [end_id] + [pad_id]*(max_len-len(ids)-1) |
|
|
return np.array(ids, dtype=np.int32) |
|
|
|
|
|
def build_dataset(corpus_path, limit=LIMIT): |
|
|
arr = [] |
|
|
with open(corpus_path, "r", encoding="utf-8") as f: |
|
|
for i, line in enumerate(f): |
|
|
if i>=limit: break |
|
|
line=line.strip() |
|
|
if not line: continue |
|
|
arr.append(line_to_ids(line)) |
|
|
data = np.stack(arr, axis=0) |
|
|
print("Loaded dataset:", data.shape) |
|
|
return data |
|
|
|
|
|
data_np = build_dataset(CORPUS_PATH, LIMIT) |
|
|
inputs = data_np |
|
|
targets = np.concatenate([data_np[:,1:], np.full((data_np.shape[0],1), pad_id, np.int32)], axis=1) |
|
|
|
|
|
def create_batch_iter(inputs, targets, batch_size, rng): |
|
|
idx = np.arange(inputs.shape[0]); rng.shuffle(idx) |
|
|
for i in range(0,len(idx)-batch_size+1,batch_size): |
|
|
batch_idx = idx[i:i+batch_size] |
|
|
yield inputs[batch_idx], targets[batch_idx] |
|
|
|
|
|
def shard(xs): return xs.reshape(NUM_DEVICES, -1, xs.shape[1]) |
|
|
|
|
|
class SwiGLU(nn.Module): |
|
|
d_model: int |
|
|
@nn.compact |
|
|
def __call__(self, x): |
|
|
x_f32 = x.astype(jnp.float32) |
|
|
proj = nn.Dense(self.d_model*2, dtype=jnp.float32)(x_f32) |
|
|
x_val, x_gate = jnp.split(proj, 2, axis=-1) |
|
|
out = x_val * nn.silu(x_gate) |
|
|
out = nn.Dense(self.d_model, dtype=jnp.float32)(out) |
|
|
return out.astype(x.dtype) |
|
|
|
|
|
class LoU(nn.Module): |
|
|
d_model: int |
|
|
clip_value: float = 5.0 |
|
|
eps: float = 1e-6 |
|
|
@nn.compact |
|
|
def __call__(self, x): |
|
|
x_f32 = x.astype(jnp.float32) |
|
|
residual = x_f32 |
|
|
x_norm = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)(x_f32) |
|
|
Q = nn.Dense(self.d_model, dtype=jnp.float32) |
|
|
K = nn.Dense(self.d_model, dtype=jnp.float32) |
|
|
V = nn.Dense(self.d_model, dtype=jnp.float32) |
|
|
q,k,v = Q(x_norm), K(x_norm), V(x_norm) |
|
|
g_q = (jnp.tanh(q)+1)/2 |
|
|
g_k = (jnp.tanh(k)+1)/2 |
|
|
score = g_q * g_k |
|
|
alpha_dynamic = nn.Dense(1, dtype=jnp.float32)(x_norm) |
|
|
|
|
|
score_t = jnp.transpose(score,(1,0,2)) |
|
|
alpha_t = jnp.transpose(alpha_dynamic,(1,0,2)) |
|
|
def step(prev, cur): |
|
|
s, a = cur |
|
|
new = a*s + (1-a)*prev |
|
|
return new,new |
|
|
init = score_t[0] |
|
|
_, ema_seq = jax.lax.scan(step, init, (score_t[1:], alpha_t[1:])) |
|
|
ema_full = jnp.concatenate([init[None,...], ema_seq], 0) |
|
|
ema = jnp.transpose(ema_full,(1,0,2)) |
|
|
out = v * ema + residual |
|
|
out = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)(out) |
|
|
return SwiGLU(self.d_model)(out).astype(x.dtype) |
|
|
|
|
|
|
|
|
class Lo(nn.Module): |
|
|
d_model:int |
|
|
dtype:Any=DTYPE |
|
|
@nn.compact |
|
|
def __call__(self,x): |
|
|
h=nn.Dense(64,dtype=self.dtype)(x); h=nn.silu(h) |
|
|
h=nn.Dense(self.d_model,dtype=self.dtype)(h) |
|
|
return nn.LayerNorm(epsilon=1e-5,dtype=self.dtype)(h)+x |
|
|
|
|
|
class Block(nn.Module): |
|
|
d_model:int |
|
|
dtype:Any=DTYPE |
|
|
@nn.compact |
|
|
def __call__(self,x): |
|
|
x=LoU(self.d_model,self.dtype)(x) |
|
|
x=Lo(self.d_model,self.dtype)(x) |
|
|
return x |
|
|
|
|
|
class ReLM(nn.Module): |
|
|
vocab_size:int; max_seq_len:int; d_model:int; n_layers:int; dtype:Any=DTYPE |
|
|
def setup(self): |
|
|
self.token_embed = nn.Embed(self.vocab_size,self.d_model,dtype=self.dtype) |
|
|
self.pos_embed = nn.Embed(self.max_seq_len,self.d_model,dtype=self.dtype) |
|
|
self.blocks=[Block(self.d_model,self.dtype) for _ in range(self.n_layers)] |
|
|
self.ln_f=nn.LayerNorm(epsilon=1e-5,dtype=self.dtype) |
|
|
def __call__(self,x,deterministic=True): |
|
|
b,seq=x.shape |
|
|
pos=jnp.arange(seq)[None,:] |
|
|
x=self.token_embed(x)+self.pos_embed(pos) |
|
|
for blk in self.blocks: x=blk(x) |
|
|
x=self.ln_f(x) |
|
|
logits=jnp.einsum("bld,vd->blv",x,self.token_embed.embedding) |
|
|
return logits |
|
|
|
|
|
def smoothed_ce(logits, targets, pad_id, eps=0.1): |
|
|
logits = logits.astype(jnp.float32) |
|
|
targets = targets.astype(jnp.int32) |
|
|
vocab = logits.shape[-1] |
|
|
mask = (targets != pad_id).astype(jnp.float32) |
|
|
one_hot = jax.nn.one_hot(targets, vocab) |
|
|
smooth = (1-eps)*one_hot + eps/vocab |
|
|
log_probs = jax.nn.log_softmax(logits, axis=-1) |
|
|
loss = -jnp.sum(smooth * log_probs, axis=-1) * mask |
|
|
return jnp.sum(loss) / (jnp.sum(mask)+1e-8) |
|
|
|
|
|
def masked_ppl(logits, targets, pad_id, eps=0.1): |
|
|
logits = logits.astype(jnp.float32) |
|
|
targets = targets.astype(jnp.int32) |
|
|
vocab = logits.shape[-1] |
|
|
mask = (targets != pad_id).astype(jnp.float32) |
|
|
one_hot = jax.nn.one_hot(targets, vocab) |
|
|
smooth = (1-eps)*one_hot + eps/vocab |
|
|
log_probs = jax.nn.log_softmax(logits, axis=-1) |
|
|
loss = -jnp.sum(smooth*log_probs, axis=-1) * mask |
|
|
return jnp.exp(jnp.sum(loss)/(jnp.sum(mask)+1e-8)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainState(train_state.TrainState): pass |
|
|
def create_train_state(rng,model,lr): |
|
|
params=model.init(rng,jnp.zeros((1,SEQ_LEN),dtype=jnp.int32))["params"] |
|
|
tx=optax.chain(optax.clip_by_global_norm(1.0),optax.adamw(lr,b1=0.9,b2=0.95,eps=1e-8)) |
|
|
return TrainState.create(apply_fn=model.apply,params=params,tx=tx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@partial(jax.pmap, axis_name="batch") |
|
|
def train_step(state,bx,by,rngs): |
|
|
def loss_fn(params): |
|
|
logits=state.apply_fn({"params":params},bx,deterministic=False) |
|
|
return smoothed_ce(logits,by,pad_id),logits |
|
|
(loss,logits),grads=jax.value_and_grad(loss_fn,has_aux=True)(state.params) |
|
|
grads=jax.lax.pmean(grads,"batch") |
|
|
state=state.apply_gradients(grads=grads) |
|
|
metrics={"loss":loss,"ppl":masked_ppl(logits,by,pad_id)} |
|
|
metrics=jax.lax.pmean(metrics,"batch") |
|
|
return state,metrics |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def top_p_sample(rng, logits, p=0.9, temperature=1.0): |
|
|
probs=jax.nn.softmax(logits/temperature) |
|
|
sorted_probs,sorted_idx=jax.lax.top_k(probs,logits.shape[-1]) |
|
|
cum_probs=jnp.cumsum(sorted_probs) |
|
|
mask=cum_probs<=p |
|
|
top_probs=jnp.where(mask,sorted_probs,0.0) |
|
|
top_probs=top_probs/jnp.sum(top_probs) |
|
|
return int(sorted_idx[jax.random.categorical(rng,jnp.log(top_probs))]) |
|
|
|
|
|
def generate_text(state,prompt,max_gen=256,p=0.9,temperature=0.8,min_len=20): |
|
|
params=jax.tree_map(lambda x: np.array(x[0]),state.params) |
|
|
tokens=sp.encode("<start> "+prompt,out_type=int) |
|
|
generated=tokens.copy() |
|
|
rng=random.PRNGKey(SEED) |
|
|
for step in range(max_gen): |
|
|
cur=generated[-SEQ_LEN:] |
|
|
if len(cur)<SEQ_LEN: cur=cur+[pad_id]*(SEQ_LEN-len(cur)) |
|
|
x=jnp.array([cur],dtype=jnp.int32) |
|
|
logits=model.apply({"params":params},x,deterministic=True)[0,len(generated)-1] |
|
|
logits=logits.at[end_id].add(-5.0).at[pad_id].add(-10.0) |
|
|
next_id=top_p_sample(rng,logits,p,temperature) |
|
|
generated.append(next_id) |
|
|
if next_id==end_id and len(generated)>=min_len: break |
|
|
return sp.decode(generated) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rng=random.PRNGKey(SEED) |
|
|
rng,init_rng=random.split(rng) |
|
|
model=ReLM(vocab_size=vocab_size,max_seq_len=SEQ_LEN,d_model=512,n_layers=9,dtype=DTYPE) |
|
|
state=create_train_state(init_rng,model,LEARNING_RATE) |
|
|
state=jax.device_put_replicated(state,jax.local_devices()) |
|
|
|
|
|
global_step=0 |
|
|
for epoch in range(EPOCHS): |
|
|
print(f"Epoch {epoch+1}/{EPOCHS}") |
|
|
np_rng=np.random.default_rng(SEED+epoch) |
|
|
batch_iter=create_batch_iter(inputs,targets,GLOBAL_BATCH,np_rng) |
|
|
pbar=tqdm.tqdm(batch_iter,total=max(1,inputs.shape[0]//GLOBAL_BATCH)) |
|
|
for bx,by in pbar: |
|
|
bx_sh,by_sh=shard(bx),shard(by) |
|
|
state,metrics=train_step(state,bx_sh,by_sh,jax.random.split(rng,NUM_DEVICES)) |
|
|
m=jax.tree_util.tree_map(lambda x:x[0],metrics) |
|
|
pbar.set_postfix(loss=float(m["loss"]),ppl=float(m["ppl"])) |
|
|
global_step+=1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
save_dir="./checkpoints" |
|
|
os.makedirs(save_dir,exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import jax.tree_util |
|
|
checkpoints.save_checkpoint(save_dir, jax.tree_util.tree_map(lambda x: np.array(x), state), step=global_step, keep=3) |
|
|
|
|
|
print("Saved checkpoint to",save_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n\n===== ์์ฑ ๊ฒฐ๊ณผ =====") |
|
|
print(generate_text(state,"์ง๋ 2๋
๋์ ์ถ์ฐ์ฐ์ด ๊ตญ๊ฐ๊ฐ ํ์ํ ์ฐ๊ตฌ๋ฅผ",p=0.9)) |
|
|
|