|
|
import os |
|
|
import time |
|
|
import math |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import requests |
|
|
from torch.utils.data import DataLoader |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator |
|
|
from datasets import load_dataset |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
MODEL_ID = "Qwen/Qwen3-1.7B" |
|
|
DATA_PATH = "/workspace/French_ASR_Corpus_Raw" |
|
|
OUTPUT_DIR = "/workspace/checkpoints" |
|
|
|
|
|
|
|
|
TOTAL_TOKENS = 3_000_000_000 |
|
|
SEQ_LEN = 8192 |
|
|
BATCH_SIZE = 1 |
|
|
GRAD_ACCUM = 64 |
|
|
LEARNING_RATE = 0.002 |
|
|
SOFTCAP_VAL = 30.0 |
|
|
|
|
|
PROUST_URL = "https://www.gutenberg.org/cache/epub/2650/pg2650.txt" |
|
|
|
|
|
VIBE_PROMPTS = [ |
|
|
{"name": "Lipogram (No 'e')", "prompt": "Écris une phrase sur l'hiver sans utiliser la lettre 'e'.\nPhrase:"}, |
|
|
{"name": "Slang Translate", "prompt": "Traduis en langage soutenu: 'Wesh le sang, c'est comment ? T'as capté ?'\nTraduction:"}, |
|
|
{"name": "ASR Context", "prompt": "Corrige la transcription: 'Il a pris son pain seau pour peindre le mur.' -> "}, |
|
|
{"name": "Thinking (Logic)", "prompt": "Si je suis à Paris et que je regarde le soleil se coucher, quelle direction est derrière moi ?\nRéponse:"} |
|
|
] |
|
|
|
|
|
|
|
|
def zeropower_via_newtonschulz5(G, steps=5): |
|
|
assert len(G.shape) == 2 |
|
|
a, b, c = (3.4445, -4.7750, 2.0315) |
|
|
X = G.bfloat16() |
|
|
X /= (X.norm() + 1e-7) |
|
|
if G.size(0) > G.size(1): X = X.T |
|
|
for _ in range(steps): |
|
|
A = X @ X.T |
|
|
B = b * A + c * A @ A |
|
|
X = a * X + B @ X |
|
|
if G.size(0) > G.size(1): X = X.T |
|
|
return X |
|
|
|
|
|
class Muon(torch.optim.Optimizer): |
|
|
def __init__(self, params, lr=0.002, momentum=0.95, nesterov=True, ns_steps=5): |
|
|
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) |
|
|
super().__init__(params, defaults) |
|
|
|
|
|
@torch.no_grad() |
|
|
def step(self): |
|
|
for group in self.param_groups: |
|
|
lr = group['lr'] |
|
|
momentum = group['momentum'] |
|
|
for p in group['params']: |
|
|
if p.grad is None: continue |
|
|
g = p.grad |
|
|
if g.ndim != 2: continue |
|
|
state = self.state[p] |
|
|
if 'momentum_buffer' not in state: |
|
|
state['momentum_buffer'] = torch.zeros_like(g) |
|
|
buf = state['momentum_buffer'] |
|
|
buf.mul_(momentum).add_(g) |
|
|
if group['nesterov']: g = g.add(buf, alpha=momentum) |
|
|
else: g = buf |
|
|
g_ortho = zeropower_via_newtonschulz5(g, steps=group['ns_steps']) |
|
|
scale = max(1, g.size(0)/g.size(1))**0.5 |
|
|
p.data.add_(g_ortho, alpha=-lr * scale) |
|
|
|
|
|
|
|
|
def apply_softcapping(logits, cap=30.0): |
|
|
return torch.tanh(logits / cap) * cap |
|
|
|
|
|
def get_trapezoidal_schedule(optimizer, num_training_steps, warmup_ratio=0.1, hold_ratio=0.3): |
|
|
def lr_lambda(current_step): |
|
|
warmup_steps = int(num_training_steps * warmup_ratio) |
|
|
hold_steps = int(num_training_steps * hold_ratio) |
|
|
decay_start = warmup_steps + hold_steps |
|
|
if current_step < warmup_steps: |
|
|
return float(current_step) / float(max(1, warmup_steps)) |
|
|
elif current_step < decay_start: |
|
|
return 1.0 |
|
|
else: |
|
|
decay_steps = num_training_steps - decay_start |
|
|
progress = (current_step - decay_start) / max(1, decay_steps) |
|
|
return max(0.0, 1.0 - progress) |
|
|
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
|
def download_proust(): |
|
|
try: |
|
|
r = requests.get(PROUST_URL) |
|
|
text = r.text |
|
|
start = text.find("Longtemps, je me suis couché de bonne heure.") |
|
|
if start == -1: start = 1000 |
|
|
return text[start:start+50000] |
|
|
except: |
|
|
return "Longtemps, je me suis couché de bonne heure. " * 500 |
|
|
|
|
|
def calc_perplexity(model, tokenizer, text, device): |
|
|
encodings = tokenizer(text, return_tensors="pt") |
|
|
max_len = 8192 |
|
|
input_ids = encodings.input_ids[:, :max_len].to(device) |
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids, labels=input_ids) |
|
|
neg_log_likelihood = outputs.loss |
|
|
return torch.exp(neg_log_likelihood).item() |
|
|
|
|
|
|
|
|
def main(): |
|
|
torch.cuda.empty_cache() |
|
|
torch.set_float32_matmul_precision('high') |
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
print(f"Loading {MODEL_ID}...") |
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.bfloat16, |
|
|
attn_implementation="flash_attention_2" |
|
|
).to("cuda") |
|
|
except OSError: |
|
|
MODEL_ID_FALLBACK = "Qwen/Qwen3-1.7B-Instruct" |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID_FALLBACK) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID_FALLBACK, |
|
|
torch_dtype=torch.bfloat16, |
|
|
attn_implementation="flash_attention_2" |
|
|
).to("cuda") |
|
|
|
|
|
|
|
|
model.gradient_checkpointing_enable() |
|
|
print("✅ Gradient Checkpointing: ENABLED") |
|
|
|
|
|
|
|
|
print("✅ Torch Compile: DISABLED (Stability Mode)") |
|
|
|
|
|
muon_params = [p for n, p in model.named_parameters() if p.requires_grad and p.ndim == 2] |
|
|
adam_params = [p for n, p in model.named_parameters() if p.requires_grad and p.ndim != 2] |
|
|
|
|
|
optim_muon = Muon(muon_params, lr=LEARNING_RATE) |
|
|
optim_adam = torch.optim.AdamW(adam_params, lr=LEARNING_RATE * 0.1, weight_decay=0.01) |
|
|
|
|
|
print("Loading Data...") |
|
|
dataset = load_dataset("parquet", data_files=f"{DATA_PATH}/*.parquet", split="train", streaming=True) |
|
|
dataset = dataset.map( |
|
|
lambda x: tokenizer(x["text"], truncation=True, max_length=SEQ_LEN, padding="max_length"), |
|
|
batched=True |
|
|
).remove_columns(["text", "url", "category", "ttr", "token_est"] if "ttr" in list(dataset.take(1))[0] else []) |
|
|
|
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=BATCH_SIZE, |
|
|
collate_fn=default_data_collator, |
|
|
num_workers=8, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
val_text = download_proust() |
|
|
tokens_per_step = BATCH_SIZE * SEQ_LEN * GRAD_ACCUM |
|
|
total_steps = TOTAL_TOKENS // tokens_per_step |
|
|
|
|
|
scheduler_muon = get_trapezoidal_schedule(optim_muon, total_steps) |
|
|
scheduler_adam = get_trapezoidal_schedule(optim_adam, total_steps) |
|
|
|
|
|
print(f"🚀 STARTING RUN: {total_steps} steps | Target: 3B Tokens") |
|
|
|
|
|
model.train() |
|
|
pbar = tqdm(total=total_steps, unit="step", dynamic_ncols=True) |
|
|
accum_loss = 0 |
|
|
t0 = time.time() |
|
|
|
|
|
for i, batch in enumerate(dataloader): |
|
|
batch = {k: v.to("cuda", non_blocking=True) for k, v in batch.items()} |
|
|
|
|
|
outputs = model(**batch) |
|
|
logits = apply_softcapping(outputs.logits, cap=SOFTCAP_VAL) |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = batch["input_ids"][..., 1:].contiguous() |
|
|
loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
|
|
|
(loss / GRAD_ACCUM).backward() |
|
|
accum_loss += loss.item() |
|
|
|
|
|
if (i + 1) % GRAD_ACCUM == 0: |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
optim_muon.step(); optim_adam.step() |
|
|
optim_muon.zero_grad(); optim_adam.zero_grad() |
|
|
scheduler_muon.step(); scheduler_adam.step() |
|
|
|
|
|
dt = time.time() - t0 |
|
|
if dt > 0: |
|
|
tps = tokens_per_step / dt |
|
|
pbar.set_postfix({"Loss": f"{accum_loss:.4f}", "Kt/s": f"{tps/1000:.1f}"}) |
|
|
t0 = time.time() |
|
|
|
|
|
pbar.update(1) |
|
|
accum_loss = 0 |
|
|
|
|
|
if pbar.n % 100 == 0: |
|
|
print(f"\n--- 🎭 VIBE CHECK (Step {pbar.n}) ---") |
|
|
model.eval() |
|
|
for v in VIBE_PROMPTS: |
|
|
inp = tokenizer(v["prompt"], return_tensors="pt").to("cuda") |
|
|
with torch.no_grad(): |
|
|
gen = model.generate(**inp, max_new_tokens=50, do_sample=True, temperature=0.7) |
|
|
res = tokenizer.decode(gen[0], skip_special_tokens=True).replace(v['prompt'], '').strip().replace('\n', ' ') |
|
|
print(f"🔹 {v['name']}: {res}") |
|
|
model.train() |
|
|
|
|
|
if pbar.n % 500 == 0: |
|
|
print(f"\n--- 🧐 PROUST CHECK (Step {pbar.n}) ---") |
|
|
model.eval() |
|
|
ppl = calc_perplexity(model, tokenizer, val_text, "cuda") |
|
|
print(f"📚 Hard French Perplexity: {ppl:.2f}") |
|
|
model.train() |
|
|
model.save_pretrained(f"{OUTPUT_DIR}/step_{pbar.n}") |
|
|
tokenizer.save_pretrained(f"{OUTPUT_DIR}/step_{pbar.n}") |
|
|
|
|
|
if pbar.n >= total_steps: break |
|
|
|
|
|
model.save_pretrained(f"{OUTPUT_DIR}/final") |
|
|
tokenizer.save_pretrained(f"{OUTPUT_DIR}/final") |
|
|
print("🏁 DONE.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |