import os import time import math import torch import torch.nn.functional as F import requests from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator from datasets import load_dataset from tqdm import tqdm # --- CONFIGURATION --- MODEL_ID = "Qwen/Qwen3-1.7B" DATA_PATH = "/workspace/French_ASR_Corpus_Raw" OUTPUT_DIR = "/workspace/checkpoints" # --- INTELLECTUAL CONFIG (3B Tokens) --- TOTAL_TOKENS = 3_000_000_000 SEQ_LEN = 8192 BATCH_SIZE = 1 # Safe for VRAM GRAD_ACCUM = 64 # High accumulation = stable updates (Effective Batch = 64) LEARNING_RATE = 0.002 SOFTCAP_VAL = 30.0 PROUST_URL = "https://www.gutenberg.org/cache/epub/2650/pg2650.txt" VIBE_PROMPTS = [ {"name": "Lipogram (No 'e')", "prompt": "Écris une phrase sur l'hiver sans utiliser la lettre 'e'.\nPhrase:"}, {"name": "Slang Translate", "prompt": "Traduis en langage soutenu: 'Wesh le sang, c'est comment ? T'as capté ?'\nTraduction:"}, {"name": "ASR Context", "prompt": "Corrige la transcription: 'Il a pris son pain seau pour peindre le mur.' -> "}, {"name": "Thinking (Logic)", "prompt": "Si je suis à Paris et que je regarde le soleil se coucher, quelle direction est derrière moi ?\nRéponse:"} ] # --- OPTIMIZER (Compiled for Speed) --- def zeropower_via_newtonschulz5(G, steps=5): assert len(G.shape) == 2 a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() X /= (X.norm() + 1e-7) if G.size(0) > G.size(1): X = X.T for _ in range(steps): A = X @ X.T B = b * A + c * A @ A X = a * X + B @ X if G.size(0) > G.size(1): X = X.T return X class Muon(torch.optim.Optimizer): def __init__(self, params, lr=0.002, momentum=0.95, nesterov=True, ns_steps=5): defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) super().__init__(params, defaults) @torch.no_grad() def step(self): for group in self.param_groups: lr = group['lr'] momentum = group['momentum'] for p in group['params']: if p.grad is None: continue g = p.grad if g.ndim != 2: continue state = self.state[p] if 'momentum_buffer' not in state: state['momentum_buffer'] = torch.zeros_like(g) buf = state['momentum_buffer'] buf.mul_(momentum).add_(g) if group['nesterov']: g = g.add(buf, alpha=momentum) else: g = buf g_ortho = zeropower_via_newtonschulz5(g, steps=group['ns_steps']) scale = max(1, g.size(0)/g.size(1))**0.5 p.data.add_(g_ortho, alpha=-lr * scale) # --- UTILITIES --- def apply_softcapping(logits, cap=30.0): return torch.tanh(logits / cap) * cap def get_trapezoidal_schedule(optimizer, num_training_steps, warmup_ratio=0.1, hold_ratio=0.3): def lr_lambda(current_step): warmup_steps = int(num_training_steps * warmup_ratio) hold_steps = int(num_training_steps * hold_ratio) decay_start = warmup_steps + hold_steps if current_step < warmup_steps: return float(current_step) / float(max(1, warmup_steps)) elif current_step < decay_start: return 1.0 else: decay_steps = num_training_steps - decay_start progress = (current_step - decay_start) / max(1, decay_steps) return max(0.0, 1.0 - progress) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) def download_proust(): try: r = requests.get(PROUST_URL) text = r.text start = text.find("Longtemps, je me suis couché de bonne heure.") if start == -1: start = 1000 return text[start:start+50000] except: return "Longtemps, je me suis couché de bonne heure. " * 500 def calc_perplexity(model, tokenizer, text, device): encodings = tokenizer(text, return_tensors="pt") max_len = 8192 input_ids = encodings.input_ids[:, :max_len].to(device) with torch.no_grad(): outputs = model(input_ids, labels=input_ids) neg_log_likelihood = outputs.loss return torch.exp(neg_log_likelihood).item() # --- MAIN --- def main(): torch.cuda.empty_cache() torch.set_float32_matmul_precision('high') os.makedirs(OUTPUT_DIR, exist_ok=True) print(f"Loading {MODEL_ID}...") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ).to("cuda") except OSError: MODEL_ID_FALLBACK = "Qwen/Qwen3-1.7B-Instruct" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID_FALLBACK) model = AutoModelForCausalLM.from_pretrained( MODEL_ID_FALLBACK, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ).to("cuda") # CRITICAL: Enable Checkpointing to fit in VRAM model.gradient_checkpointing_enable() print("✅ Gradient Checkpointing: ENABLED") # Disable compile for model to avoid OOM/Instability print("✅ Torch Compile: DISABLED (Stability Mode)") muon_params = [p for n, p in model.named_parameters() if p.requires_grad and p.ndim == 2] adam_params = [p for n, p in model.named_parameters() if p.requires_grad and p.ndim != 2] optim_muon = Muon(muon_params, lr=LEARNING_RATE) optim_adam = torch.optim.AdamW(adam_params, lr=LEARNING_RATE * 0.1, weight_decay=0.01) print("Loading Data...") dataset = load_dataset("parquet", data_files=f"{DATA_PATH}/*.parquet", split="train", streaming=True) dataset = dataset.map( lambda x: tokenizer(x["text"], truncation=True, max_length=SEQ_LEN, padding="max_length"), batched=True ).remove_columns(["text", "url", "category", "ttr", "token_est"] if "ttr" in list(dataset.take(1))[0] else []) # CRITICAL: 8 Workers to feed the H100 dataloader = DataLoader( dataset, batch_size=BATCH_SIZE, collate_fn=default_data_collator, num_workers=8, pin_memory=True ) val_text = download_proust() tokens_per_step = BATCH_SIZE * SEQ_LEN * GRAD_ACCUM total_steps = TOTAL_TOKENS // tokens_per_step scheduler_muon = get_trapezoidal_schedule(optim_muon, total_steps) scheduler_adam = get_trapezoidal_schedule(optim_adam, total_steps) print(f"🚀 STARTING RUN: {total_steps} steps | Target: 3B Tokens") model.train() pbar = tqdm(total=total_steps, unit="step", dynamic_ncols=True) accum_loss = 0 t0 = time.time() for i, batch in enumerate(dataloader): batch = {k: v.to("cuda", non_blocking=True) for k, v in batch.items()} outputs = model(**batch) logits = apply_softcapping(outputs.logits, cap=SOFTCAP_VAL) shift_logits = logits[..., :-1, :].contiguous() shift_labels = batch["input_ids"][..., 1:].contiguous() loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) (loss / GRAD_ACCUM).backward() accum_loss += loss.item() if (i + 1) % GRAD_ACCUM == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optim_muon.step(); optim_adam.step() optim_muon.zero_grad(); optim_adam.zero_grad() scheduler_muon.step(); scheduler_adam.step() dt = time.time() - t0 if dt > 0: tps = tokens_per_step / dt pbar.set_postfix({"Loss": f"{accum_loss:.4f}", "Kt/s": f"{tps/1000:.1f}"}) t0 = time.time() pbar.update(1) accum_loss = 0 if pbar.n % 100 == 0: print(f"\n--- 🎭 VIBE CHECK (Step {pbar.n}) ---") model.eval() for v in VIBE_PROMPTS: inp = tokenizer(v["prompt"], return_tensors="pt").to("cuda") with torch.no_grad(): gen = model.generate(**inp, max_new_tokens=50, do_sample=True, temperature=0.7) res = tokenizer.decode(gen[0], skip_special_tokens=True).replace(v['prompt'], '').strip().replace('\n', ' ') print(f"🔹 {v['name']}: {res}") model.train() if pbar.n % 500 == 0: print(f"\n--- 🧐 PROUST CHECK (Step {pbar.n}) ---") model.eval() ppl = calc_perplexity(model, tokenizer, val_text, "cuda") print(f"📚 Hard French Perplexity: {ppl:.2f}") model.train() model.save_pretrained(f"{OUTPUT_DIR}/step_{pbar.n}") tokenizer.save_pretrained(f"{OUTPUT_DIR}/step_{pbar.n}") if pbar.n >= total_steps: break model.save_pretrained(f"{OUTPUT_DIR}/final") tokenizer.save_pretrained(f"{OUTPUT_DIR}/final") print("🏁 DONE.") if __name__ == "__main__": main()