| import os, sys, json, yaml, math, time |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader, IterableDataset |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import LambdaLR |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) |
| from scripts.model_tiny import TinyModel, restore_from_v2 |
|
|
|
|
| def _format_item(item, tokenizer): |
| if "text" in item: |
| return item["text"] |
| if "messages" in item: |
| return tokenizer.apply_chat_template(item["messages"], tokenize=False) |
| system = item.get("system", "") |
| inp = item.get("input", "") |
| instruction = item.get("instruction", "") |
| output = item.get("output", "") |
| user_msg = instruction + ("\n" + inp if inp else "") |
| return f"<|system|>\n{system}\n<|user|>\n{user_msg}\n<|assistant|>\n{output}" |
|
|
|
|
| def _encode(text, tokenizer, max_seq_len): |
| enc = tokenizer.encode(text) |
| if len(enc) > max_seq_len: |
| enc = enc[:max_seq_len] |
| return torch.tensor(enc, dtype=torch.long) |
|
|
|
|
| class StreamingSFTDataset(IterableDataset): |
| def __init__(self, hf_repo, tokenizer, max_seq_len=2048, split="train", hf_name=None): |
| self.hf_repo = hf_repo |
| self.hf_name = hf_name |
| self.split = split |
| self.tokenizer = tokenizer |
| self.max_seq_len = max_seq_len |
|
|
| def __iter__(self): |
| from datasets import load_dataset |
| ds = load_dataset(self.hf_repo, name=self.hf_name, split=self.split, streaming=True) |
| for item in ds: |
| yield _encode(_format_item(item, self.tokenizer), self.tokenizer, self.max_seq_len) |
|
|
|
|
| class ListDataset(Dataset): |
| def __init__(self, samples): |
| self.samples = samples |
| def __len__(self): |
| return len(self.samples) |
| def __getitem__(self, idx): |
| return self.samples[idx] |
|
|
|
|
| def make_eval_dataset(hf_repo, tokenizer, max_seq_len, num_eval=500, hf_name=None): |
| from datasets import load_dataset |
| ds = load_dataset(hf_repo, name=hf_name, split="train", streaming=True) |
| samples = [] |
| for item in ds: |
| if len(samples) >= num_eval: |
| break |
| samples.append(_encode(_format_item(item, tokenizer), tokenizer, max_seq_len)) |
| return ListDataset(samples) |
|
|
|
|
| def collate_fn(batch): |
| max_len = max(len(x) for x in batch) |
| padded = torch.full((len(batch), max_len), fill_value=0, dtype=torch.long) |
| labels = torch.full((len(batch), max_len), fill_value=-100, dtype=torch.long) |
| for i, seq in enumerate(batch): |
| l = len(seq) |
| padded[i, :l] = seq |
| labels[i, :l] = seq |
| return padded, labels |
|
|
|
|
| @torch.no_grad() |
| def compute_metrics(logits, labels, ignore_index=-100): |
| B, T, V = logits.shape |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = F.cross_entropy(shift_logits.view(-1, V), shift_labels.view(-1), ignore_index=ignore_index, reduction='mean') |
| ppl = torch.exp(loss.double()) |
| logits_stable = shift_logits - shift_logits.logsumexp(dim=-1, keepdim=True) |
| probs = torch.exp(logits_stable) |
| entropy = -(probs * logits_stable).sum(-1).mean() |
| preds = shift_logits.argmax(dim=-1) |
| mask = shift_labels != ignore_index |
| acc = (preds == shift_labels)[mask].float().mean() if mask.any() else torch.tensor(0.0) |
| return loss.item(), ppl.item(), entropy.item(), acc.item() |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loader, device): |
| model.eval() |
| total_loss = 0 |
| total_ppl = 0 |
| total_entropy = 0 |
| total_acc = 0 |
| total_tokens = 0 |
| n = 0 |
| for x, y in loader: |
| x, y = x.to(device), y.to(device) |
| logits, loss = model(x, labels=y) |
| loss_val, ppl_val, ent_val, acc_val = compute_metrics(logits, y) |
| batch_tokens = (y != -100).sum().item() |
| total_loss += loss_val * x.size(0) |
| total_ppl += ppl_val * x.size(0) |
| total_entropy += ent_val * x.size(0) |
| total_acc += acc_val * x.size(0) |
| total_tokens += batch_tokens |
| n += x.size(0) |
| return { |
| "loss": total_loss / max(n, 1), |
| "ppl": total_ppl / max(n, 1), |
| "entropy": total_entropy / max(n, 1), |
| "acc": total_acc / max(n, 1), |
| "tokens": total_tokens, |
| } |
|
|
|
|
| def upload_to_hf(local_dir, repo_id, token, files): |
| from huggingface_hub import HfApi |
| api = HfApi() |
| for f in files: |
| path = os.path.join(local_dir, f) |
| if os.path.exists(path): |
| api.upload_file(path_or_fileobj=path, path_in_repo=f, repo_id=repo_id, token=token) |
| print(f" Uploaded {f}") |
|
|
|
|
| def train(): |
| config_path = os.path.join(os.path.dirname(__file__), "..", "config", "train_tiny.yaml") |
| with open(config_path) as f: |
| cfg = yaml.safe_load(f) |
|
|
| t_cfg = cfg["training"] |
| d_cfg = cfg["data"] |
|
|
| device = "cuda" if torch.cuda.is_available() and not t_cfg.get("use_cpu", False) else "cpu" |
| print(f"Device: {device}") |
|
|
| |
| if not os.environ.get("HF_TOKEN"): |
| try: |
| from kaggle_secrets import UserSecretsClient |
| secret = UserSecretsClient().get_secret("HF_TOKEN") |
| if secret: |
| os.environ["HF_TOKEN"] = secret |
| except Exception: |
| pass |
|
|
| |
| from tokenizers import Tokenizer as Tk |
| from transformers import PreTrainedTokenizerFast |
| tok_obj = Tk.from_file(os.path.join(os.path.dirname(__file__), "..", "tokenizer", "tokenizer.json")) |
| tok = PreTrainedTokenizerFast(tokenizer_object=tok_obj) |
| tok.add_special_tokens({"pad_token": "<pad>", "bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>"}) |
| tok.chat_template = "{% for message in messages %}{% if message['role'] == 'system' %}<|system|>\n{{ message['content'] }}\n{% elif message['role'] == 'user' %}<|user|>\n{{ message['content'] }}\n{% elif message['role'] == 'assistant' %}<|assistant|>\n{{ message['content'] }}\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>\n{% endif %}" |
|
|
| vocab_size = tok.vocab_size |
| max_seq = d_cfg.get("max_seq_length", 2048) |
|
|
| |
| qlora_cfg = cfg.get("qlora", {}) |
| cft_cfg = cfg.get("cft", {}) |
| qlora_enabled = qlora_cfg.get("enabled", False) |
| cft_enabled = cft_cfg.get("enabled", False) |
|
|
| |
| if cft_enabled: |
| |
| resume_ckpt = cft_cfg.get("resume_checkpoint", "") |
| if resume_ckpt and not os.path.exists(resume_ckpt): |
| |
| print(f"[CFT] Downloading {resume_ckpt} from HF ...") |
| from huggingface_hub import hf_hub_download |
| ckpt = hf_hub_download("samcheng0/lumia-tiny", resume_ckpt, repo_type="model") |
| import shutil |
| shutil.copy(ckpt, resume_ckpt) |
| print(f"[CFT] Downloaded to {resume_ckpt}") |
|
|
| model = TinyModel(vocab_size=vocab_size, hidden=128, code_dim=96, |
| num_layers=6, num_heads=8, num_kv_heads=4, |
| max_seq_len=max_seq, tie_weights=True) |
| model.reset_weights() |
|
|
| if resume_ckpt and os.path.exists(resume_ckpt): |
| raw = torch.load(resume_ckpt, map_location="cpu", weights_only=True) |
| sd = raw if "model" not in raw else raw["model"] |
| missing, unexpected = model.load_state_dict(sd, strict=False) |
| if missing: |
| print(f" Missing keys: {len(missing)} (e.g. {missing[:3]})") |
| if unexpected: |
| print(f" Unexpected keys: {len(unexpected)} (e.g. {unexpected[:3]})") |
| print(f"[CFT] Loaded checkpoint from {resume_ckpt}") |
|
|
| if cft_cfg.get("reset_embeddings", False): |
| nn.init.normal_(model.token_embed.weight, std=0.02) |
| print(f"[CFT] Reset token embeddings (incompatible vocab)") |
| if model.lm_head.weight is not model.token_embed.weight: |
| nn.init.normal_(model.lm_head.weight, std=0.02) |
| else: |
| |
| if not os.path.exists("checkpoint.pt"): |
| print("[Restore] Downloading checkpoint.pt from samcheng0/lumia-tiny ...") |
| from huggingface_hub import hf_hub_download |
| ckpt = hf_hub_download("samcheng0/lumia-tiny", "checkpoint.pt", repo_type="model") |
| import shutil |
| shutil.copy(ckpt, "checkpoint.pt") |
| print(f"[Restore] Downloaded to checkpoint.pt") |
| model = restore_from_v2("checkpoint.pt") |
|
|
| model = model.to(device) |
|
|
| |
| if qlora_enabled: |
| from scripts.model_tiny import apply_qlora |
| model = apply_qlora(model, |
| r=qlora_cfg.get("r", 8), |
| alpha=qlora_cfg.get("alpha", 16), |
| dropout=qlora_cfg.get("dropout", 0.0)) |
| |
| if cft_cfg.get("reset_embeddings", False): |
| for name, param in model.named_parameters(): |
| if "token_embed" in name or "lm_head" in name: |
| param.requires_grad = True |
|
|
| |
| if t_cfg.get("compile", False): |
| try: |
| model = torch.compile(model, dynamic=True) |
| print("[Compile] torch.compile enabled (dynamic=True)") |
| except Exception as e: |
| print(f"[Compile] Failed: {e} — continuing without compile") |
|
|
| hf_repo = d_cfg.get("hf_repo") |
| is_streaming = bool(hf_repo) |
|
|
| if is_streaming: |
| train_ds = StreamingSFTDataset(hf_repo, tok, max_seq, split=d_cfg.get("hf_split", "train"), hf_name=d_cfg.get("hf_repo_name")) |
| eval_ds = make_eval_dataset(hf_repo, tok, max_seq, num_eval=d_cfg.get("hf_num_eval", 500), hf_name=d_cfg.get("hf_repo_name")) |
| print(f"Train: streaming from {hf_repo} Eval: {len(eval_ds.samples)} samples") |
| else: |
| path = d_cfg["train_file"] |
| with open(path) as f: |
| raw = [json.loads(line) for line in f] |
| split = int(len(raw) * (1 - d_cfg.get("eval_split_ratio", 0.1))) |
| train_raw = raw[:split] |
| eval_raw = raw[split:] |
| train_ds = ListDataset([_encode(_format_item(x, tok), tok, max_seq) for x in train_raw]) |
| eval_ds = ListDataset([_encode(_format_item(x, tok), tok, max_seq) for x in eval_raw]) |
| print(f"Train: {len(train_ds.samples)} Eval: {len(eval_ds.samples)}") |
|
|
| bs = t_cfg.get("per_device_train_batch_size", 8) |
| eval_bs = t_cfg.get("per_device_eval_batch_size", 8) |
| ga_steps = t_cfg.get("gradient_accumulation_steps", 4) |
| lr = t_cfg.get("learning_rate", 3e-4) |
| epochs = t_cfg.get("num_train_epochs", 1) |
| max_grad_norm = t_cfg.get("max_grad_norm", 1.0) |
| log_steps = t_cfg.get("logging_steps", 5) |
| save_steps = t_cfg.get("save_steps", 200) |
| output_dir = t_cfg.get("output_dir", "outputs/tiny-sft") |
|
|
| train_loader = DataLoader(train_ds, batch_size=bs, shuffle=not is_streaming, collate_fn=collate_fn, num_workers=0) |
| eval_loader = DataLoader(eval_ds, batch_size=eval_bs, shuffle=False, collate_fn=collate_fn, num_workers=0) |
|
|
| |
| hf_repo_id = t_cfg.get("hf_repo_id", "") |
| hf_token = os.environ.get("HF_TOKEN", "") |
|
|
| n_params = sum(p.numel() for p in model.parameters()) |
|
|
| step = 0 |
| opt_step = 0 |
| start_epoch = 1 |
| best_loss = float('inf') |
| best_acc = 0.0 |
| cumul_tokens = 0 |
| global_step_offset = 0 |
| ckpt_path = os.path.join(output_dir, "checkpoint.pt") |
|
|
| |
| decay_params = [] |
| no_decay_params = [] |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if "norm" in name or "ln_" in name or "lora_" in name: |
| no_decay_params.append(param) |
| else: |
| decay_params.append(param) |
| if not decay_params and not no_decay_params: |
| print("[WARNING] No trainable parameters found!") |
| optimizer = AdamW([ |
| {"params": decay_params, "weight_decay": t_cfg.get("weight_decay", 0.1)}, |
| {"params": no_decay_params, "weight_decay": 0.0}, |
| ], lr=lr, betas=(0.9, 0.95)) |
|
|
| warmup_ratio = t_cfg.get("warmup_ratio", 0.1) |
| max_steps_cfg = t_cfg.get("max_steps", -1) |
| if is_streaming and max_steps_cfg > 0: |
| total_model_steps = max_steps_cfg |
| elif not is_streaming: |
| total_model_steps = len(train_ds) // bs * epochs |
| else: |
| total_model_steps = 50000 |
|
|
| total_opt_steps = total_model_steps // ga_steps |
| warmup_opt_steps = int(total_opt_steps * warmup_ratio) |
|
|
| def _lr_lambda(current_step): |
| if current_step < warmup_opt_steps: |
| return float(current_step) / max(1, warmup_opt_steps) |
| progress = float(current_step - warmup_opt_steps) / max(1, total_opt_steps - warmup_opt_steps) |
| return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) |
|
|
| scheduler = LambdaLR(optimizer, lr_lambda=_lr_lambda) |
|
|
| if os.path.exists(ckpt_path): |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) |
| model.load_state_dict(ckpt["model"]) |
| optimizer.load_state_dict(ckpt["optimizer"]) |
| scheduler.load_state_dict(ckpt["scheduler"]) |
| start_epoch = ckpt["epoch"] + 1 |
| global_step_offset = ckpt["step"] |
| opt_step = ckpt.get("opt_step", 0) |
| best_loss = ckpt.get("best_loss", float('inf')) |
| best_acc = ckpt.get("best_acc", 0.0) |
| cumul_tokens = ckpt.get("cumul_tokens", 0) |
| print(f"Resumed: epoch={ckpt['epoch']}, step={ckpt['step']}, best_loss={best_loss:.4f}, best_acc={best_acc:.3f}, total_tok={cumul_tokens}") |
|
|
| os.makedirs(output_dir, exist_ok=True) |
| t_start = time.time() |
| t_epoch = time.time() |
|
|
| eff_bs = bs * ga_steps |
| print() |
| print("=" * 72) |
| print(f" PCT-V3 Training") |
| print(f" Model: {n_params:,} params | Device: {device.upper()}") |
| print(f" Vocab: {vocab_size} | Depth: 6 | Hidden: 128 | Code: 96 | Heads: 8/4 | RPW+GPP+VCR") |
| print(f" Epochs: {epochs} | Batch: {bs} | GA: {ga_steps} | Eff BS: {eff_bs} | LR: {lr}") |
| print(f" Warmup: {warmup_opt_steps} opt steps | Total: {total_opt_steps} opt steps") |
| print("=" * 72) |
| print(f" {'Time':>8} {'Ep':>2} {'Step':>8} {'Loss':>9} {'PPL':>9} {'Entropy':>9} {'Acc':>6} {'GradNorm':>9} {'LR':>10} {'tok/s':>7} {'smp/s':>7}") |
| print(f" {'-'*8:>8} {'-'*2:>2} {'-'*8:>8} {'-'*9:>9} {'-'*9:>9} {'-'*9:>9} {'-'*6:>6} {'-'*9:>9} {'-'*10:>10} {'-'*7:>7} {'-'*7:>7}") |
| print() |
|
|
| for epoch in range(start_epoch, epochs + 1): |
| model.train() |
| epoch_loss = 0 |
| epoch_ppl = 0 |
| epoch_entropy = 0 |
| epoch_acc = 0 |
| epoch_tokens = 0 |
| epoch_steps = 0 |
| optimizer.zero_grad() |
|
|
| for x, y in train_loader: |
| step += 1 |
| epoch_steps += 1 |
| x, y = x.to(device), y.to(device) |
| n_tokens = y.numel() |
| logits, loss = model(x, labels=y) |
| loss_val = loss.item() |
| loss = loss / ga_steps |
| loss.backward() |
|
|
| if step % ga_steps == 0: |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm).item() |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| opt_step += 1 |
| else: |
| grad_norm = 0.0 |
|
|
| if max_steps_cfg > 0 and step >= max_steps_cfg: |
| break |
|
|
| epoch_loss += loss_val |
| epoch_tokens += n_tokens |
| cumul_tokens += n_tokens |
|
|
| if step % log_steps == 0: |
| elapsed = time.time() - t_start |
| tok_per_sec = epoch_tokens / max(time.time() - t_epoch, 1) |
| smp_per_sec = epoch_steps * bs / max(time.time() - t_epoch, 1) |
| cur_lr = optimizer.param_groups[0]['lr'] |
| ts = time.strftime("%H:%M:%S") |
| pct = min(100, step / max_steps_cfg * 100) if max_steps_cfg > 0 else 0 |
|
|
| with torch.no_grad(): |
| met = compute_metrics(logits, y) |
| loss_s, ppl_s, ent_s, acc_s = met |
|
|
| print(f" {ts:>8} {epoch:>2} {step:>8} {loss_s:>9.4f} {ppl_s:>9.2f} {ent_s:>9.4f} {acc_s:>6.3f} {grad_norm:>9.2e} {cur_lr:>10.2e} {tok_per_sec:>7.0f} {smp_per_sec:>7.1f} cum_tok={cumul_tokens}") |
| sys.stdout.flush() |
|
|
| if step % save_steps == 0: |
| global_step = global_step_offset + step |
| eval_metrics = evaluate(model, eval_loader, device) |
| elapsed = time.time() - t_start |
| pct = min(100, step / max_steps_cfg * 100) if max_steps_cfg > 0 else 0 |
| print(f" {'':>8} {'':>2} {'':>8} {'-'*9} {'-'*9} {'-'*9} {'-'*6} {'-'*9} {'-'*10} {'-'*7} {'-'*7}") |
| print(f" -- Eval: loss={eval_metrics['loss']:.4f} ppl={eval_metrics['ppl']:.2f} ent={eval_metrics['entropy']:.4f} acc={eval_metrics['acc']:.3f} tokens={eval_metrics['tokens']} | best={best_loss:.4f} | step={global_step} [{pct:.1f}%] [{elapsed:.0f}s]") |
| sys.stdout.flush() |
|
|
| torch.save({ |
| "model": model.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "scheduler": scheduler.state_dict(), |
| "step": global_step, |
| "opt_step": opt_step, |
| "epoch": epoch, |
| "best_loss": best_loss, |
| "best_acc": best_acc, |
| "cumul_tokens": cumul_tokens, |
| "eval_metrics": eval_metrics, |
| }, ckpt_path) |
| print(f" + Checkpoint saved (step={global_step})") |
| sys.stdout.flush() |
|
|
| if eval_metrics['loss'] < best_loss: |
| best_loss = eval_metrics['loss'] |
| best_acc = eval_metrics['acc'] |
| torch.save(model.state_dict(), os.path.join(output_dir, "best.pt")) |
| print(f" + New best! loss={best_loss:.4f} ppl={eval_metrics['ppl']:.2f} acc={best_acc:.3f}") |
| if hf_repo_id and hf_token: |
| upload_to_hf(output_dir, hf_repo_id, hf_token, ["best.pt"]) |
| print(f" + Uploaded best.pt to {hf_repo_id}") |
| elif eval_metrics['acc'] > best_acc: |
| best_acc = eval_metrics['acc'] |
| sys.stdout.flush() |
|
|
| if max_steps_cfg > 0 and step >= max_steps_cfg: |
| break |
|
|
| t_epoch_end = time.time() |
| avg_loss = epoch_loss / max(epoch_steps, 1) |
| eval_metrics = evaluate(model, eval_loader, device) |
| epoch_time = t_epoch_end - t_epoch |
| t_epoch = t_epoch_end |
| print(f" {'':>8} {'':>2} {'':>8} {'-'*9} {'-'*9} {'-'*9} {'-'*6} {'-'*9} {'-'*10} {'-'*7} {'-'*7}") |
| print(f" -- Epoch {epoch}/{epochs} done | loss={avg_loss:.4f} | eval={eval_metrics['loss']:.4f} | ppl={eval_metrics['ppl']:.2f} | ent={eval_metrics['entropy']:.4f} | acc={eval_metrics['acc']:.3f} | tok={epoch_tokens} | cum_tok={cumul_tokens} | time={epoch_time:.0f}s --") |
| sys.stdout.flush() |
| torch.save(model.state_dict(), os.path.join(output_dir, f"epoch_{epoch}.pt")) |
|
|
| final = os.path.join(output_dir, "final") |
| os.makedirs(final, exist_ok=True) |
| torch.save(model.state_dict(), os.path.join(final, "model.pt")) |
| total_time = time.time() - t_start |
| print() |
| print("=" * 72) |
| print(f" Training complete!") |
| print(f" Total time: {total_time:.0f}s ({total_time/3600:.1f}h)") |
| print(f" Total tokens: {cumul_tokens:,}") |
| print(f" Total steps: {step}") |
| print(f" Best loss: {best_loss:.4f}") |
| print(f" Best acc: {best_acc:.3f}") |
| print(f" Avg tok/s: {cumul_tokens/max(total_time, 1):.0f}") |
| print(f" Saved: {final}/model.pt") |
| print("=" * 72) |
| print() |
|
|
| if hf_repo_id and hf_token: |
| print(f"Uploading to HF: {hf_repo_id} ...") |
| upload_files = ["best.pt", "final/model.pt"] |
| for ep in range(start_epoch, epochs + 1): |
| ep_path = os.path.join(output_dir, f"epoch_{ep}.pt") |
| if os.path.exists(ep_path): |
| upload_files.append(f"epoch_{ep}.pt") |
| upload_to_hf(output_dir, hf_repo_id, hf_token, upload_files) |
| print(f"Uploaded to https://huggingface.co/{hf_repo_id}") |
|
|
| return final |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|