#!/usr/bin/env python3 # -*- coding: utf-8 -*- from __future__ import annotations import json import math import os import random import time from collections import OrderedDict from contextlib import nullcontext from dataclasses import asdict, dataclass from pathlib import Path from typing import Iterator, Optional import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP from datasets import load_dataset from transformers import PreTrainedTokenizerFast # ============================================================ # Base model / tokenizer / config # ============================================================ BASE_CHECKPOINT = Path("./wikipedia_ar_h100_codealpaca/model_best.pt") BASE_TOKENIZER_DIR = Path("./wikipedia_ar_h100/tokenizer_32k") BASE_CONFIG_FILE = Path("./wikipedia_ar_h100/config.json") OUT_DIR = Path("./wikipedia_ar_h100_multicode_10x2000") OUT_DIR.mkdir(parents=True, exist_ok=True) MODEL_FILE = OUT_DIR / "model.pt" BEST_MODEL_FILE = OUT_DIR / "model_best.pt" STATE_FILE = OUT_DIR / "train_state.pt" CONFIG_FILE = OUT_DIR / "config.json" # ============================================================ # Datasets # ============================================================ TRAIN_SOURCES = [ { "name": "HuggingFaceH4/CodeAlpaca_20K", "subset": None, "split": "train", "kind": "codealpaca", "weight": 0.45, "streaming": False, }, { "name": "open-r1/codeforces", "subset": "verifiable-prompts", "split": "train", "kind": "codeforces_python", "weight": 0.35, "streaming": False, }, { "name": "wikimedia/wikipedia", "subset": "20231101.ar", "split": "train", "kind": "wikipedia_ar", "weight": 0.20, "streaming": True, }, ] EVAL_SOURCE = { "name": "HuggingFaceH4/CodeAlpaca_20K", "subset": None, "split": "test", "kind": "codealpaca", "streaming": False, } CODEFORCES_LANGUAGE = "python" # ============================================================ # Hyperparamètres # ============================================================ SEED = 42 TARGET_VRAM_GIB = 75.0 LEARNING_RATE = 5e-5 MIN_LR = 5e-6 WEIGHT_DECAY = 0.1 WARMUP_STEPS = 200 NUM_ROUNDS = 10 STEPS_PER_ROUND = 2000 MAX_STEPS = NUM_ROUNDS * STEPS_PER_ROUND # 20000 BATCH_SIZE = 24 GRAD_ACCUM_STEPS = 1 MAX_GRAD_NORM = 1.0 EVAL_EVERY = 250 SAVE_EVERY = 500 MAX_EVAL_EXAMPLES = 2000 TEXT_CHAR_LIMIT = 6000 DTYPE = torch.bfloat16 USE_COMPILE = True COMPILE_MODE = "default" USE_CHECKPOINTING = False TRAIN_NUM_WORKERS = 0 EVAL_NUM_WORKERS = 0 # ============================================================ # Helpers # ============================================================ def is_distributed() -> bool: return dist.is_available() and dist.is_initialized() def get_rank() -> int: return dist.get_rank() if is_distributed() else 0 def get_world_size() -> int: return dist.get_world_size() if is_distributed() else 1 def is_main() -> bool: return get_rank() == 0 def init_distributed() -> Optional[torch.device]: local_rank = int(os.environ.get("LOCAL_RANK", -1)) if local_rank == -1: return None dist.init_process_group("nccl") torch.cuda.set_device(local_rank) return torch.device(f"cuda:{local_rank}") def set_seed(seed: int) -> None: random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def get_device(ddp_device: Optional[torch.device] = None) -> torch.device: if ddp_device is not None: return ddp_device if torch.cuda.is_available(): return torch.device(f"cuda:{torch.cuda.current_device()}") return torch.device("cpu") def current_cuda_index(device: torch.device) -> int: if device.type != "cuda": raise ValueError("Device non CUDA") return device.index if device.index is not None else torch.cuda.current_device() def autocast_context(device: torch.device): if device.type == "cuda": return torch.autocast("cuda", dtype=DTYPE) return nullcontext() def unwrap_model(model: nn.Module) -> nn.Module: m = model.module if isinstance(model, DDP) else model if hasattr(m, "_orig_mod"): return m._orig_mod return m def count_parameters(model: nn.Module) -> int: return sum(p.numel() for p in model.parameters() if p.requires_grad) def normalize_state_dict_keys(state_dict: dict) -> OrderedDict: normalized = OrderedDict() for k, v in state_dict.items(): nk = k if nk.startswith("module._orig_mod."): nk = nk[len("module._orig_mod."):] elif nk.startswith("_orig_mod."): nk = nk[len("_orig_mod."):] elif nk.startswith("module."): nk = nk[len("module."):] normalized[nk] = v return normalized def normalize_text(text: str) -> str: return " ".join(text.strip().split()) # ============================================================ # Dataset loading / formatting # ============================================================ def load_one_dataset(spec: dict): kwargs = { "path": spec["name"], "split": spec["split"], "streaming": spec["streaming"], } if spec["subset"] is not None: kwargs["name"] = spec["subset"] return load_dataset(**kwargs) def format_record(row: dict, kind: str) -> str: if kind == "codealpaca": prompt = row.get("prompt", "") completion = row.get("completion", "") if not isinstance(prompt, str): prompt = str(prompt) if not isinstance(completion, str): completion = str(completion) text = ( "### Instruction\n" f"{prompt.strip()}\n\n" "### Response\n" f"{completion.strip()}" ) return normalize_text(text) if kind == "codeforces_python": language = row.get("language", "") if language != CODEFORCES_LANGUAGE: return "" prompt = row.get("prompt", "") title = row.get("title", "") if not isinstance(prompt, str): prompt = str(prompt) if not isinstance(title, str): title = str(title) text = ( f"### Competitive Programming Problem ({language})\n" f"{title.strip()}\n\n" f"{prompt.strip()}" ) return normalize_text(text) if kind == "wikipedia_ar": text = row.get("text", "") if not isinstance(text, str): text = str(text) return normalize_text(text) return "" def example_text_iter(spec: dict, max_examples: Optional[int] = None) -> Iterator[str]: ds = load_one_dataset(spec) n = 0 for row in ds: text = format_record(row, spec["kind"]) if not text or len(text) < 20: continue if TEXT_CHAR_LIMIT is not None: text = text[:TEXT_CHAR_LIMIT] yield text n += 1 if max_examples is not None and n >= max_examples: break class MixedTextSource: def __init__(self, specs: list[dict]): self.specs = specs self.weights = [s["weight"] for s in specs] self.streams = [example_text_iter(s) for s in specs] def next_text(self) -> str: while True: idx = random.choices(range(len(self.specs)), weights=self.weights, k=1)[0] try: return next(self.streams[idx]) except StopIteration: self.streams[idx] = example_text_iter(self.specs[idx]) def packed_block_stream_mixed( tokenizer: PreTrainedTokenizerFast, specs: list[dict], block_size: int, ) -> Iterator[list[int]]: bos, eos = tokenizer.bos_token_id, tokenizer.eos_token_id buffer: list[int] = [] source = MixedTextSource(specs) while True: text = source.next_text() ids = tokenizer.encode(text, add_special_tokens=False) if not ids: continue buffer.extend([bos] + ids + [eos]) while len(buffer) >= block_size + 1: yield buffer[: block_size + 1] buffer = buffer[block_size + 1:] class PackedMixedBlocks(torch.utils.data.IterableDataset): def __init__(self, tokenizer, specs, block_size): super().__init__() self.tokenizer = tokenizer self.specs = specs self.block_size = block_size def __iter__(self): worker = torch.utils.data.get_worker_info() rank = get_rank() world_size = get_world_size() if worker is None: shard_mod = world_size shard_id = rank else: shard_mod = worker.num_workers * world_size shard_id = rank * worker.num_workers + worker.id for idx, chunk in enumerate( packed_block_stream_mixed( tokenizer=self.tokenizer, specs=self.specs, block_size=self.block_size, ) ): if idx % shard_mod != shard_id: continue yield { "input_ids": torch.tensor(chunk[:-1], dtype=torch.long), "labels": torch.tensor(chunk[1:], dtype=torch.long), } class PackedEvalBlocks(torch.utils.data.IterableDataset): def __init__(self, tokenizer, spec, block_size, max_examples): super().__init__() self.tokenizer = tokenizer self.spec = spec self.block_size = block_size self.max_examples = max_examples def __iter__(self): worker = torch.utils.data.get_worker_info() rank = get_rank() world_size = get_world_size() if worker is None: shard_mod = world_size shard_id = rank else: shard_mod = worker.num_workers * world_size shard_id = rank * worker.num_workers + worker.id bos, eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id buffer: list[int] = [] for ex_idx, text in enumerate(example_text_iter(self.spec, max_examples=self.max_examples)): if ex_idx % shard_mod != shard_id: continue ids = self.tokenizer.encode(text, add_special_tokens=False) if not ids: continue buffer.extend([bos] + ids + [eos]) while len(buffer) >= self.block_size + 1: chunk = buffer[: self.block_size + 1] buffer = buffer[self.block_size + 1:] yield { "input_ids": torch.tensor(chunk[:-1], dtype=torch.long), "labels": torch.tensor(chunk[1:], dtype=torch.long), } # ============================================================ # Architecture # ============================================================ @dataclass class GPTConfig: vocab_size: int block_size: int d_model: int n_heads: int n_layers: int d_ff: int dropout: float = 0.0 use_checkpointing: bool = False class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) class RotaryEmbedding(nn.Module): def __init__(self, dim: int, base: int = 10000, max_seq: int = 4096): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(max_seq).float() freqs = torch.outer(t, inv_freq) self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False) self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False) def forward(self, seq_len: int, dtype: torch.dtype): return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype) def rotate_half(x: torch.Tensor) -> torch.Tensor: x1, x2 = x[..., ::2], x[..., 1::2] return torch.stack((-x2, x1), dim=-1).flatten(-2) def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) return x * cos + rotate_half(x) * sin class CausalSelfAttention(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() assert cfg.d_model % cfg.n_heads == 0 self.n_heads = cfg.n_heads self.head_dim = cfg.d_model // cfg.n_heads self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False) self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False) self.dropout_p = cfg.dropout self.rope = RotaryEmbedding(self.head_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: b, t, c = x.shape q, k, v = self.qkv(x).split(c, dim=-1) q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) cos, sin = self.rope(t, x.dtype) q = apply_rope(q, cos, sin) k = apply_rope(k, cos, sin) y = F.scaled_dot_product_attention( q, k, v, dropout_p=self.dropout_p if self.training else 0.0, is_causal=True, ) y = y.transpose(1, 2).contiguous().view(b, t, c) return self.proj(y) class SwiGLU(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w3(F.silu(self.w1(x)) * self.w2(x)) class Block(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() self.ln1 = RMSNorm(cfg.d_model) self.attn = CausalSelfAttention(cfg) self.ln2 = RMSNorm(cfg.d_model) self.ff = SwiGLU(cfg) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.ln1(x)) x = x + self.ff(self.ln2(x)) return x class GPT(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() self.cfg = cfg self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) self.ln_f = RMSNorm(cfg.d_model) self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) self.lm_head.weight = self.tok_emb.weight self.apply(self._init_weights) @staticmethod def _init_weights(m: nn.Module) -> None: if isinstance(m, (nn.Linear, nn.Embedding)): nn.init.normal_(m.weight, mean=0.0, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.zeros_(m.bias) def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None): x = self.tok_emb(input_ids) for block in self.blocks: if self.cfg.use_checkpointing and self.training: x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) else: x = block(x) logits = self.lm_head(self.ln_f(x)) loss = None if labels is not None: loss = F.cross_entropy( logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=-100, ) return logits, loss # ============================================================ # Optimizer / LR # ============================================================ def build_optimizer(model: nn.Module) -> torch.optim.Optimizer: decay, no_decay = [], [] for name, p in unwrap_model(model).named_parameters(): if not p.requires_grad: continue (decay if p.ndim >= 2 and "weight" in name else no_decay).append(p) return torch.optim.AdamW( [ {"params": decay, "weight_decay": WEIGHT_DECAY}, {"params": no_decay, "weight_decay": 0.0}, ], lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8, fused=torch.cuda.is_available(), ) def cosine_lr(step: int) -> float: if step < WARMUP_STEPS: return LEARNING_RATE * step / max(1, WARMUP_STEPS) p = min(1.0, (step - WARMUP_STEPS) / max(1, MAX_STEPS - WARMUP_STEPS)) return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1 + math.cos(math.pi * p)) # ============================================================ # Checkpoints # ============================================================ def load_base_config() -> GPTConfig: cfg_dict = json.loads(BASE_CONFIG_FILE.read_text(encoding="utf-8")) cfg_dict["use_checkpointing"] = USE_CHECKPOINTING return GPTConfig(**cfg_dict) def initialize_model_from_base(model: nn.Module, device: torch.device) -> None: if not BASE_CHECKPOINT.exists(): raise FileNotFoundError(f"Checkpoint de base introuvable: {BASE_CHECKPOINT}") ckpt = torch.load(BASE_CHECKPOINT, map_location=device) state_dict = normalize_state_dict_keys(ckpt["model"]) unwrap_model(model).load_state_dict(state_dict, strict=True) def save_checkpoint(model, optimizer, step, best_loss, path): raw = unwrap_model(model) model_state = normalize_state_dict_keys(raw.state_dict()) torch.save( { "model": model_state, "optimizer": optimizer.state_dict(), "step": step, "best_loss": best_loss, "config": asdict(raw.cfg), }, path, ) def load_resume_checkpoint(model, optimizer, path, device) -> tuple[int, float]: ckpt = torch.load(path, map_location=device) raw = unwrap_model(model) model_state = normalize_state_dict_keys(ckpt["model"]) raw.load_state_dict(model_state, strict=True) try: optimizer.load_state_dict(ckpt["optimizer"]) except Exception as e: print(f"[warn] Optimizer state non repris: {e}") return int(ckpt.get("step", 0)), float(ckpt.get("best_loss", 1e9)) # ============================================================ # Evaluation # ============================================================ @torch.no_grad() def evaluate(model, loader, device, max_batches: int = 100) -> float: model.eval() losses = [] for i, batch in enumerate(loader): if i >= max_batches: break inp = batch["input_ids"].to(device, non_blocking=True) lbl = batch["labels"].to(device, non_blocking=True) with autocast_context(device): _, loss = model(inp, lbl) losses.append(loss.item()) model.train() return sum(losses) / max(1, len(losses)) # ============================================================ # Main # ============================================================ def main() -> None: ddp_device = init_distributed() set_seed(SEED + get_rank()) device = get_device(ddp_device) cuda_device_index = None vram_fraction = None if device.type == "cuda": torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.set_float32_matmul_precision("high") cuda_device_index = current_cuda_index(device) _, total_mem_bytes = torch.cuda.mem_get_info(cuda_device_index) target_bytes = int(TARGET_VRAM_GIB * (1024 ** 3)) vram_fraction = min(target_bytes / total_mem_bytes, 0.999) torch.cuda.memory.set_per_process_memory_fraction( vram_fraction, device=cuda_device_index, ) if is_main(): print("=" * 60) print(" Re-train même modèle | 10 x 2000 steps") print("=" * 60) print(f"Device: {device} | World: {get_world_size()} GPU(s)") if device.type == "cuda": free_mem, total_mem = torch.cuda.mem_get_info(cuda_device_index) print(f"GPU: {torch.cuda.get_device_name(cuda_device_index)}") print(f"VRAM cible: {TARGET_VRAM_GIB:.1f} GiB") print(f"Fraction PyTorch: {vram_fraction:.4f}") print(f"GPU total: {total_mem / 1024**3:.2f} GiB | libre: {free_mem / 1024**3:.2f} GiB") print(f"Rounds: {NUM_ROUNDS} | Steps/round: {STEPS_PER_ROUND} | MAX_STEPS: {MAX_STEPS}") tokenizer = PreTrainedTokenizerFast.from_pretrained(str(BASE_TOKENIZER_DIR)) cfg = load_base_config() cfg.vocab_size = len(tokenizer) if is_main(): CONFIG_FILE.write_text( json.dumps(asdict(cfg), indent=2, ensure_ascii=False), encoding="utf-8", ) print(f"Base checkpoint: {BASE_CHECKPOINT}") print(f"Tokenizer: {BASE_TOKENIZER_DIR}") model = GPT(cfg).to(device) initialize_model_from_base(model, device) if USE_COMPILE and hasattr(torch, "compile"): model = torch.compile(model, mode=COMPILE_MODE) if is_main(): print(f"torch.compile activé ({COMPILE_MODE})") if is_distributed(): model = DDP(model, device_ids=[device.index]) optimizer = build_optimizer(model) start_step, best_eval = 0, 1e9 if STATE_FILE.exists(): try: if is_main(): print(f"Reprise depuis {STATE_FILE}") start_step, best_eval = load_resume_checkpoint(model, optimizer, STATE_FILE, device) except Exception as e: if is_main(): bad_path = STATE_FILE.with_suffix(".corrupt.pt") print(f"[warn] Checkpoint illisible: {e}") try: STATE_FILE.rename(bad_path) print(f"[warn] Checkpoint corrompu renommé vers {bad_path}") except Exception: pass print("[warn] Reprise ignorée, démarrage depuis le checkpoint de base.") start_step, best_eval = 0, 1e9 if start_step >= MAX_STEPS: if is_main(): print(f"[warn] start_step={start_step} >= MAX_STEPS={MAX_STEPS}") print("[warn] Rien à entraîner.") return train_ds = PackedMixedBlocks( tokenizer=tokenizer, specs=TRAIN_SOURCES, block_size=cfg.block_size, ) eval_ds = PackedEvalBlocks( tokenizer=tokenizer, spec=EVAL_SOURCE, block_size=cfg.block_size, max_examples=MAX_EVAL_EXAMPLES, ) train_loader = torch.utils.data.DataLoader( train_ds, batch_size=BATCH_SIZE, num_workers=TRAIN_NUM_WORKERS, pin_memory=(device.type == "cuda"), ) eval_loader = torch.utils.data.DataLoader( eval_ds, batch_size=BATCH_SIZE, num_workers=EVAL_NUM_WORKERS, pin_memory=(device.type == "cuda"), ) if is_main(): raw_model = unwrap_model(model) n_params = count_parameters(raw_model) print(f"Paramètres: {n_params / 1e6:.1f}M") print(f"Architecture: d={cfg.d_model} | heads={cfg.n_heads} | layers={cfg.n_layers} | block={cfg.block_size}") print(f"Batch size: {BATCH_SIZE} | Grad accum: {GRAD_ACCUM_STEPS}") print(f"Dtype: {DTYPE} | Compile: {USE_COMPILE} ({COMPILE_MODE if USE_COMPILE else 'off'})") model.train() optimizer.zero_grad(set_to_none=True) train_iter = iter(train_loader) step = start_step t0 = time.time() log_loss_sum = 0.0 log_loss_count = 0 tokens_since_log = 0 last_log = time.time() if device.type == "cuda": torch.cuda.reset_peak_memory_stats(cuda_device_index) current_round = (step // STEPS_PER_ROUND) + 1 while step < MAX_STEPS: for _ in range(GRAD_ACCUM_STEPS): batch = next(train_iter) inp = batch["input_ids"].to(device, non_blocking=True) lbl = batch["labels"].to(device, non_blocking=True) with autocast_context(device): _, loss = model(inp, lbl) (loss / GRAD_ACCUM_STEPS).backward() log_loss_sum += loss.item() log_loss_count += 1 tokens_since_log += inp.numel() lr = cosine_lr(step) for group in optimizer.param_groups: group["lr"] = lr torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM) optimizer.step() optimizer.zero_grad(set_to_none=True) step += 1 new_round = ((step - 1) // STEPS_PER_ROUND) + 1 if new_round != current_round and is_main(): current_round = new_round print(f"\n===== Round {current_round}/{NUM_ROUNDS} =====") if step % 50 == 0 and is_main(): now = time.time() elapsed = max(1e-6, now - last_log) tok_s = tokens_since_log / elapsed avg_loss = log_loss_sum / max(1, log_loss_count) round_idx = ((step - 1) // STEPS_PER_ROUND) + 1 step_in_round = ((step - 1) % STEPS_PER_ROUND) + 1 print( f"round {round_idx:2d}/{NUM_ROUNDS} | " f"step {step_in_round:4d}/{STEPS_PER_ROUND} | " f"global {step:5d}/{MAX_STEPS} | " f"loss={avg_loss:.4f} | lr={lr:.2e} | {tok_s:,.0f} tok/s" ) if device.type == "cuda": allocated = torch.cuda.memory_allocated(cuda_device_index) / 1024**3 reserved = torch.cuda.memory_reserved(cuda_device_index) / 1024**3 max_alloc = torch.cuda.max_memory_allocated(cuda_device_index) / 1024**3 max_reserved = torch.cuda.max_memory_reserved(cuda_device_index) / 1024**3 print( f"GPU mem | alloc={allocated:.2f} GiB | reserved={reserved:.2f} GiB | " f"max_alloc={max_alloc:.2f} GiB | max_reserved={max_reserved:.2f} GiB" ) last_log = now tokens_since_log = 0 log_loss_sum = 0.0 log_loss_count = 0 if step % EVAL_EVERY == 0 and is_main(): val_loss = evaluate(model, eval_loader, device) print(f"[eval] global step {step:5d} | val_loss={val_loss:.4f}") if val_loss < best_eval: best_eval = val_loss save_checkpoint(model, optimizer, step, best_eval, BEST_MODEL_FILE) print(f"✓ Meilleur modèle → {BEST_MODEL_FILE}") if step % SAVE_EVERY == 0 and is_main(): save_checkpoint(model, optimizer, step, best_eval, STATE_FILE) save_checkpoint(model, optimizer, step, best_eval, MODEL_FILE) print(f"✓ Checkpoint → {MODEL_FILE}") if step % STEPS_PER_ROUND == 0 and is_main(): round_no = step // STEPS_PER_ROUND round_ckpt = OUT_DIR / f"model_round_{round_no:02d}.pt" save_checkpoint(model, optimizer, step, best_eval, round_ckpt) print(f"✓ Fin round {round_no}/{NUM_ROUNDS} → {round_ckpt}") if is_main(): save_checkpoint(model, optimizer, step, best_eval, MODEL_FILE) save_checkpoint(model, optimizer, step, best_eval, STATE_FILE) total = (time.time() - t0) / 60 print(f"\nModèle final → {MODEL_FILE}") print(f"Meilleur modèle → {BEST_MODEL_FILE}") print(f"Temps total : {total:.1f} min") print(f"Steps effectués : {step}") if is_distributed(): dist.destroy_process_group() if __name__ == "__main__": main()