# TPU 최적화 Flax + JAX ReLM import os, math, numpy as np, sentencepiece as spm, requests, tqdm from functools import partial from typing import Any import jax, jax.numpy as jnp from jax import random from flax import linen as nn from flax.training import train_state, checkpoints import optax import requests def download_file(url, save_path): r = requests.get(url, stream=True) r.raise_for_status() with open(save_path, "wb") as f: for chunk in r.iter_content(8192*2): f.write(chunk) print(f"✅ {save_path} 저장됨") # ------------------ # Config # ------------------ SEQ_LEN = 512 GLOBAL_BATCH = 256 LIMIT = 200_000 VOCAB_MODEL = "ko_unigram.model" CORPUS_PATH = "corpus.txt" SEED = 42 LEARNING_RATE = 1e-4 EPOCHS = 1 if not os.path.exists(CORPUS_PATH): download_file( "https://huggingface.co/datasets/Yuchan5386/Prototype/resolve/main/corpus_ko.txt?download=true", CORPUS_PATH ) if not os.path.exists(VOCAB_MODEL): download_file( "https://huggingface.co/Yuchan5386/inlam-100m/resolve/main/ko_unigram.model?download=true", VOCAB_MODEL ) DTYPE = jnp.bfloat16 if jax.local_devices()[0].platform == "tpu" else jnp.float32 NUM_DEVICES = jax.device_count() PER_DEVICE_BATCH = GLOBAL_BATCH // NUM_DEVICES print("devices:", jax.devices(), "dtype:", DTYPE) # ------------------ # Tokenizer # ------------------ sp = spm.SentencePieceProcessor() sp.load(VOCAB_MODEL) pad_id = sp.piece_to_id("") if sp.piece_to_id("")!=-1 else 0 start_id = sp.piece_to_id("") end_id = sp.piece_to_id("") vocab_size = sp.get_piece_size() print("vocab_size:", vocab_size, "pad_id:", pad_id, "start_id:", start_id, "end_id:", end_id) # ------------------ # Data pipeline # ------------------ def line_to_ids(line, max_len=SEQ_LEN): ids = sp.encode(line.strip(), out_type=int) if len(ids) > max_len-1: ids = ids[:max_len-1] ids += [end_id] + [pad_id]*(max_len-len(ids)-1) return np.array(ids, dtype=np.int32) def build_dataset(corpus_path, limit=LIMIT): arr = [] with open(corpus_path, "r", encoding="utf-8") as f: for i, line in enumerate(f): if i>=limit: break line=line.strip() if not line: continue arr.append(line_to_ids(line)) data = np.stack(arr, axis=0) print("Loaded dataset:", data.shape) return data data_np = build_dataset(CORPUS_PATH, LIMIT) inputs = data_np targets = np.concatenate([data_np[:,1:], np.full((data_np.shape[0],1), pad_id, np.int32)], axis=1) def create_batch_iter(inputs, targets, batch_size, rng): idx = np.arange(inputs.shape[0]); rng.shuffle(idx) for i in range(0,len(idx)-batch_size+1,batch_size): batch_idx = idx[i:i+batch_size] yield inputs[batch_idx], targets[batch_idx] def shard(xs): return xs.reshape(NUM_DEVICES, -1, xs.shape[1]) class SwiGLU(nn.Module): d_model: int @nn.compact def __call__(self, x): x_f32 = x.astype(jnp.float32) proj = nn.Dense(self.d_model*2, dtype=jnp.float32)(x_f32) x_val, x_gate = jnp.split(proj, 2, axis=-1) out = x_val * nn.silu(x_gate) out = nn.Dense(self.d_model, dtype=jnp.float32)(out) return out.astype(x.dtype) class LoU(nn.Module): d_model: int clip_value: float = 5.0 eps: float = 1e-6 @nn.compact def __call__(self, x): x_f32 = x.astype(jnp.float32) residual = x_f32 x_norm = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)(x_f32) Q = nn.Dense(self.d_model, dtype=jnp.float32) K = nn.Dense(self.d_model, dtype=jnp.float32) V = nn.Dense(self.d_model, dtype=jnp.float32) q,k,v = Q(x_norm), K(x_norm), V(x_norm) g_q = (jnp.tanh(q)+1)/2 g_k = (jnp.tanh(k)+1)/2 score = g_q * g_k alpha_dynamic = nn.Dense(1, dtype=jnp.float32)(x_norm) # EMA scan along seq axis score_t = jnp.transpose(score,(1,0,2)) alpha_t = jnp.transpose(alpha_dynamic,(1,0,2)) def step(prev, cur): s, a = cur new = a*s + (1-a)*prev return new,new init = score_t[0] _, ema_seq = jax.lax.scan(step, init, (score_t[1:], alpha_t[1:])) ema_full = jnp.concatenate([init[None,...], ema_seq], 0) ema = jnp.transpose(ema_full,(1,0,2)) out = v * ema + residual out = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)(out) return SwiGLU(self.d_model)(out).astype(x.dtype) class Lo(nn.Module): d_model:int dtype:Any=DTYPE @nn.compact def __call__(self,x): h=nn.Dense(64,dtype=self.dtype)(x); h=nn.silu(h) h=nn.Dense(self.d_model,dtype=self.dtype)(h) return nn.LayerNorm(epsilon=1e-5,dtype=self.dtype)(h)+x class Block(nn.Module): d_model:int dtype:Any=DTYPE @nn.compact def __call__(self,x): x=LoU(self.d_model,self.dtype)(x) x=Lo(self.d_model,self.dtype)(x) return x class ReLM(nn.Module): vocab_size:int; max_seq_len:int; d_model:int; n_layers:int; dtype:Any=DTYPE def setup(self): self.token_embed = nn.Embed(self.vocab_size,self.d_model,dtype=self.dtype) self.pos_embed = nn.Embed(self.max_seq_len,self.d_model,dtype=self.dtype) self.blocks=[Block(self.d_model,self.dtype) for _ in range(self.n_layers)] self.ln_f=nn.LayerNorm(epsilon=1e-5,dtype=self.dtype) def __call__(self,x,deterministic=True): b,seq=x.shape pos=jnp.arange(seq)[None,:] x=self.token_embed(x)+self.pos_embed(pos) for blk in self.blocks: x=blk(x) x=self.ln_f(x) logits=jnp.einsum("bld,vd->blv",x,self.token_embed.embedding) return logits def smoothed_ce(logits, targets, pad_id, eps=0.1): logits = logits.astype(jnp.float32) targets = targets.astype(jnp.int32) vocab = logits.shape[-1] mask = (targets != pad_id).astype(jnp.float32) one_hot = jax.nn.one_hot(targets, vocab) smooth = (1-eps)*one_hot + eps/vocab log_probs = jax.nn.log_softmax(logits, axis=-1) loss = -jnp.sum(smooth * log_probs, axis=-1) * mask return jnp.sum(loss) / (jnp.sum(mask)+1e-8) def masked_ppl(logits, targets, pad_id, eps=0.1): logits = logits.astype(jnp.float32) targets = targets.astype(jnp.int32) vocab = logits.shape[-1] mask = (targets != pad_id).astype(jnp.float32) one_hot = jax.nn.one_hot(targets, vocab) smooth = (1-eps)*one_hot + eps/vocab log_probs = jax.nn.log_softmax(logits, axis=-1) loss = -jnp.sum(smooth*log_probs, axis=-1) * mask return jnp.exp(jnp.sum(loss)/(jnp.sum(mask)+1e-8)) # ------------------ # Train state # ------------------ class TrainState(train_state.TrainState): pass def create_train_state(rng,model,lr): params=model.init(rng,jnp.zeros((1,SEQ_LEN),dtype=jnp.int32))["params"] tx=optax.chain(optax.clip_by_global_norm(1.0),optax.adamw(lr,b1=0.9,b2=0.95,eps=1e-8)) return TrainState.create(apply_fn=model.apply,params=params,tx=tx) # ------------------ # pmap step # ------------------ @partial(jax.pmap, axis_name="batch") def train_step(state,bx,by,rngs): def loss_fn(params): logits=state.apply_fn({"params":params},bx,deterministic=False) return smoothed_ce(logits,by,pad_id),logits (loss,logits),grads=jax.value_and_grad(loss_fn,has_aux=True)(state.params) grads=jax.lax.pmean(grads,"batch") state=state.apply_gradients(grads=grads) metrics={"loss":loss,"ppl":masked_ppl(logits,by,pad_id)} metrics=jax.lax.pmean(metrics,"batch") return state,metrics # ------------------ # Top-p sampling (JAX-native) # ------------------ def top_p_sample(rng, logits, p=0.9, temperature=1.0): probs=jax.nn.softmax(logits/temperature) sorted_probs,sorted_idx=jax.lax.top_k(probs,logits.shape[-1]) cum_probs=jnp.cumsum(sorted_probs) mask=cum_probs<=p top_probs=jnp.where(mask,sorted_probs,0.0) top_probs=top_probs/jnp.sum(top_probs) return int(sorted_idx[jax.random.categorical(rng,jnp.log(top_probs))]) def generate_text(state,prompt,max_gen=256,p=0.9,temperature=0.8,min_len=20): params=jax.tree_map(lambda x: np.array(x[0]),state.params) tokens=sp.encode(" "+prompt,out_type=int) generated=tokens.copy() rng=random.PRNGKey(SEED) for step in range(max_gen): cur=generated[-SEQ_LEN:] if len(cur)=min_len: break return sp.decode(generated) # ------------------ # Training # ------------------ rng=random.PRNGKey(SEED) rng,init_rng=random.split(rng) model=ReLM(vocab_size=vocab_size,max_seq_len=SEQ_LEN,d_model=512,n_layers=9,dtype=DTYPE) state=create_train_state(init_rng,model,LEARNING_RATE) state=jax.device_put_replicated(state,jax.local_devices()) global_step=0 for epoch in range(EPOCHS): print(f"Epoch {epoch+1}/{EPOCHS}") np_rng=np.random.default_rng(SEED+epoch) batch_iter=create_batch_iter(inputs,targets,GLOBAL_BATCH,np_rng) pbar=tqdm.tqdm(batch_iter,total=max(1,inputs.shape[0]//GLOBAL_BATCH)) for bx,by in pbar: bx_sh,by_sh=shard(bx),shard(by) state,metrics=train_step(state,bx_sh,by_sh,jax.random.split(rng,NUM_DEVICES)) m=jax.tree_util.tree_map(lambda x:x[0],metrics) pbar.set_postfix(loss=float(m["loss"]),ppl=float(m["ppl"])) global_step+=1 # ------------------ # Save # ------------------ save_dir="./checkpoints" os.makedirs(save_dir,exist_ok=True) # 기존 # checkpoints.save_checkpoint(save_dir,jax.tree_map(lambda x:np.array(x),state),step=global_step,keep=3) # 수정 import jax.tree_util checkpoints.save_checkpoint(save_dir, jax.tree_util.tree_map(lambda x: np.array(x), state), step=global_step, keep=3) print("Saved checkpoint to",save_dir) # ------------------ # Generate # ------------------ print("\n\n===== 생성 결과 =====") print(generate_text(state,"지난 2년 동안 출연연이 국가가 필요한 연구를",p=0.9))