french-qwen-checkpoints / train_final.py
toth235a's picture
Upload train_final.py with huggingface_hub
ea19689 verified
import os
import time
import math
import torch
import torch.nn.functional as F
import requests
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator
from datasets import load_dataset
from tqdm import tqdm
# --- CONFIGURATION ---
MODEL_ID = "Qwen/Qwen3-1.7B"
DATA_PATH = "/workspace/French_ASR_Corpus_Raw"
OUTPUT_DIR = "/workspace/checkpoints"
# --- INTELLECTUAL CONFIG (3B Tokens) ---
TOTAL_TOKENS = 3_000_000_000
SEQ_LEN = 8192
BATCH_SIZE = 1 # Safe for VRAM
GRAD_ACCUM = 64 # High accumulation = stable updates (Effective Batch = 64)
LEARNING_RATE = 0.002
SOFTCAP_VAL = 30.0
PROUST_URL = "https://www.gutenberg.org/cache/epub/2650/pg2650.txt"
VIBE_PROMPTS = [
{"name": "Lipogram (No 'e')", "prompt": "Écris une phrase sur l'hiver sans utiliser la lettre 'e'.\nPhrase:"},
{"name": "Slang Translate", "prompt": "Traduis en langage soutenu: 'Wesh le sang, c'est comment ? T'as capté ?'\nTraduction:"},
{"name": "ASR Context", "prompt": "Corrige la transcription: 'Il a pris son pain seau pour peindre le mur.' -> "},
{"name": "Thinking (Logic)", "prompt": "Si je suis à Paris et que je regarde le soleil se coucher, quelle direction est derrière moi ?\nRéponse:"}
]
# --- OPTIMIZER (Compiled for Speed) ---
def zeropower_via_newtonschulz5(G, steps=5):
assert len(G.shape) == 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
X /= (X.norm() + 1e-7)
if G.size(0) > G.size(1): X = X.T
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X
if G.size(0) > G.size(1): X = X.T
return X
class Muon(torch.optim.Optimizer):
def __init__(self, params, lr=0.002, momentum=0.95, nesterov=True, ns_steps=5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
lr = group['lr']
momentum = group['momentum']
for p in group['params']:
if p.grad is None: continue
g = p.grad
if g.ndim != 2: continue
state = self.state[p]
if 'momentum_buffer' not in state:
state['momentum_buffer'] = torch.zeros_like(g)
buf = state['momentum_buffer']
buf.mul_(momentum).add_(g)
if group['nesterov']: g = g.add(buf, alpha=momentum)
else: g = buf
g_ortho = zeropower_via_newtonschulz5(g, steps=group['ns_steps'])
scale = max(1, g.size(0)/g.size(1))**0.5
p.data.add_(g_ortho, alpha=-lr * scale)
# --- UTILITIES ---
def apply_softcapping(logits, cap=30.0):
return torch.tanh(logits / cap) * cap
def get_trapezoidal_schedule(optimizer, num_training_steps, warmup_ratio=0.1, hold_ratio=0.3):
def lr_lambda(current_step):
warmup_steps = int(num_training_steps * warmup_ratio)
hold_steps = int(num_training_steps * hold_ratio)
decay_start = warmup_steps + hold_steps
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
elif current_step < decay_start:
return 1.0
else:
decay_steps = num_training_steps - decay_start
progress = (current_step - decay_start) / max(1, decay_steps)
return max(0.0, 1.0 - progress)
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
def download_proust():
try:
r = requests.get(PROUST_URL)
text = r.text
start = text.find("Longtemps, je me suis couché de bonne heure.")
if start == -1: start = 1000
return text[start:start+50000]
except:
return "Longtemps, je me suis couché de bonne heure. " * 500
def calc_perplexity(model, tokenizer, text, device):
encodings = tokenizer(text, return_tensors="pt")
max_len = 8192
input_ids = encodings.input_ids[:, :max_len].to(device)
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
neg_log_likelihood = outputs.loss
return torch.exp(neg_log_likelihood).item()
# --- MAIN ---
def main():
torch.cuda.empty_cache()
torch.set_float32_matmul_precision('high')
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Loading {MODEL_ID}...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2"
).to("cuda")
except OSError:
MODEL_ID_FALLBACK = "Qwen/Qwen3-1.7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID_FALLBACK)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID_FALLBACK,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2"
).to("cuda")
# CRITICAL: Enable Checkpointing to fit in VRAM
model.gradient_checkpointing_enable()
print("✅ Gradient Checkpointing: ENABLED")
# Disable compile for model to avoid OOM/Instability
print("✅ Torch Compile: DISABLED (Stability Mode)")
muon_params = [p for n, p in model.named_parameters() if p.requires_grad and p.ndim == 2]
adam_params = [p for n, p in model.named_parameters() if p.requires_grad and p.ndim != 2]
optim_muon = Muon(muon_params, lr=LEARNING_RATE)
optim_adam = torch.optim.AdamW(adam_params, lr=LEARNING_RATE * 0.1, weight_decay=0.01)
print("Loading Data...")
dataset = load_dataset("parquet", data_files=f"{DATA_PATH}/*.parquet", split="train", streaming=True)
dataset = dataset.map(
lambda x: tokenizer(x["text"], truncation=True, max_length=SEQ_LEN, padding="max_length"),
batched=True
).remove_columns(["text", "url", "category", "ttr", "token_est"] if "ttr" in list(dataset.take(1))[0] else [])
# CRITICAL: 8 Workers to feed the H100
dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
collate_fn=default_data_collator,
num_workers=8,
pin_memory=True
)
val_text = download_proust()
tokens_per_step = BATCH_SIZE * SEQ_LEN * GRAD_ACCUM
total_steps = TOTAL_TOKENS // tokens_per_step
scheduler_muon = get_trapezoidal_schedule(optim_muon, total_steps)
scheduler_adam = get_trapezoidal_schedule(optim_adam, total_steps)
print(f"🚀 STARTING RUN: {total_steps} steps | Target: 3B Tokens")
model.train()
pbar = tqdm(total=total_steps, unit="step", dynamic_ncols=True)
accum_loss = 0
t0 = time.time()
for i, batch in enumerate(dataloader):
batch = {k: v.to("cuda", non_blocking=True) for k, v in batch.items()}
outputs = model(**batch)
logits = apply_softcapping(outputs.logits, cap=SOFTCAP_VAL)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = batch["input_ids"][..., 1:].contiguous()
loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
(loss / GRAD_ACCUM).backward()
accum_loss += loss.item()
if (i + 1) % GRAD_ACCUM == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optim_muon.step(); optim_adam.step()
optim_muon.zero_grad(); optim_adam.zero_grad()
scheduler_muon.step(); scheduler_adam.step()
dt = time.time() - t0
if dt > 0:
tps = tokens_per_step / dt
pbar.set_postfix({"Loss": f"{accum_loss:.4f}", "Kt/s": f"{tps/1000:.1f}"})
t0 = time.time()
pbar.update(1)
accum_loss = 0
if pbar.n % 100 == 0:
print(f"\n--- 🎭 VIBE CHECK (Step {pbar.n}) ---")
model.eval()
for v in VIBE_PROMPTS:
inp = tokenizer(v["prompt"], return_tensors="pt").to("cuda")
with torch.no_grad():
gen = model.generate(**inp, max_new_tokens=50, do_sample=True, temperature=0.7)
res = tokenizer.decode(gen[0], skip_special_tokens=True).replace(v['prompt'], '').strip().replace('\n', ' ')
print(f"🔹 {v['name']}: {res}")
model.train()
if pbar.n % 500 == 0:
print(f"\n--- 🧐 PROUST CHECK (Step {pbar.n}) ---")
model.eval()
ppl = calc_perplexity(model, tokenizer, val_text, "cuda")
print(f"📚 Hard French Perplexity: {ppl:.2f}")
model.train()
model.save_pretrained(f"{OUTPUT_DIR}/step_{pbar.n}")
tokenizer.save_pretrained(f"{OUTPUT_DIR}/step_{pbar.n}")
if pbar.n >= total_steps: break
model.save_pretrained(f"{OUTPUT_DIR}/final")
tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")
print("🏁 DONE.")
if __name__ == "__main__":
main()