| from tensorflow.keras import mixed_precision
|
|
|
| mixed_precision.set_global_policy('mixed_bfloat16')
|
|
|
| import os
|
| os.environ['TF_CPP_MIN_LOG_LEVEL']='3'
|
|
|
| os.environ.setdefault('HF_HUB_DOWNLOAD_TIMEOUT', '60')
|
|
|
| import tensorflow as tf
|
| import numpy as np
|
| from layers import TransformerBlock,PositionalEmbedding,create_causal_mask
|
| from tokenizer import CharTokenizer,BPE_tokenizer,HFTokenizer
|
|
|
| class training_callback:
|
| def __init__(self,patience=3,min_delta=0.01,save_path="best_model.weights.h5"):
|
| self.patience = patience
|
| self.min_delta = min_delta
|
| self.save_path = save_path
|
| self.best_loss = float("inf")
|
| self.wait = 0
|
| self.stopped_epoch = 0
|
| self.history = []
|
| def on_epoch_end(self,epoch,loss,model):
|
| self.history.append(loss)
|
|
|
|
|
| if loss<self.best_loss-self.min_delta:
|
| print(f"Loss improved{self.best_loss:.4f}->{loss:.4f},saving model")
|
| self.best_loss=loss
|
| model.save_weights(self.save_path)
|
| self.wait=0
|
| else:
|
| self.wait+=1
|
| print(f"No improvement for {self.wait}/{self.patience} epochs")
|
|
|
|
|
| if self.wait>=self.patience:
|
| self.stopped_epoch=epoch
|
| print(f"Early stopping at epoch {epoch+1}, best loss: {self.best_loss:.4f}")
|
| return True
|
| return False
|
| def plot_history(self):
|
| import matplotlib.pyplot as plt
|
| plt.plot(self.history,marker="o")
|
| plt.title("Training Loss")
|
| plt.xlabel("Epoch")
|
| plt.ylabel("Loss")
|
| plt.grid(True)
|
| plt.savefig("loss_curve.png")
|
| print("Loss curve saved to png")
|
|
|
|
|
| class WarmupCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
|
| def __init__(self,peak_lr,total_steps,warmup_steps):
|
| self.peak_lr=peak_lr
|
| self.total_steps=total_steps
|
| self.warmup_steps=warmup_steps
|
|
|
| def __call__(self,step):
|
| step=tf.cast(step,tf.float32)
|
|
|
|
|
| warmup_lr=self.peak_lr*(step/self.warmup_steps)
|
|
|
|
|
| cosine_lr=self.peak_lr*0.1*(
|
| 1+tf.cos(np.pi * step/self.total_steps))/2
|
|
|
| return tf.where(step<self.warmup_steps,warmup_lr,cosine_lr)
|
|
|
|
|
|
|
| class GPT(tf.keras.Model):
|
| """
|
| ORIGINAL ARCHITECTURE: POST-LN AND NO WEIGHT TYING
|
| GPT 1 USED WEIGHT TYING
|
| GPT2 USED POST-LAYERNORM (normalisation after attention)
|
| GPT3 USES PRE-LAYERNORM (normalisation before attention)
|
| """
|
| def __init__(self,vocab_size,d_model,num_heads,dff,num_layers,max_len,dropout_rate=0.1):
|
| super().__init__()
|
| self.num_layers=num_layers
|
| self.max_len=max_len
|
|
|
| self.embedding = PositionalEmbedding(vocab_size,d_model,max_len,rate=dropout_rate)
|
|
|
|
|
| self.blocks=[
|
| TransformerBlock(d_model,num_heads,dff,num_layers,max_len,rate=dropout_rate)
|
| for _ in range(num_layers)
|
| ]
|
|
|
|
|
|
|
|
|
| self.final_ln=tf.keras.layers.LayerNormalization()
|
|
|
|
|
|
|
|
|
| def call(self,x,training=False,caches=None):
|
| seq_len=tf.shape(x)[1]
|
|
|
| if caches is not None and len(caches)>0 and caches[0] is not None and caches[0].get("k") is not None:
|
| cache_len=tf.shape(caches[0]["k"])[2]
|
| else:
|
| cache_len=0
|
|
|
| seq_len_k=seq_len+cache_len
|
| mask=create_causal_mask(seq_len,seq_len_k)
|
|
|
| x=self.embedding(x,training=training)
|
|
|
| if caches is None:
|
| caches=[None for _ in self.blocks]
|
|
|
| new_caches=[]
|
| for block,cache in zip(self.blocks, caches):
|
| x,updated=block(x,mask=mask,cache=cache,training=training)
|
| new_caches.append(updated)
|
|
|
| x = self.final_ln(x)
|
|
|
| logits= tf.matmul(
|
| x,
|
| self.embedding.token_embedding.embeddings,
|
| transpose_b=True
|
| )
|
| logits=tf.cast(logits,tf.float32)
|
|
|
| return logits,new_caches
|
|
|
| def sample_with_temperature(logits,temperature=1.0):
|
|
|
| logits=logits/temperature
|
|
|
| probabilities=tf.nn.softmax(logits)
|
|
|
| next_token=tf.random.categorical(
|
| tf.math.log(probabilities),num_samples=1
|
| )
|
| return tf.squeeze(next_token,axis=-1)
|
|
|
|
|
| def sample_top_k(logits,k=50):
|
| values,indices=tf.math.top_k(logits,k=k)
|
|
|
| min_values=values[:,-1,tf.newaxis]
|
|
|
| logits=tf.where(
|
| logits<min_values,
|
| tf.ones_like(logits)*-1e9,
|
| logits
|
| )
|
| probabilities=tf.nn.softmax(logits)
|
|
|
| next_token=tf.random.categorical(
|
| tf.math.log(probabilities),num_samples=1
|
| )
|
|
|
| return tf.squeeze(next_token,axis=-1)
|
|
|
|
|
| def sample_top_p(logits,p=0.9):
|
| sorted_logits=tf.sort(logits,direction="DESCENDING")
|
| sorted_indices=tf.argsort(logits,direction="DESCENDING")
|
|
|
| cumulative_probs=tf.cumsum(tf.nn.softmax(sorted_logits),axis=-1)
|
|
|
| cutoff=cumulative_probs > p
|
| cutoff = tf.cast(cutoff,tf.int32)
|
|
|
| cutoff_index = tf.argmax(cutoff,axis=-1)
|
|
|
| threshold = tf.gather(
|
| sorted_logits,
|
| cutoff_index,
|
| batch_dims=1
|
| )
|
|
|
| logits =tf.where(
|
| logits<threshold[:,tf.newaxis],
|
| tf.ones_like(logits)*-1e9,
|
| logits)
|
| probabilities = tf.nn.softmax(logits)
|
|
|
| next_token=tf.random.categorical(
|
| tf.math.log(probabilities),num_samples=1
|
| )
|
|
|
| return tf.squeeze(next_token,axis=-1)
|
|
|
|
|
|
|
| def apply_repetition_penalty(logits, generated_ids, penalty):
|
| if penalty is None or penalty == 1.0:
|
| return logits
|
| vocab_size = tf.shape(logits)[-1]
|
| ids = tf.reshape(generated_ids, [-1])
|
| counts = tf.math.bincount(ids, minlength=vocab_size, maxlength=vocab_size)
|
| seen = tf.cast(counts > 0, logits.dtype)[tf.newaxis, :]
|
| pos = logits > 0
|
| penalized = tf.where(pos, logits / penalty, logits * penalty)
|
| return logits * (1.0 - seen) + penalized * seen
|
|
|
|
|
| def generate_text(model,
|
| start_tokens,
|
| max_new_tokens,
|
| temperature=1.0,
|
| top_k=None,
|
| top_p=None,
|
| eos_token_id=None,
|
| repetition_penalty=1.0,
|
| max_len=256):
|
| logits, caches = model(start_tokens, training=False)
|
|
|
| for _ in range(max_new_tokens):
|
| last_token_logits = logits[:, -1, :]
|
| last_token_logits = apply_repetition_penalty(last_token_logits, start_tokens, repetition_penalty)
|
| last_token_logits = last_token_logits / temperature
|
|
|
| if top_k is not None:
|
| next_token = sample_top_k(last_token_logits, k=top_k)
|
| elif top_p is not None:
|
| next_token = sample_top_p(last_token_logits, p=top_p)
|
| else:
|
| next_token = sample_with_temperature(last_token_logits, temperature=1.0)
|
|
|
| next_token = tf.cast(next_token, tf.int32)
|
| next_token = tf.expand_dims(next_token, axis=1)
|
|
|
| start_tokens = tf.concat([start_tokens, next_token], axis=1)
|
|
|
| if eos_token_id is not None and int(next_token[0, 0].numpy()) == int(eos_token_id):
|
| break
|
|
|
| logits, caches = model(next_token, training=False, caches=caches)
|
|
|
| return start_tokens
|
|
|
| def create_sequences(encoded_text,seq_len):
|
| inputs=[]
|
| targets=[]
|
|
|
| for i in range(0,len(encoded_text) - seq_len, seq_len):
|
| input_seq=encoded_text[i:i+seq_len]
|
| target_seq=encoded_text[i+1:i+seq_len+1]
|
|
|
| inputs.append(input_seq)
|
| targets.append(target_seq)
|
| return np.array(inputs),np.array(targets)
|
|
|
|
|
| def stream_documents(dataset_specs, max_docs=None, loop=False):
|
| """
|
| I yield raw text documents from a mix of streaming HF datasets.
|
| dataset_specs: list of (name, config, weight) — weight sets the interleave probability.
|
| I stop after max_docs documents (None = no limit).
|
| loop=True: when the dataset is exhausted I restart it from the beginning, so a
|
| token budget larger than the dataset can still be met (multiple passes).
|
| I retry on transient network errors instead of letting the whole run crash.
|
| """
|
| import time
|
| from datasets import load_dataset, interleave_datasets
|
|
|
| def build_mixed():
|
| streams, probs = [], []
|
| for name, config, weight in dataset_specs:
|
| ds = load_dataset(name, config, split="train", streaming=True)
|
|
|
|
|
| keep = [c for c in ds.column_names if c != "text"] if ds.column_names else []
|
| if keep:
|
| ds = ds.remove_columns(keep)
|
| streams.append(ds)
|
| probs.append(weight)
|
| return interleave_datasets(streams, probabilities=probs, seed=42)
|
|
|
| emitted = 0
|
| attempt = 0
|
| while True:
|
| try:
|
| mixed = build_mixed()
|
| it = iter(mixed)
|
|
|
| for _ in range(emitted):
|
| next(it, None)
|
| for item in it:
|
| if max_docs is not None and emitted >= max_docs:
|
| return
|
| text = item.get("text")
|
| if text:
|
| emitted += 1
|
| yield text
|
|
|
| if loop:
|
| emitted = 0
|
| attempt = 0
|
| continue
|
| return
|
| except Exception as e:
|
|
|
| attempt += 1
|
| wait = min(60, 2 ** attempt)
|
| print(f"[stream] error: {e!r} — retrying in {wait}s (attempt {attempt})", flush=True)
|
| time.sleep(wait)
|
|
|
|
|
| def token_stream(doc_iterator, tokenizer, max_tokens=None):
|
| """
|
| I tokenize each document, append <|endoftext|> as a boundary, and yield token ids
|
| one at a time. I stop after max_tokens ids (None = no limit).
|
| """
|
| eos = tokenizer.eos_id
|
| produced = 0
|
| for doc in doc_iterator:
|
| ids = tokenizer.encode(doc)
|
| if eos is not None:
|
| ids = ids + [eos]
|
| for tid in ids:
|
| yield tid
|
| produced += 1
|
| if max_tokens is not None and produced >= max_tokens:
|
| return
|
|
|
|
|
| def make_tf_dataset(token_iter_fn, seq_len, batch_size):
|
| """
|
| I turn a token-id generator into a batched (x, y) tf.data pipeline.
|
| Each sample: x = ids[i:i+seq_len], y = ids[i+1:i+seq_len+1] -> next-token prediction.
|
| Shapes: x,y = (batch_size, seq_len).
|
| """
|
| def gen():
|
| buffer = []
|
| for tid in token_iter_fn():
|
| buffer.append(tid)
|
|
|
| if len(buffer) == seq_len + 1:
|
| x = np.array(buffer[:-1], dtype=np.int32)
|
| y = np.array(buffer[1:], dtype=np.int32)
|
| yield x, y
|
| buffer = []
|
|
|
| ds = tf.data.Dataset.from_generator(
|
| gen,
|
| output_signature=(
|
| tf.TensorSpec(shape=(seq_len,), dtype=tf.int32),
|
| tf.TensorSpec(shape=(seq_len,), dtype=tf.int32),
|
| ),
|
| )
|
| return ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
|
|
| for _gpu in tf.config.list_physical_devices('GPU'):
|
| tf.config.experimental.set_memory_growth(_gpu, True)
|
|
|
| loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
|
|
|
|
|
|
| accum_steps = 4
|
|
|
| @tf.function
|
| def micro_step(x, y):
|
|
|
| with tf.GradientTape() as tape:
|
| logits, _ = model(x, training=True)
|
| loss = loss_fn(y, logits)
|
| loss = loss / accum_steps
|
|
|
| gradients = tape.gradient(loss, model.trainable_variables)
|
| for buffer, grad in zip(accum_grads, gradients):
|
| if grad is not None:
|
| buffer.assign_add(grad)
|
| return loss * accum_steps
|
|
|
| @tf.function
|
| def apply_accum():
|
|
|
| optimizer.apply_gradients(zip(accum_grads, model.trainable_variables))
|
| for buffer in accum_grads:
|
| buffer.assign(tf.zeros_like(buffer))
|
|
|
|
|
|
|
| seq_len = 256
|
| TOKENIZER_PATH = os.path.join(
|
| os.path.dirname(os.path.abspath(__file__)),
|
| "..", "saved_models", "tinystories_tokenizer.json",
|
| )
|
|
|
|
|
|
|
|
|
| DATASET_MIX = [
|
| ("noanabeshima/TinyStoriesV2", None, 1.0),
|
| ]
|
|
|
| tokenizer = HFTokenizer(vocab_size=10000)
|
| if os.path.exists(TOKENIZER_PATH):
|
| tokenizer.load(TOKENIZER_PATH)
|
| else:
|
|
|
| print("Training HF tokenizer on a sample...")
|
| sample_docs = stream_documents(DATASET_MIX, max_docs=200_000)
|
| tokenizer.train(sample_docs)
|
| tokenizer.save(TOKENIZER_PATH)
|
|
|
| vocab_size = tokenizer.vocab_size
|
| print("Vocab_size: ", vocab_size)
|
|
|
|
|
|
|
| batch_size = 4
|
|
|
|
|
| MAX_TOKENS = 700_000_000
|
|
|
|
|
| effective_batch = batch_size * accum_steps
|
| total_steps = MAX_TOKENS // (seq_len * effective_batch)
|
| warmup_steps = total_steps // 10
|
| print("Planned optimizer updates:", total_steps)
|
|
|
| lr_schedule = WarmupCosineDecay(peak_lr=3e-4,
|
| total_steps=total_steps,
|
| warmup_steps=warmup_steps)
|
|
|
| optimizer = tf.keras.optimizers.Adam(lr_schedule,
|
| beta_1=0.9,
|
| beta_2=0.95,
|
| epsilon=1e-8,
|
| clipnorm=1.0)
|
|
|
|
|
|
|
|
|
| VAL_TOKENS = 2_000_000
|
|
|
| def train_token_iter():
|
|
|
| docs = stream_documents(DATASET_MIX, loop=True)
|
|
|
| stream = token_stream(docs, tokenizer, max_tokens=MAX_TOKENS + VAL_TOKENS)
|
| for idx, tid in enumerate(stream):
|
| if idx >= VAL_TOKENS:
|
| yield tid
|
|
|
| dataset = make_tf_dataset(train_token_iter, seq_len, batch_size)
|
|
|
|
|
|
|
|
|
| print("Building in-memory validation set...")
|
| _val_ids = []
|
| for _tid in token_stream(stream_documents(DATASET_MIX), tokenizer, max_tokens=VAL_TOKENS):
|
| _val_ids.append(_tid)
|
| _n = (len(_val_ids) - 1) // seq_len
|
| _vx = np.array([_val_ids[k*seq_len:(k+1)*seq_len] for k in range(_n)], dtype=np.int32)
|
| _vy = np.array([_val_ids[k*seq_len+1:(k+1)*seq_len+1] for k in range(_n)], dtype=np.int32)
|
| val_dataset = (tf.data.Dataset.from_tensor_slices((_vx, _vy))
|
| .batch(batch_size).prefetch(tf.data.AUTOTUNE))
|
| print(f"Validation set: {len(_val_ids)} tokens -> {_n} sequences")
|
|
|
| model = GPT(vocab_size=vocab_size,
|
| d_model=640,
|
| num_heads=10,
|
| dff=2560,
|
| num_layers=10,
|
| max_len=seq_len)
|
|
|
|
|
| model(tf.zeros((1, seq_len), dtype=tf.int32), training=False)
|
|
|
|
|
| accum_grads = [
|
| tf.Variable(tf.zeros_like(v), trainable=False)
|
| for v in model.trainable_variables
|
| ]
|
|
|
|
|
|
|
| optimizer.build(model.trainable_variables)
|
|
|
| @tf.function
|
| def val_step(x, y):
|
|
|
| logits, _ = model(x, training=False)
|
| return loss_fn(y, logits)
|
|
|
| def evaluate(max_batches=200):
|
| losses = []
|
| for j, (vx, vy) in enumerate(val_dataset):
|
| if j >= max_batches:
|
| break
|
| losses.append(val_step(vx, vy).numpy())
|
| return float(np.mean(losses)) if losses else float("inf")
|
|
|
| SAVE_PATH = os.path.join(
|
| os.path.dirname(os.path.abspath(__file__)),
|
| "..", "saved_models", "tinystories_model.weights.h5",
|
| )
|
| EVAL_EVERY = 1000
|
|
|
|
|
|
|
|
|
|
|
|
|
| ckpt_step = tf.Variable(0, dtype=tf.int64, trainable=False)
|
| best_val_var = tf.Variable(float("inf"), dtype=tf.float32, trainable=False)
|
| CKPT_DIR = os.path.join(
|
| os.path.dirname(os.path.abspath(__file__)),
|
| "..", "saved_models", "ckpt_tinystories",
|
| )
|
| ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer,
|
| step=ckpt_step, best_val=best_val_var)
|
| ckpt_mgr = tf.train.CheckpointManager(ckpt, CKPT_DIR, max_to_keep=2)
|
|
|
| start_step = 0
|
| if ckpt_mgr.latest_checkpoint:
|
| ckpt.restore(ckpt_mgr.latest_checkpoint).expect_partial()
|
| start_step = int(ckpt_step.numpy())
|
| best_val = float(best_val_var.numpy())
|
| print(f"Resumed from {ckpt_mgr.latest_checkpoint} at update {start_step}, best_val {best_val:.4f}")
|
| else:
|
| best_val = float("inf")
|
| print("No checkpoint found, starting fresh.")
|
|
|
|
|
|
|
|
|
| update_step = start_step
|
| micro_count = 0
|
| skip_micro = start_step * accum_steps
|
|
|
| for i, (batch_x, batch_y) in enumerate(dataset):
|
|
|
| if i < skip_micro:
|
| continue
|
|
|
|
|
| loss = micro_step(batch_x, batch_y)
|
| micro_count += 1
|
|
|
| if micro_count % accum_steps == 0:
|
| apply_accum()
|
| update_step += 1
|
| ckpt_step.assign(update_step)
|
|
|
| if update_step % 50 == 0:
|
| print(f"update {update_step}/{total_steps}, loss: {loss.numpy():.4f}")
|
|
|
|
|
| if update_step % EVAL_EVERY == 0:
|
| val_loss = evaluate()
|
| print(f" [eval] update {update_step}, val_loss: {val_loss:.4f}")
|
| if val_loss < best_val:
|
| print(f" val improved {best_val:.4f} -> {val_loss:.4f}, saving")
|
| best_val = val_loss
|
| best_val_var.assign(best_val)
|
| model.save_weights(SAVE_PATH)
|
|
|
| ckpt_mgr.save()
|
|
|
| if update_step >= total_steps:
|
| break
|
|
|
| print("Training complete. best val_loss:", best_val)
|
|
|