| |
| |
|
|
| from __future__ import annotations |
|
|
| import json |
| import math |
| import os |
| import random |
| import time |
| from collections import OrderedDict |
| from contextlib import nullcontext |
| from dataclasses import asdict, dataclass |
| from pathlib import Path |
| from typing import Iterator, Optional |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from datasets import load_dataset |
| from transformers import PreTrainedTokenizerFast |
|
|
|
|
| |
| |
| |
|
|
| BASE_CHECKPOINT = Path("./wikipedia_ar_h100_codealpaca/model_best.pt") |
| BASE_TOKENIZER_DIR = Path("./wikipedia_ar_h100/tokenizer_32k") |
| BASE_CONFIG_FILE = Path("./wikipedia_ar_h100/config.json") |
|
|
| OUT_DIR = Path("./wikipedia_ar_h100_multicode_10x2000") |
| OUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| MODEL_FILE = OUT_DIR / "model.pt" |
| BEST_MODEL_FILE = OUT_DIR / "model_best.pt" |
| STATE_FILE = OUT_DIR / "train_state.pt" |
| CONFIG_FILE = OUT_DIR / "config.json" |
|
|
|
|
| |
| |
| |
|
|
| TRAIN_SOURCES = [ |
| { |
| "name": "HuggingFaceH4/CodeAlpaca_20K", |
| "subset": None, |
| "split": "train", |
| "kind": "codealpaca", |
| "weight": 0.45, |
| "streaming": False, |
| }, |
| { |
| "name": "open-r1/codeforces", |
| "subset": "verifiable-prompts", |
| "split": "train", |
| "kind": "codeforces_python", |
| "weight": 0.35, |
| "streaming": False, |
| }, |
| { |
| "name": "wikimedia/wikipedia", |
| "subset": "20231101.ar", |
| "split": "train", |
| "kind": "wikipedia_ar", |
| "weight": 0.20, |
| "streaming": True, |
| }, |
| ] |
|
|
| EVAL_SOURCE = { |
| "name": "HuggingFaceH4/CodeAlpaca_20K", |
| "subset": None, |
| "split": "test", |
| "kind": "codealpaca", |
| "streaming": False, |
| } |
|
|
| CODEFORCES_LANGUAGE = "python" |
|
|
|
|
| |
| |
| |
|
|
| SEED = 42 |
| TARGET_VRAM_GIB = 75.0 |
|
|
| LEARNING_RATE = 5e-5 |
| MIN_LR = 5e-6 |
| WEIGHT_DECAY = 0.1 |
| WARMUP_STEPS = 200 |
|
|
| NUM_ROUNDS = 10 |
| STEPS_PER_ROUND = 2000 |
| MAX_STEPS = NUM_ROUNDS * STEPS_PER_ROUND |
|
|
| BATCH_SIZE = 24 |
| GRAD_ACCUM_STEPS = 1 |
| MAX_GRAD_NORM = 1.0 |
|
|
| EVAL_EVERY = 250 |
| SAVE_EVERY = 500 |
| MAX_EVAL_EXAMPLES = 2000 |
| TEXT_CHAR_LIMIT = 6000 |
|
|
| DTYPE = torch.bfloat16 |
| USE_COMPILE = True |
| COMPILE_MODE = "default" |
| USE_CHECKPOINTING = False |
|
|
| TRAIN_NUM_WORKERS = 0 |
| EVAL_NUM_WORKERS = 0 |
|
|
|
|
| |
| |
| |
|
|
| def is_distributed() -> bool: |
| return dist.is_available() and dist.is_initialized() |
|
|
| def get_rank() -> int: |
| return dist.get_rank() if is_distributed() else 0 |
|
|
| def get_world_size() -> int: |
| return dist.get_world_size() if is_distributed() else 1 |
|
|
| def is_main() -> bool: |
| return get_rank() == 0 |
|
|
| def init_distributed() -> Optional[torch.device]: |
| local_rank = int(os.environ.get("LOCAL_RANK", -1)) |
| if local_rank == -1: |
| return None |
| dist.init_process_group("nccl") |
| torch.cuda.set_device(local_rank) |
| return torch.device(f"cuda:{local_rank}") |
|
|
| def set_seed(seed: int) -> None: |
| random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
| def get_device(ddp_device: Optional[torch.device] = None) -> torch.device: |
| if ddp_device is not None: |
| return ddp_device |
| if torch.cuda.is_available(): |
| return torch.device(f"cuda:{torch.cuda.current_device()}") |
| return torch.device("cpu") |
|
|
| def current_cuda_index(device: torch.device) -> int: |
| if device.type != "cuda": |
| raise ValueError("Device non CUDA") |
| return device.index if device.index is not None else torch.cuda.current_device() |
|
|
| def autocast_context(device: torch.device): |
| if device.type == "cuda": |
| return torch.autocast("cuda", dtype=DTYPE) |
| return nullcontext() |
|
|
| def unwrap_model(model: nn.Module) -> nn.Module: |
| m = model.module if isinstance(model, DDP) else model |
| if hasattr(m, "_orig_mod"): |
| return m._orig_mod |
| return m |
|
|
| def count_parameters(model: nn.Module) -> int: |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
| def normalize_state_dict_keys(state_dict: dict) -> OrderedDict: |
| normalized = OrderedDict() |
| for k, v in state_dict.items(): |
| nk = k |
| if nk.startswith("module._orig_mod."): |
| nk = nk[len("module._orig_mod."):] |
| elif nk.startswith("_orig_mod."): |
| nk = nk[len("_orig_mod."):] |
| elif nk.startswith("module."): |
| nk = nk[len("module."):] |
| normalized[nk] = v |
| return normalized |
|
|
| def normalize_text(text: str) -> str: |
| return " ".join(text.strip().split()) |
|
|
|
|
| |
| |
| |
|
|
| def load_one_dataset(spec: dict): |
| kwargs = { |
| "path": spec["name"], |
| "split": spec["split"], |
| "streaming": spec["streaming"], |
| } |
| if spec["subset"] is not None: |
| kwargs["name"] = spec["subset"] |
| return load_dataset(**kwargs) |
|
|
| def format_record(row: dict, kind: str) -> str: |
| if kind == "codealpaca": |
| prompt = row.get("prompt", "") |
| completion = row.get("completion", "") |
| if not isinstance(prompt, str): |
| prompt = str(prompt) |
| if not isinstance(completion, str): |
| completion = str(completion) |
| text = ( |
| "### Instruction\n" |
| f"{prompt.strip()}\n\n" |
| "### Response\n" |
| f"{completion.strip()}" |
| ) |
| return normalize_text(text) |
|
|
| if kind == "codeforces_python": |
| language = row.get("language", "") |
| if language != CODEFORCES_LANGUAGE: |
| return "" |
|
|
| prompt = row.get("prompt", "") |
| title = row.get("title", "") |
| if not isinstance(prompt, str): |
| prompt = str(prompt) |
| if not isinstance(title, str): |
| title = str(title) |
|
|
| text = ( |
| f"### Competitive Programming Problem ({language})\n" |
| f"{title.strip()}\n\n" |
| f"{prompt.strip()}" |
| ) |
| return normalize_text(text) |
|
|
| if kind == "wikipedia_ar": |
| text = row.get("text", "") |
| if not isinstance(text, str): |
| text = str(text) |
| return normalize_text(text) |
|
|
| return "" |
|
|
| def example_text_iter(spec: dict, max_examples: Optional[int] = None) -> Iterator[str]: |
| ds = load_one_dataset(spec) |
| n = 0 |
| for row in ds: |
| text = format_record(row, spec["kind"]) |
| if not text or len(text) < 20: |
| continue |
| if TEXT_CHAR_LIMIT is not None: |
| text = text[:TEXT_CHAR_LIMIT] |
| yield text |
| n += 1 |
| if max_examples is not None and n >= max_examples: |
| break |
|
|
|
|
| class MixedTextSource: |
| def __init__(self, specs: list[dict]): |
| self.specs = specs |
| self.weights = [s["weight"] for s in specs] |
| self.streams = [example_text_iter(s) for s in specs] |
|
|
| def next_text(self) -> str: |
| while True: |
| idx = random.choices(range(len(self.specs)), weights=self.weights, k=1)[0] |
| try: |
| return next(self.streams[idx]) |
| except StopIteration: |
| self.streams[idx] = example_text_iter(self.specs[idx]) |
|
|
|
|
| def packed_block_stream_mixed( |
| tokenizer: PreTrainedTokenizerFast, |
| specs: list[dict], |
| block_size: int, |
| ) -> Iterator[list[int]]: |
| bos, eos = tokenizer.bos_token_id, tokenizer.eos_token_id |
| buffer: list[int] = [] |
| source = MixedTextSource(specs) |
|
|
| while True: |
| text = source.next_text() |
| ids = tokenizer.encode(text, add_special_tokens=False) |
| if not ids: |
| continue |
|
|
| buffer.extend([bos] + ids + [eos]) |
|
|
| while len(buffer) >= block_size + 1: |
| yield buffer[: block_size + 1] |
| buffer = buffer[block_size + 1:] |
|
|
|
|
| class PackedMixedBlocks(torch.utils.data.IterableDataset): |
| def __init__(self, tokenizer, specs, block_size): |
| super().__init__() |
| self.tokenizer = tokenizer |
| self.specs = specs |
| self.block_size = block_size |
|
|
| def __iter__(self): |
| worker = torch.utils.data.get_worker_info() |
| rank = get_rank() |
| world_size = get_world_size() |
|
|
| if worker is None: |
| shard_mod = world_size |
| shard_id = rank |
| else: |
| shard_mod = worker.num_workers * world_size |
| shard_id = rank * worker.num_workers + worker.id |
|
|
| for idx, chunk in enumerate( |
| packed_block_stream_mixed( |
| tokenizer=self.tokenizer, |
| specs=self.specs, |
| block_size=self.block_size, |
| ) |
| ): |
| if idx % shard_mod != shard_id: |
| continue |
|
|
| yield { |
| "input_ids": torch.tensor(chunk[:-1], dtype=torch.long), |
| "labels": torch.tensor(chunk[1:], dtype=torch.long), |
| } |
|
|
|
|
| class PackedEvalBlocks(torch.utils.data.IterableDataset): |
| def __init__(self, tokenizer, spec, block_size, max_examples): |
| super().__init__() |
| self.tokenizer = tokenizer |
| self.spec = spec |
| self.block_size = block_size |
| self.max_examples = max_examples |
|
|
| def __iter__(self): |
| worker = torch.utils.data.get_worker_info() |
| rank = get_rank() |
| world_size = get_world_size() |
|
|
| if worker is None: |
| shard_mod = world_size |
| shard_id = rank |
| else: |
| shard_mod = worker.num_workers * world_size |
| shard_id = rank * worker.num_workers + worker.id |
|
|
| bos, eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id |
| buffer: list[int] = [] |
|
|
| for ex_idx, text in enumerate(example_text_iter(self.spec, max_examples=self.max_examples)): |
| if ex_idx % shard_mod != shard_id: |
| continue |
|
|
| ids = self.tokenizer.encode(text, add_special_tokens=False) |
| if not ids: |
| continue |
|
|
| buffer.extend([bos] + ids + [eos]) |
|
|
| while len(buffer) >= self.block_size + 1: |
| chunk = buffer[: self.block_size + 1] |
| buffer = buffer[self.block_size + 1:] |
| yield { |
| "input_ids": torch.tensor(chunk[:-1], dtype=torch.long), |
| "labels": torch.tensor(chunk[1:], dtype=torch.long), |
| } |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class GPTConfig: |
| vocab_size: int |
| block_size: int |
| d_model: int |
| n_heads: int |
| n_layers: int |
| d_ff: int |
| dropout: float = 0.0 |
| use_checkpointing: bool = False |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim: int, base: int = 10000, max_seq: int = 4096): |
| super().__init__() |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| t = torch.arange(max_seq).float() |
| freqs = torch.outer(t, inv_freq) |
|
|
| self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False) |
| self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False) |
|
|
| def forward(self, seq_len: int, dtype: torch.dtype): |
| return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype) |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| x1, x2 = x[..., ::2], x[..., 1::2] |
| return torch.stack((-x2, x1), dim=-1).flatten(-2) |
|
|
| def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| cos = cos.unsqueeze(0).unsqueeze(0) |
| sin = sin.unsqueeze(0).unsqueeze(0) |
| return x * cos + rotate_half(x) * sin |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, cfg: GPTConfig): |
| super().__init__() |
| assert cfg.d_model % cfg.n_heads == 0 |
| self.n_heads = cfg.n_heads |
| self.head_dim = cfg.d_model // cfg.n_heads |
| self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False) |
| self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False) |
| self.dropout_p = cfg.dropout |
| self.rope = RotaryEmbedding(self.head_dim) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| b, t, c = x.shape |
| q, k, v = self.qkv(x).split(c, dim=-1) |
|
|
| q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) |
| k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) |
| v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
| cos, sin = self.rope(t, x.dtype) |
| q = apply_rope(q, cos, sin) |
| k = apply_rope(k, cos, sin) |
|
|
| y = F.scaled_dot_product_attention( |
| q, k, v, |
| dropout_p=self.dropout_p if self.training else 0.0, |
| is_causal=True, |
| ) |
| y = y.transpose(1, 2).contiguous().view(b, t, c) |
| return self.proj(y) |
|
|
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, cfg: GPTConfig): |
| super().__init__() |
| self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) |
| self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) |
| self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w3(F.silu(self.w1(x)) * self.w2(x)) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, cfg: GPTConfig): |
| super().__init__() |
| self.ln1 = RMSNorm(cfg.d_model) |
| self.attn = CausalSelfAttention(cfg) |
| self.ln2 = RMSNorm(cfg.d_model) |
| self.ff = SwiGLU(cfg) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x + self.attn(self.ln1(x)) |
| x = x + self.ff(self.ln2(x)) |
| return x |
|
|
|
|
| class GPT(nn.Module): |
| def __init__(self, cfg: GPTConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) |
| self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) |
| self.ln_f = RMSNorm(cfg.d_model) |
| self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) |
| self.lm_head.weight = self.tok_emb.weight |
| self.apply(self._init_weights) |
|
|
| @staticmethod |
| def _init_weights(m: nn.Module) -> None: |
| if isinstance(m, (nn.Linear, nn.Embedding)): |
| nn.init.normal_(m.weight, mean=0.0, std=0.02) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
| def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None): |
| x = self.tok_emb(input_ids) |
|
|
| for block in self.blocks: |
| if self.cfg.use_checkpointing and self.training: |
| x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) |
| else: |
| x = block(x) |
|
|
| logits = self.lm_head(self.ln_f(x)) |
| loss = None |
|
|
| if labels is not None: |
| loss = F.cross_entropy( |
| logits.reshape(-1, logits.size(-1)), |
| labels.reshape(-1), |
| ignore_index=-100, |
| ) |
|
|
| return logits, loss |
|
|
|
|
| |
| |
| |
|
|
| def build_optimizer(model: nn.Module) -> torch.optim.Optimizer: |
| decay, no_decay = [], [] |
| for name, p in unwrap_model(model).named_parameters(): |
| if not p.requires_grad: |
| continue |
| (decay if p.ndim >= 2 and "weight" in name else no_decay).append(p) |
|
|
| return torch.optim.AdamW( |
| [ |
| {"params": decay, "weight_decay": WEIGHT_DECAY}, |
| {"params": no_decay, "weight_decay": 0.0}, |
| ], |
| lr=LEARNING_RATE, |
| betas=(0.9, 0.95), |
| eps=1e-8, |
| fused=torch.cuda.is_available(), |
| ) |
|
|
| def cosine_lr(step: int) -> float: |
| if step < WARMUP_STEPS: |
| return LEARNING_RATE * step / max(1, WARMUP_STEPS) |
| p = min(1.0, (step - WARMUP_STEPS) / max(1, MAX_STEPS - WARMUP_STEPS)) |
| return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1 + math.cos(math.pi * p)) |
|
|
|
|
| |
| |
| |
|
|
| def load_base_config() -> GPTConfig: |
| cfg_dict = json.loads(BASE_CONFIG_FILE.read_text(encoding="utf-8")) |
| cfg_dict["use_checkpointing"] = USE_CHECKPOINTING |
| return GPTConfig(**cfg_dict) |
|
|
| def initialize_model_from_base(model: nn.Module, device: torch.device) -> None: |
| if not BASE_CHECKPOINT.exists(): |
| raise FileNotFoundError(f"Checkpoint de base introuvable: {BASE_CHECKPOINT}") |
| ckpt = torch.load(BASE_CHECKPOINT, map_location=device) |
| state_dict = normalize_state_dict_keys(ckpt["model"]) |
| unwrap_model(model).load_state_dict(state_dict, strict=True) |
|
|
| def save_checkpoint(model, optimizer, step, best_loss, path): |
| raw = unwrap_model(model) |
| model_state = normalize_state_dict_keys(raw.state_dict()) |
| torch.save( |
| { |
| "model": model_state, |
| "optimizer": optimizer.state_dict(), |
| "step": step, |
| "best_loss": best_loss, |
| "config": asdict(raw.cfg), |
| }, |
| path, |
| ) |
|
|
| def load_resume_checkpoint(model, optimizer, path, device) -> tuple[int, float]: |
| ckpt = torch.load(path, map_location=device) |
| raw = unwrap_model(model) |
| model_state = normalize_state_dict_keys(ckpt["model"]) |
| raw.load_state_dict(model_state, strict=True) |
|
|
| try: |
| optimizer.load_state_dict(ckpt["optimizer"]) |
| except Exception as e: |
| print(f"[warn] Optimizer state non repris: {e}") |
|
|
| return int(ckpt.get("step", 0)), float(ckpt.get("best_loss", 1e9)) |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def evaluate(model, loader, device, max_batches: int = 100) -> float: |
| model.eval() |
| losses = [] |
|
|
| for i, batch in enumerate(loader): |
| if i >= max_batches: |
| break |
|
|
| inp = batch["input_ids"].to(device, non_blocking=True) |
| lbl = batch["labels"].to(device, non_blocking=True) |
|
|
| with autocast_context(device): |
| _, loss = model(inp, lbl) |
|
|
| losses.append(loss.item()) |
|
|
| model.train() |
| return sum(losses) / max(1, len(losses)) |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| ddp_device = init_distributed() |
| set_seed(SEED + get_rank()) |
| device = get_device(ddp_device) |
|
|
| cuda_device_index = None |
| vram_fraction = None |
|
|
| if device.type == "cuda": |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.set_float32_matmul_precision("high") |
|
|
| cuda_device_index = current_cuda_index(device) |
| _, total_mem_bytes = torch.cuda.mem_get_info(cuda_device_index) |
| target_bytes = int(TARGET_VRAM_GIB * (1024 ** 3)) |
| vram_fraction = min(target_bytes / total_mem_bytes, 0.999) |
|
|
| torch.cuda.memory.set_per_process_memory_fraction( |
| vram_fraction, |
| device=cuda_device_index, |
| ) |
|
|
| if is_main(): |
| print("=" * 60) |
| print(" Re-train même modèle | 10 x 2000 steps") |
| print("=" * 60) |
| print(f"Device: {device} | World: {get_world_size()} GPU(s)") |
| if device.type == "cuda": |
| free_mem, total_mem = torch.cuda.mem_get_info(cuda_device_index) |
| print(f"GPU: {torch.cuda.get_device_name(cuda_device_index)}") |
| print(f"VRAM cible: {TARGET_VRAM_GIB:.1f} GiB") |
| print(f"Fraction PyTorch: {vram_fraction:.4f}") |
| print(f"GPU total: {total_mem / 1024**3:.2f} GiB | libre: {free_mem / 1024**3:.2f} GiB") |
| print(f"Rounds: {NUM_ROUNDS} | Steps/round: {STEPS_PER_ROUND} | MAX_STEPS: {MAX_STEPS}") |
|
|
| tokenizer = PreTrainedTokenizerFast.from_pretrained(str(BASE_TOKENIZER_DIR)) |
| cfg = load_base_config() |
| cfg.vocab_size = len(tokenizer) |
|
|
| if is_main(): |
| CONFIG_FILE.write_text( |
| json.dumps(asdict(cfg), indent=2, ensure_ascii=False), |
| encoding="utf-8", |
| ) |
| print(f"Base checkpoint: {BASE_CHECKPOINT}") |
| print(f"Tokenizer: {BASE_TOKENIZER_DIR}") |
|
|
| model = GPT(cfg).to(device) |
| initialize_model_from_base(model, device) |
|
|
| if USE_COMPILE and hasattr(torch, "compile"): |
| model = torch.compile(model, mode=COMPILE_MODE) |
| if is_main(): |
| print(f"torch.compile activé ({COMPILE_MODE})") |
|
|
| if is_distributed(): |
| model = DDP(model, device_ids=[device.index]) |
|
|
| optimizer = build_optimizer(model) |
|
|
| start_step, best_eval = 0, 1e9 |
| if STATE_FILE.exists(): |
| try: |
| if is_main(): |
| print(f"Reprise depuis {STATE_FILE}") |
| start_step, best_eval = load_resume_checkpoint(model, optimizer, STATE_FILE, device) |
| except Exception as e: |
| if is_main(): |
| bad_path = STATE_FILE.with_suffix(".corrupt.pt") |
| print(f"[warn] Checkpoint illisible: {e}") |
| try: |
| STATE_FILE.rename(bad_path) |
| print(f"[warn] Checkpoint corrompu renommé vers {bad_path}") |
| except Exception: |
| pass |
| print("[warn] Reprise ignorée, démarrage depuis le checkpoint de base.") |
| start_step, best_eval = 0, 1e9 |
|
|
| if start_step >= MAX_STEPS: |
| if is_main(): |
| print(f"[warn] start_step={start_step} >= MAX_STEPS={MAX_STEPS}") |
| print("[warn] Rien à entraîner.") |
| return |
|
|
| train_ds = PackedMixedBlocks( |
| tokenizer=tokenizer, |
| specs=TRAIN_SOURCES, |
| block_size=cfg.block_size, |
| ) |
|
|
| eval_ds = PackedEvalBlocks( |
| tokenizer=tokenizer, |
| spec=EVAL_SOURCE, |
| block_size=cfg.block_size, |
| max_examples=MAX_EVAL_EXAMPLES, |
| ) |
|
|
| train_loader = torch.utils.data.DataLoader( |
| train_ds, |
| batch_size=BATCH_SIZE, |
| num_workers=TRAIN_NUM_WORKERS, |
| pin_memory=(device.type == "cuda"), |
| ) |
|
|
| eval_loader = torch.utils.data.DataLoader( |
| eval_ds, |
| batch_size=BATCH_SIZE, |
| num_workers=EVAL_NUM_WORKERS, |
| pin_memory=(device.type == "cuda"), |
| ) |
|
|
| if is_main(): |
| raw_model = unwrap_model(model) |
| n_params = count_parameters(raw_model) |
| print(f"Paramètres: {n_params / 1e6:.1f}M") |
| print(f"Architecture: d={cfg.d_model} | heads={cfg.n_heads} | layers={cfg.n_layers} | block={cfg.block_size}") |
| print(f"Batch size: {BATCH_SIZE} | Grad accum: {GRAD_ACCUM_STEPS}") |
| print(f"Dtype: {DTYPE} | Compile: {USE_COMPILE} ({COMPILE_MODE if USE_COMPILE else 'off'})") |
|
|
| model.train() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| train_iter = iter(train_loader) |
| step = start_step |
| t0 = time.time() |
| log_loss_sum = 0.0 |
| log_loss_count = 0 |
| tokens_since_log = 0 |
| last_log = time.time() |
|
|
| if device.type == "cuda": |
| torch.cuda.reset_peak_memory_stats(cuda_device_index) |
|
|
| current_round = (step // STEPS_PER_ROUND) + 1 |
|
|
| while step < MAX_STEPS: |
| for _ in range(GRAD_ACCUM_STEPS): |
| batch = next(train_iter) |
|
|
| inp = batch["input_ids"].to(device, non_blocking=True) |
| lbl = batch["labels"].to(device, non_blocking=True) |
|
|
| with autocast_context(device): |
| _, loss = model(inp, lbl) |
|
|
| (loss / GRAD_ACCUM_STEPS).backward() |
|
|
| log_loss_sum += loss.item() |
| log_loss_count += 1 |
| tokens_since_log += inp.numel() |
|
|
| lr = cosine_lr(step) |
| for group in optimizer.param_groups: |
| group["lr"] = lr |
|
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM) |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
| step += 1 |
|
|
| new_round = ((step - 1) // STEPS_PER_ROUND) + 1 |
| if new_round != current_round and is_main(): |
| current_round = new_round |
| print(f"\n===== Round {current_round}/{NUM_ROUNDS} =====") |
|
|
| if step % 50 == 0 and is_main(): |
| now = time.time() |
| elapsed = max(1e-6, now - last_log) |
| tok_s = tokens_since_log / elapsed |
| avg_loss = log_loss_sum / max(1, log_loss_count) |
| round_idx = ((step - 1) // STEPS_PER_ROUND) + 1 |
| step_in_round = ((step - 1) % STEPS_PER_ROUND) + 1 |
|
|
| print( |
| f"round {round_idx:2d}/{NUM_ROUNDS} | " |
| f"step {step_in_round:4d}/{STEPS_PER_ROUND} | " |
| f"global {step:5d}/{MAX_STEPS} | " |
| f"loss={avg_loss:.4f} | lr={lr:.2e} | {tok_s:,.0f} tok/s" |
| ) |
|
|
| if device.type == "cuda": |
| allocated = torch.cuda.memory_allocated(cuda_device_index) / 1024**3 |
| reserved = torch.cuda.memory_reserved(cuda_device_index) / 1024**3 |
| max_alloc = torch.cuda.max_memory_allocated(cuda_device_index) / 1024**3 |
| max_reserved = torch.cuda.max_memory_reserved(cuda_device_index) / 1024**3 |
| print( |
| f"GPU mem | alloc={allocated:.2f} GiB | reserved={reserved:.2f} GiB | " |
| f"max_alloc={max_alloc:.2f} GiB | max_reserved={max_reserved:.2f} GiB" |
| ) |
|
|
| last_log = now |
| tokens_since_log = 0 |
| log_loss_sum = 0.0 |
| log_loss_count = 0 |
|
|
| if step % EVAL_EVERY == 0 and is_main(): |
| val_loss = evaluate(model, eval_loader, device) |
| print(f"[eval] global step {step:5d} | val_loss={val_loss:.4f}") |
|
|
| if val_loss < best_eval: |
| best_eval = val_loss |
| save_checkpoint(model, optimizer, step, best_eval, BEST_MODEL_FILE) |
| print(f"✓ Meilleur modèle → {BEST_MODEL_FILE}") |
|
|
| if step % SAVE_EVERY == 0 and is_main(): |
| save_checkpoint(model, optimizer, step, best_eval, STATE_FILE) |
| save_checkpoint(model, optimizer, step, best_eval, MODEL_FILE) |
| print(f"✓ Checkpoint → {MODEL_FILE}") |
|
|
| if step % STEPS_PER_ROUND == 0 and is_main(): |
| round_no = step // STEPS_PER_ROUND |
| round_ckpt = OUT_DIR / f"model_round_{round_no:02d}.pt" |
| save_checkpoint(model, optimizer, step, best_eval, round_ckpt) |
| print(f"✓ Fin round {round_no}/{NUM_ROUNDS} → {round_ckpt}") |
|
|
| if is_main(): |
| save_checkpoint(model, optimizer, step, best_eval, MODEL_FILE) |
| save_checkpoint(model, optimizer, step, best_eval, STATE_FILE) |
| total = (time.time() - t0) / 60 |
| print(f"\nModèle final → {MODEL_FILE}") |
| print(f"Meilleur modèle → {BEST_MODEL_FILE}") |
| print(f"Temps total : {total:.1f} min") |
| print(f"Steps effectués : {step}") |
|
|
| if is_distributed(): |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |