lumia-tiny / train_tiny.py
samcheng0's picture
Upload train_tiny.py with huggingface_hub
d1fa37b verified
Raw
History Blame Contribute Delete
21.4 kB
import os, sys, json, yaml, math, time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, IterableDataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from scripts.model_tiny import TinyModel, restore_from_v2
def _format_item(item, tokenizer):
if "text" in item:
return item["text"]
if "messages" in item:
return tokenizer.apply_chat_template(item["messages"], tokenize=False)
system = item.get("system", "")
inp = item.get("input", "")
instruction = item.get("instruction", "")
output = item.get("output", "")
user_msg = instruction + ("\n" + inp if inp else "")
return f"<|system|>\n{system}\n<|user|>\n{user_msg}\n<|assistant|>\n{output}"
def _encode(text, tokenizer, max_seq_len):
enc = tokenizer.encode(text)
if len(enc) > max_seq_len:
enc = enc[:max_seq_len]
return torch.tensor(enc, dtype=torch.long)
class StreamingSFTDataset(IterableDataset):
def __init__(self, hf_repo, tokenizer, max_seq_len=2048, split="train", hf_name=None):
self.hf_repo = hf_repo
self.hf_name = hf_name
self.split = split
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
def __iter__(self):
from datasets import load_dataset
ds = load_dataset(self.hf_repo, name=self.hf_name, split=self.split, streaming=True)
for item in ds:
yield _encode(_format_item(item, self.tokenizer), self.tokenizer, self.max_seq_len)
class ListDataset(Dataset):
def __init__(self, samples):
self.samples = samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
def make_eval_dataset(hf_repo, tokenizer, max_seq_len, num_eval=500, hf_name=None):
from datasets import load_dataset
ds = load_dataset(hf_repo, name=hf_name, split="train", streaming=True)
samples = []
for item in ds:
if len(samples) >= num_eval:
break
samples.append(_encode(_format_item(item, tokenizer), tokenizer, max_seq_len))
return ListDataset(samples)
def collate_fn(batch):
max_len = max(len(x) for x in batch)
padded = torch.full((len(batch), max_len), fill_value=0, dtype=torch.long)
labels = torch.full((len(batch), max_len), fill_value=-100, dtype=torch.long)
for i, seq in enumerate(batch):
l = len(seq)
padded[i, :l] = seq
labels[i, :l] = seq
return padded, labels
@torch.no_grad()
def compute_metrics(logits, labels, ignore_index=-100):
B, T, V = logits.shape
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(shift_logits.view(-1, V), shift_labels.view(-1), ignore_index=ignore_index, reduction='mean')
ppl = torch.exp(loss.double())
logits_stable = shift_logits - shift_logits.logsumexp(dim=-1, keepdim=True)
probs = torch.exp(logits_stable)
entropy = -(probs * logits_stable).sum(-1).mean()
preds = shift_logits.argmax(dim=-1)
mask = shift_labels != ignore_index
acc = (preds == shift_labels)[mask].float().mean() if mask.any() else torch.tensor(0.0)
return loss.item(), ppl.item(), entropy.item(), acc.item()
@torch.no_grad()
def evaluate(model, loader, device):
model.eval()
total_loss = 0
total_ppl = 0
total_entropy = 0
total_acc = 0
total_tokens = 0
n = 0
for x, y in loader:
x, y = x.to(device), y.to(device)
logits, loss = model(x, labels=y)
loss_val, ppl_val, ent_val, acc_val = compute_metrics(logits, y)
batch_tokens = (y != -100).sum().item()
total_loss += loss_val * x.size(0)
total_ppl += ppl_val * x.size(0)
total_entropy += ent_val * x.size(0)
total_acc += acc_val * x.size(0)
total_tokens += batch_tokens
n += x.size(0)
return {
"loss": total_loss / max(n, 1),
"ppl": total_ppl / max(n, 1),
"entropy": total_entropy / max(n, 1),
"acc": total_acc / max(n, 1),
"tokens": total_tokens,
}
def upload_to_hf(local_dir, repo_id, token, files):
from huggingface_hub import HfApi
api = HfApi()
for f in files:
path = os.path.join(local_dir, f)
if os.path.exists(path):
api.upload_file(path_or_fileobj=path, path_in_repo=f, repo_id=repo_id, token=token)
print(f" Uploaded {f}")
def train():
config_path = os.path.join(os.path.dirname(__file__), "..", "config", "train_tiny.yaml")
with open(config_path) as f:
cfg = yaml.safe_load(f)
t_cfg = cfg["training"]
d_cfg = cfg["data"]
device = "cuda" if torch.cuda.is_available() and not t_cfg.get("use_cpu", False) else "cpu"
print(f"Device: {device}")
# Kaggle secrets for HF_TOKEN
if not os.environ.get("HF_TOKEN"):
try:
from kaggle_secrets import UserSecretsClient
secret = UserSecretsClient().get_secret("HF_TOKEN")
if secret:
os.environ["HF_TOKEN"] = secret
except Exception:
pass
# Use actual vocab size from tokenizer
from tokenizers import Tokenizer as Tk
from transformers import PreTrainedTokenizerFast
tok_obj = Tk.from_file(os.path.join(os.path.dirname(__file__), "..", "tokenizer", "tokenizer.json"))
tok = PreTrainedTokenizerFast(tokenizer_object=tok_obj)
tok.add_special_tokens({"pad_token": "<pad>", "bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>"})
tok.chat_template = "{% for message in messages %}{% if message['role'] == 'system' %}<|system|>\n{{ message['content'] }}\n{% elif message['role'] == 'user' %}<|user|>\n{{ message['content'] }}\n{% elif message['role'] == 'assistant' %}<|assistant|>\n{{ message['content'] }}\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>\n{% endif %}"
vocab_size = tok.vocab_size
max_seq = d_cfg.get("max_seq_length", 2048)
# ── QLoRA / CFT config ──────────────────────────────────────────────────
qlora_cfg = cfg.get("qlora", {})
cft_cfg = cfg.get("cft", {})
qlora_enabled = qlora_cfg.get("enabled", False)
cft_enabled = cft_cfg.get("enabled", False)
# ── Create model ────────────────────────────────────────────────────────
if cft_enabled:
# CFT: load a prior checkpoint and apply QLoRA
resume_ckpt = cft_cfg.get("resume_checkpoint", "")
if resume_ckpt and not os.path.exists(resume_ckpt):
# Download if missing
print(f"[CFT] Downloading {resume_ckpt} from HF ...")
from huggingface_hub import hf_hub_download
ckpt = hf_hub_download("samcheng0/lumia-tiny", resume_ckpt, repo_type="model")
import shutil
shutil.copy(ckpt, resume_ckpt)
print(f"[CFT] Downloaded to {resume_ckpt}")
model = TinyModel(vocab_size=vocab_size, hidden=128, code_dim=96,
num_layers=6, num_heads=8, num_kv_heads=4,
max_seq_len=max_seq, tie_weights=True)
model.reset_weights()
if resume_ckpt and os.path.exists(resume_ckpt):
raw = torch.load(resume_ckpt, map_location="cpu", weights_only=True)
sd = raw if "model" not in raw else raw["model"]
missing, unexpected = model.load_state_dict(sd, strict=False)
if missing:
print(f" Missing keys: {len(missing)} (e.g. {missing[:3]})")
if unexpected:
print(f" Unexpected keys: {len(unexpected)} (e.g. {unexpected[:3]})")
print(f"[CFT] Loaded checkpoint from {resume_ckpt}")
if cft_cfg.get("reset_embeddings", False):
nn.init.normal_(model.token_embed.weight, std=0.02)
print(f"[CFT] Reset token embeddings (incompatible vocab)")
if model.lm_head.weight is not model.token_embed.weight:
nn.init.normal_(model.lm_head.weight, std=0.02)
else:
# Standard: download V2 checkpoint and restore
if not os.path.exists("checkpoint.pt"):
print("[Restore] Downloading checkpoint.pt from samcheng0/lumia-tiny ...")
from huggingface_hub import hf_hub_download
ckpt = hf_hub_download("samcheng0/lumia-tiny", "checkpoint.pt", repo_type="model")
import shutil
shutil.copy(ckpt, "checkpoint.pt")
print(f"[Restore] Downloaded to checkpoint.pt")
model = restore_from_v2("checkpoint.pt")
model = model.to(device)
# ── Apply QLoRA ─────────────────────────────────────────────────────────
if qlora_enabled:
from scripts.model_tiny import apply_qlora
model = apply_qlora(model,
r=qlora_cfg.get("r", 8),
alpha=qlora_cfg.get("alpha", 16),
dropout=qlora_cfg.get("dropout", 0.0))
# Unfreeze embeddings if they were reset (new vocab needs training)
if cft_cfg.get("reset_embeddings", False):
for name, param in model.named_parameters():
if "token_embed" in name or "lm_head" in name:
param.requires_grad = True
# ── torch.compile ──────────────────────────────────────────────────────
if t_cfg.get("compile", False):
try:
model = torch.compile(model, dynamic=True)
print("[Compile] torch.compile enabled (dynamic=True)")
except Exception as e:
print(f"[Compile] Failed: {e} — continuing without compile")
hf_repo = d_cfg.get("hf_repo")
is_streaming = bool(hf_repo)
if is_streaming:
train_ds = StreamingSFTDataset(hf_repo, tok, max_seq, split=d_cfg.get("hf_split", "train"), hf_name=d_cfg.get("hf_repo_name"))
eval_ds = make_eval_dataset(hf_repo, tok, max_seq, num_eval=d_cfg.get("hf_num_eval", 500), hf_name=d_cfg.get("hf_repo_name"))
print(f"Train: streaming from {hf_repo} Eval: {len(eval_ds.samples)} samples")
else:
path = d_cfg["train_file"]
with open(path) as f:
raw = [json.loads(line) for line in f]
split = int(len(raw) * (1 - d_cfg.get("eval_split_ratio", 0.1)))
train_raw = raw[:split]
eval_raw = raw[split:]
train_ds = ListDataset([_encode(_format_item(x, tok), tok, max_seq) for x in train_raw])
eval_ds = ListDataset([_encode(_format_item(x, tok), tok, max_seq) for x in eval_raw])
print(f"Train: {len(train_ds.samples)} Eval: {len(eval_ds.samples)}")
bs = t_cfg.get("per_device_train_batch_size", 8)
eval_bs = t_cfg.get("per_device_eval_batch_size", 8)
ga_steps = t_cfg.get("gradient_accumulation_steps", 4)
lr = t_cfg.get("learning_rate", 3e-4)
epochs = t_cfg.get("num_train_epochs", 1)
max_grad_norm = t_cfg.get("max_grad_norm", 1.0)
log_steps = t_cfg.get("logging_steps", 5)
save_steps = t_cfg.get("save_steps", 200)
output_dir = t_cfg.get("output_dir", "outputs/tiny-sft")
train_loader = DataLoader(train_ds, batch_size=bs, shuffle=not is_streaming, collate_fn=collate_fn, num_workers=0)
eval_loader = DataLoader(eval_ds, batch_size=eval_bs, shuffle=False, collate_fn=collate_fn, num_workers=0)
# HF upload config
hf_repo_id = t_cfg.get("hf_repo_id", "")
hf_token = os.environ.get("HF_TOKEN", "")
n_params = sum(p.numel() for p in model.parameters())
step = 0
opt_step = 0
start_epoch = 1
best_loss = float('inf')
best_acc = 0.0
cumul_tokens = 0
global_step_offset = 0
ckpt_path = os.path.join(output_dir, "checkpoint.pt")
# ── Optimiser ──────────────────────────────────────────────────────────
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if "norm" in name or "ln_" in name or "lora_" in name:
no_decay_params.append(param)
else:
decay_params.append(param)
if not decay_params and not no_decay_params:
print("[WARNING] No trainable parameters found!")
optimizer = AdamW([
{"params": decay_params, "weight_decay": t_cfg.get("weight_decay", 0.1)},
{"params": no_decay_params, "weight_decay": 0.0},
], lr=lr, betas=(0.9, 0.95))
warmup_ratio = t_cfg.get("warmup_ratio", 0.1)
max_steps_cfg = t_cfg.get("max_steps", -1)
if is_streaming and max_steps_cfg > 0:
total_model_steps = max_steps_cfg
elif not is_streaming:
total_model_steps = len(train_ds) // bs * epochs
else:
total_model_steps = 50000
total_opt_steps = total_model_steps // ga_steps
warmup_opt_steps = int(total_opt_steps * warmup_ratio)
def _lr_lambda(current_step):
if current_step < warmup_opt_steps:
return float(current_step) / max(1, warmup_opt_steps)
progress = float(current_step - warmup_opt_steps) / max(1, total_opt_steps - warmup_opt_steps)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
scheduler = LambdaLR(optimizer, lr_lambda=_lr_lambda)
if os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
scheduler.load_state_dict(ckpt["scheduler"])
start_epoch = ckpt["epoch"] + 1
global_step_offset = ckpt["step"]
opt_step = ckpt.get("opt_step", 0)
best_loss = ckpt.get("best_loss", float('inf'))
best_acc = ckpt.get("best_acc", 0.0)
cumul_tokens = ckpt.get("cumul_tokens", 0)
print(f"Resumed: epoch={ckpt['epoch']}, step={ckpt['step']}, best_loss={best_loss:.4f}, best_acc={best_acc:.3f}, total_tok={cumul_tokens}")
os.makedirs(output_dir, exist_ok=True)
t_start = time.time()
t_epoch = time.time()
eff_bs = bs * ga_steps
print()
print("=" * 72)
print(f" PCT-V3 Training")
print(f" Model: {n_params:,} params | Device: {device.upper()}")
print(f" Vocab: {vocab_size} | Depth: 6 | Hidden: 128 | Code: 96 | Heads: 8/4 | RPW+GPP+VCR")
print(f" Epochs: {epochs} | Batch: {bs} | GA: {ga_steps} | Eff BS: {eff_bs} | LR: {lr}")
print(f" Warmup: {warmup_opt_steps} opt steps | Total: {total_opt_steps} opt steps")
print("=" * 72)
print(f" {'Time':>8} {'Ep':>2} {'Step':>8} {'Loss':>9} {'PPL':>9} {'Entropy':>9} {'Acc':>6} {'GradNorm':>9} {'LR':>10} {'tok/s':>7} {'smp/s':>7}")
print(f" {'-'*8:>8} {'-'*2:>2} {'-'*8:>8} {'-'*9:>9} {'-'*9:>9} {'-'*9:>9} {'-'*6:>6} {'-'*9:>9} {'-'*10:>10} {'-'*7:>7} {'-'*7:>7}")
print()
for epoch in range(start_epoch, epochs + 1):
model.train()
epoch_loss = 0
epoch_ppl = 0
epoch_entropy = 0
epoch_acc = 0
epoch_tokens = 0
epoch_steps = 0
optimizer.zero_grad()
for x, y in train_loader:
step += 1
epoch_steps += 1
x, y = x.to(device), y.to(device)
n_tokens = y.numel()
logits, loss = model(x, labels=y)
loss_val = loss.item()
loss = loss / ga_steps
loss.backward()
if step % ga_steps == 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm).item()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
opt_step += 1
else:
grad_norm = 0.0
if max_steps_cfg > 0 and step >= max_steps_cfg:
break
epoch_loss += loss_val
epoch_tokens += n_tokens
cumul_tokens += n_tokens
if step % log_steps == 0:
elapsed = time.time() - t_start
tok_per_sec = epoch_tokens / max(time.time() - t_epoch, 1)
smp_per_sec = epoch_steps * bs / max(time.time() - t_epoch, 1)
cur_lr = optimizer.param_groups[0]['lr']
ts = time.strftime("%H:%M:%S")
pct = min(100, step / max_steps_cfg * 100) if max_steps_cfg > 0 else 0
with torch.no_grad():
met = compute_metrics(logits, y)
loss_s, ppl_s, ent_s, acc_s = met
print(f" {ts:>8} {epoch:>2} {step:>8} {loss_s:>9.4f} {ppl_s:>9.2f} {ent_s:>9.4f} {acc_s:>6.3f} {grad_norm:>9.2e} {cur_lr:>10.2e} {tok_per_sec:>7.0f} {smp_per_sec:>7.1f} cum_tok={cumul_tokens}")
sys.stdout.flush()
if step % save_steps == 0:
global_step = global_step_offset + step
eval_metrics = evaluate(model, eval_loader, device)
elapsed = time.time() - t_start
pct = min(100, step / max_steps_cfg * 100) if max_steps_cfg > 0 else 0
print(f" {'':>8} {'':>2} {'':>8} {'-'*9} {'-'*9} {'-'*9} {'-'*6} {'-'*9} {'-'*10} {'-'*7} {'-'*7}")
print(f" -- Eval: loss={eval_metrics['loss']:.4f} ppl={eval_metrics['ppl']:.2f} ent={eval_metrics['entropy']:.4f} acc={eval_metrics['acc']:.3f} tokens={eval_metrics['tokens']} | best={best_loss:.4f} | step={global_step} [{pct:.1f}%] [{elapsed:.0f}s]")
sys.stdout.flush()
torch.save({
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": global_step,
"opt_step": opt_step,
"epoch": epoch,
"best_loss": best_loss,
"best_acc": best_acc,
"cumul_tokens": cumul_tokens,
"eval_metrics": eval_metrics,
}, ckpt_path)
print(f" + Checkpoint saved (step={global_step})")
sys.stdout.flush()
if eval_metrics['loss'] < best_loss:
best_loss = eval_metrics['loss']
best_acc = eval_metrics['acc']
torch.save(model.state_dict(), os.path.join(output_dir, "best.pt"))
print(f" + New best! loss={best_loss:.4f} ppl={eval_metrics['ppl']:.2f} acc={best_acc:.3f}")
if hf_repo_id and hf_token:
upload_to_hf(output_dir, hf_repo_id, hf_token, ["best.pt"])
print(f" + Uploaded best.pt to {hf_repo_id}")
elif eval_metrics['acc'] > best_acc:
best_acc = eval_metrics['acc']
sys.stdout.flush()
if max_steps_cfg > 0 and step >= max_steps_cfg:
break
t_epoch_end = time.time()
avg_loss = epoch_loss / max(epoch_steps, 1)
eval_metrics = evaluate(model, eval_loader, device)
epoch_time = t_epoch_end - t_epoch
t_epoch = t_epoch_end
print(f" {'':>8} {'':>2} {'':>8} {'-'*9} {'-'*9} {'-'*9} {'-'*6} {'-'*9} {'-'*10} {'-'*7} {'-'*7}")
print(f" -- Epoch {epoch}/{epochs} done | loss={avg_loss:.4f} | eval={eval_metrics['loss']:.4f} | ppl={eval_metrics['ppl']:.2f} | ent={eval_metrics['entropy']:.4f} | acc={eval_metrics['acc']:.3f} | tok={epoch_tokens} | cum_tok={cumul_tokens} | time={epoch_time:.0f}s --")
sys.stdout.flush()
torch.save(model.state_dict(), os.path.join(output_dir, f"epoch_{epoch}.pt"))
final = os.path.join(output_dir, "final")
os.makedirs(final, exist_ok=True)
torch.save(model.state_dict(), os.path.join(final, "model.pt"))
total_time = time.time() - t_start
print()
print("=" * 72)
print(f" Training complete!")
print(f" Total time: {total_time:.0f}s ({total_time/3600:.1f}h)")
print(f" Total tokens: {cumul_tokens:,}")
print(f" Total steps: {step}")
print(f" Best loss: {best_loss:.4f}")
print(f" Best acc: {best_acc:.3f}")
print(f" Avg tok/s: {cumul_tokens/max(total_time, 1):.0f}")
print(f" Saved: {final}/model.pt")
print("=" * 72)
print()
if hf_repo_id and hf_token:
print(f"Uploading to HF: {hf_repo_id} ...")
upload_files = ["best.pt", "final/model.pt"]
for ep in range(start_epoch, epochs + 1):
ep_path = os.path.join(output_dir, f"epoch_{ep}.pt")
if os.path.exists(ep_path):
upload_files.append(f"epoch_{ep}.pt")
upload_to_hf(output_dir, hf_repo_id, hf_token, upload_files)
print(f"Uploaded to https://huggingface.co/{hf_repo_id}")
return final
if __name__ == "__main__":
train()