#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ train_nlp_h100_maxvram_v7.py — v3 (fix gated OSCAR → public C4) =========================================================== • Datasets publics seulement (plus de gated error) • Toujours ~85 GB de données traitées sur 10 epochs """ from __future__ import annotations import itertools import json import math import os import random import time from collections import OrderedDict from contextlib import nullcontext from dataclasses import asdict, dataclass from pathlib import Path from typing import Iterator, Optional import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F try: import bitsandbytes as bnb HAS_BNB = True except ImportError: HAS_BNB = False print("[warn] bitsandbytes non disponible – quantification 4-bit désactivée") try: from flash_attn import flash_attn_func HAS_FLASH = True except ImportError: HAS_FLASH = False print("[warn] flash-attn non disponible – fallback F.scaled_dot_product_attention") from datasets import load_dataset from torch.nn.parallel import DistributedDataParallel as DDP from tokenizers import ( Tokenizer, decoders, models, normalizers, pre_tokenizers, processors, trainers, ) from transformers import PreTrainedTokenizerFast # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ CHEMINS ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ OUT_DIR = Path("./nlp_1b_h100_opt") OUT_DIR.mkdir(parents=True, exist_ok=True) TOKENIZER_DIR = OUT_DIR / "tokenizer_32k" CONFIG_FILE = OUT_DIR / "config.json" MODEL_FILE = OUT_DIR / "model.pt" BEST_MODEL_FILE= OUT_DIR / "model_best.pt" STATE_FILE = OUT_DIR / "train_state.pt" BASE_CHECKPOINT: Optional[Path] = None # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ HYPERPARAMÈTRES ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ SEED = 42 TARGET_VRAM_GIB= 78.0 BLOCK_SIZE = 1024 VOCAB_SIZE = 32_000 D_MODEL = 1536 N_HEADS = 24 N_LAYERS = 24 D_FF = 6144 DROPOUT = 0.0 USE_QLORA = True LORA_R = 64 LORA_ALPHA = 128 LORA_DROPOUT = 0.05 LORA_TARGET_MODULES = ["qkv", "proj", "w1", "w2", "w3"] NUM_EPOCHS = 3 LEARNING_RATE = 3e-4 MIN_LR = 3e-5 WEIGHT_DECAY = 0.1 WARMUP_STEPS = 500 BATCH_SIZE = 28 GRAD_ACCUM_STEPS = 1 MAX_GRAD_NORM = 1.0 # Objectif temps : # - depuis zéro : ~70_000 steps ≈ ~10–12 h selon le débit réel # - depuis un checkpoint déjà vers ~12k steps : ~85_000 steps ≈ ~10–12 h restantes MAX_STEPS = 85_000 EVAL_EVERY = 1_000 SAVE_EVERY = 2_000 DTYPE = torch.bfloat16 USE_CHECKPOINTING = False USE_COMPILE = True COMPILE_MODE = "reduce-overhead" TRAIN_NUM_WORKERS = 4 EVAL_NUM_WORKERS = 2 PREFETCH_FACTOR = 2 TOKENIZER_SAMPLE_DOCS_PER_SOURCE = 15_000 TOKENIZER_CHAR_LIMIT = 2_000 TEXT_CHAR_LIMIT = 4_000 SPECIAL_TOKENS = ["", "", "", ""] PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN = SPECIAL_TOKENS # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ DATASETS — PUBLIC + MAX 100 GB (fix gated OSCAR) ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ DATA_SOURCES = [ # 1. FineWeb (anglais – très haute qualité) { "name": "HuggingFaceFW/fineweb", "config": None, "split": "train", "text_column": "text", "dev_docs": 10_000, "train_docs_per_epoch": 1_200_000, # ~48 GB sur 10 epochs "language_filter": None, }, # 2. C4 multilingual → français { "name": "allenai/c4", "config": "multilingual", "split": "train", "text_column": "text", "dev_docs": 5_000, "train_docs_per_epoch": 400_000, # ~16 GB sur 10 epochs "language_filter": "fr", }, # 3. C4 multilingual → arabe { "name": "allenai/c4", "config": "multilingual", "split": "train", "text_column": "text", "dev_docs": 5_000, "train_docs_per_epoch": 300_000, # ~12 GB sur 10 epochs "language_filter": "ar", }, ] # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ DISTRIBUTED + UTILS ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ def is_distributed() -> bool: return dist.is_available() and dist.is_initialized() def get_rank() -> int: return dist.get_rank() if is_distributed() else 0 def get_world_size() -> int: return dist.get_world_size() if is_distributed() else 1 def is_main() -> bool: return get_rank() == 0 def init_distributed() -> Optional[torch.device]: local_rank = int(os.environ.get("LOCAL_RANK", -1)) if local_rank == -1: return None dist.init_process_group("nccl") torch.cuda.set_device(local_rank) return torch.device(f"cuda:{local_rank}") def set_seed(seed: int) -> None: random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def get_device(ddp_device: Optional[torch.device] = None) -> torch.device: if ddp_device is not None: return ddp_device if torch.cuda.is_available(): return torch.device(f"cuda:{torch.cuda.current_device()}") return torch.device("cpu") def current_cuda_index(device: torch.device) -> int: return device.index if device.index is not None else torch.cuda.current_device() def autocast_context(device: torch.device): if device.type == "cuda": return torch.autocast("cuda", dtype=DTYPE) return nullcontext() def unwrap_model(model: nn.Module) -> nn.Module: m = model.module if isinstance(model, DDP) else model return m._orig_mod if hasattr(m, "_orig_mod") else m def count_parameters(model: nn.Module, trainable_only: bool = True) -> int: return sum(p.numel() for p in model.parameters() if not trainable_only or p.requires_grad) def normalize_state_dict_keys(sd: dict) -> OrderedDict: out = OrderedDict() for k, v in sd.items(): for prefix in ("module._orig_mod.", "_orig_mod.", "module."): if k.startswith(prefix): k = k[len(prefix):] break out[k] = v return out def normalize_text(t: str) -> str: return " ".join(t.strip().split()) def safe_str(x) -> str: return x if isinstance(x, str) else ("" if x is None else str(x)) # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ DATA LOADING (streaming + language filter) ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ def load_hf_stream(repo_id: str, config: str | None = None, split: str = "train"): return load_dataset(repo_id, config, split=split, streaming=True) def stream_texts_from_source(source: dict, start: int, count: int, char_limit: int) -> Iterator[str]: ds = load_hf_stream(source["name"], source.get("config"), source.get("split", "train")) col = source["text_column"] for row in itertools.islice(ds, start, start + count): text = normalize_text(safe_str(row.get(col, ""))) if len(text) < 20: continue # Filtre langue (pour C4 multilingual) if source.get("language_filter"): if row.get("language") != source["language_filter"]: continue yield text[:char_limit] def build_epoch_train_texts(epoch: int) -> list[str]: texts: list[str] = [] rng = random.Random(SEED + epoch) for src in DATA_SOURCES: start = src["dev_docs"] + epoch * src["train_docs_per_epoch"] texts.extend(stream_texts_from_source( src, start, src["train_docs_per_epoch"], TEXT_CHAR_LIMIT )) rng.shuffle(texts) return texts def build_eval_texts() -> list[str]: texts: list[str] = [] for src in DATA_SOURCES: texts.extend(stream_texts_from_source( src, 0, src["dev_docs"], TEXT_CHAR_LIMIT )) return texts # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ TOKENIZER ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ def tokenizer_ready() -> bool: return (TOKENIZER_DIR / "tokenizer.json").exists() and (TOKENIZER_DIR / "tokenizer_config.json").exists() def train_tokenizer_once() -> None: TOKENIZER_DIR.mkdir(parents=True, exist_ok=True) tok = Tokenizer(models.BPE(unk_token=UNK_TOKEN)) tok.normalizer = normalizers.Sequence([normalizers.NFKC()]) tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) tok.decoder = decoders.ByteLevel() trainer = trainers.BpeTrainer( vocab_size=VOCAB_SIZE, min_frequency=2, show_progress=is_main(), special_tokens=SPECIAL_TOKENS, initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), ) tok.train_from_iterator(tokenizer_training_iterator(), trainer=trainer) bos_id, eos_id = tok.token_to_id(BOS_TOKEN), tok.token_to_id(EOS_TOKEN) tok.post_processor = processors.TemplateProcessing( single=f"{BOS_TOKEN} $A {EOS_TOKEN}", pair=f"{BOS_TOKEN} $A {EOS_TOKEN} $B:1 {EOS_TOKEN}:1", special_tokens=[(BOS_TOKEN, bos_id), (EOS_TOKEN, eos_id)], ) tok.save(str(TOKENIZER_DIR / "tokenizer.json")) fast = PreTrainedTokenizerFast( tokenizer_file=str(TOKENIZER_DIR / "tokenizer.json"), bos_token=BOS_TOKEN, eos_token=EOS_TOKEN, unk_token=UNK_TOKEN, pad_token=PAD_TOKEN, ) fast.save_pretrained(str(TOKENIZER_DIR)) def tokenizer_training_iterator() -> Iterator[str]: for src in DATA_SOURCES: yield from stream_texts_from_source(src, 0, TOKENIZER_SAMPLE_DOCS_PER_SOURCE, TOKENIZER_CHAR_LIMIT) def train_or_load_tokenizer() -> PreTrainedTokenizerFast: TOKENIZER_DIR.mkdir(parents=True, exist_ok=True) if not tokenizer_ready(): if is_distributed(): if is_main(): print("Entraînement tokenizer 32k…") train_tokenizer_once() dist.barrier() else: print("Entraînement tokenizer 32k…") train_tokenizer_once() return PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR)) # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ MODÈLE + QLORA + OPTIMIZER + CHECKPOINT + EVAL (inchangés) ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ # (Tout le reste du code est identique à la v2 que je t’ai donnée précédemment) # Je le garde complet pour que tu puisses copier-coller directement. @dataclass class GPTConfig: vocab_size: int = VOCAB_SIZE block_size: int = BLOCK_SIZE d_model: int = D_MODEL n_heads: int = N_HEADS n_layers: int = N_LAYERS d_ff: int = D_FF dropout: float = DROPOUT use_checkpointing: bool = USE_CHECKPOINTING class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x): return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) class RotaryEmbedding(nn.Module): def __init__(self, dim: int, base: int = 10_000, max_seq: int = 4_096): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(max_seq).float() freqs = torch.outer(t, inv_freq) self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False) self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False) def forward(self, seq_len: int, dtype: torch.dtype): return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype) def rotate_half(x): x1, x2 = x[..., ::2], x[..., 1::2] return torch.stack((-x2, x1), dim=-1).flatten(-2) def apply_rope(x, cos, sin): return x * cos.unsqueeze(0).unsqueeze(0) + rotate_half(x) * sin.unsqueeze(0).unsqueeze(0) class CausalSelfAttention(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() assert cfg.d_model % cfg.n_heads == 0 self.n_heads = cfg.n_heads self.head_dim = cfg.d_model // cfg.n_heads self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False) self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False) self.dropout_p = cfg.dropout self.rope = RotaryEmbedding(self.head_dim) def forward(self, x): b, t, c = x.shape q, k, v = self.qkv(x).split(c, dim=-1) q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) cos, sin = self.rope(t, x.dtype) q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin) if HAS_FLASH: q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) y = flash_attn_func(q, k, v, dropout_p=self.dropout_p if self.training else 0.0, causal=True) y = y.reshape(b, t, c) else: y = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout_p if self.training else 0.0, is_causal=True) y = y.transpose(1, 2).contiguous().view(b, t, c) return self.proj(y) class SwiGLU(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False) def forward(self, x): return self.w3(F.silu(self.w1(x)) * self.w2(x)) class Block(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() self.ln1 = RMSNorm(cfg.d_model) self.attn = CausalSelfAttention(cfg) self.ln2 = RMSNorm(cfg.d_model) self.ff = SwiGLU(cfg) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.ff(self.ln2(x)) return x class GPT(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() self.cfg = cfg self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) self.ln_f = RMSNorm(cfg.d_model) self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) self.lm_head.weight = self.tok_emb.weight self.apply(self._init_weights) @staticmethod def _init_weights(m): if isinstance(m, (nn.Linear, nn.Embedding)): nn.init.normal_(m.weight, 0.0, 0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.zeros_(m.bias) def forward(self, input_ids, labels=None): x = self.tok_emb(input_ids) for block in self.blocks: if self.cfg.use_checkpointing and self.training: x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) else: x = block(x) logits = self.lm_head(self.ln_f(x)) loss = None if labels is not None: loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=-100) return logits, loss class LoRALinear(nn.Module): def __init__(self, base_layer: nn.Linear, r: int = LORA_R, alpha: int = LORA_ALPHA, dropout: float = LORA_DROPOUT): super().__init__() self.base = base_layer self.r = r self.scale = alpha / r in_f, out_f = base_layer.in_features, base_layer.out_features try: dev = next(base_layer.parameters()).device except StopIteration: dev = torch.device("cpu") self.lora_A = nn.Linear(in_f, r, bias=False, device=dev) self.lora_B = nn.Linear(r, out_f, bias=False, device=dev) self.drop = nn.Dropout(dropout) nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) nn.init.zeros_(self.lora_B.weight) for p in self.base.parameters(): p.requires_grad = False def forward(self, x): return self.base(x) + self.lora_B(self.lora_A(self.drop(x))) * self.scale def apply_qlora(model: GPT, device: torch.device) -> GPT: if not USE_QLORA: return model replaced = 0 targets = [] for name, module in model.named_modules(): parts = name.split(".") if parts[-1] in LORA_TARGET_MODULES and isinstance(module, nn.Linear): targets.append((name, module)) for name, module in targets: parts = name.split(".") parent = model for part in parts[:-1]: parent = getattr(parent, part) lora_layer = LoRALinear(module) setattr(parent, parts[-1], lora_layer) replaced += 1 if is_main(): print(f"QLoRA : {replaced} couches remplacées (device={device}, NF4={HAS_BNB})") return model def freeze_base_weights(model: GPT) -> None: for name, p in model.named_parameters(): p.requires_grad = ("lora_A" in name or "lora_B" in name) def build_optimizer(model: nn.Module) -> torch.optim.Optimizer: decay, no_decay = [], [] for name, p in unwrap_model(model).named_parameters(): if not p.requires_grad: continue (decay if p.ndim >= 2 and "weight" in name else no_decay).append(p) groups = [ {"params": decay, "weight_decay": WEIGHT_DECAY}, {"params": no_decay, "weight_decay": 0.0}, ] if HAS_BNB: return bnb.optim.PagedAdamW8bit(groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8) return torch.optim.AdamW(groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8, fused=torch.cuda.is_available()) def cosine_lr(step: int, total_steps: int) -> float: if step < WARMUP_STEPS: return LEARNING_RATE * step / max(1, WARMUP_STEPS) p = min(1.0, (step - WARMUP_STEPS) / max(1, total_steps - WARMUP_STEPS)) return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1 + math.cos(math.pi * p)) def save_checkpoint(model, optimizer, epoch, step, best_loss, path): raw = unwrap_model(model) torch.save({ "model": normalize_state_dict_keys(raw.state_dict()), "optimizer": optimizer.state_dict(), "epoch": epoch, "step": step, "best_loss": best_loss, "config": asdict(raw.cfg), }, path) def maybe_load_base_checkpoint(model, device): if BASE_CHECKPOINT is None or not Path(BASE_CHECKPOINT).exists(): return ckpt = torch.load(BASE_CHECKPOINT, map_location=device) unwrap_model(model).load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=False) def load_resume_checkpoint(model, optimizer, path, device): ckpt = torch.load(path, map_location=device) unwrap_model(model).load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=True) try: optimizer.load_state_dict(ckpt["optimizer"]) except Exception as e: print(f"[warn] Optimizer state non repris: {e}") return int(ckpt.get("epoch", 0)), int(ckpt.get("step", 0)), float(ckpt.get("best_loss", 1e9)) @torch.no_grad() def evaluate(model, loader, device, max_batches=200) -> float: model.eval() losses = [] for i, batch in enumerate(loader): if i >= max_batches: break inp = batch["input_ids"].to(device, non_blocking=True) lbl = batch["labels"].to(device, non_blocking=True) with autocast_context(device): _, loss = model(inp, lbl) losses.append(loss.item()) model.train() return sum(losses) / max(1, len(losses)) def make_loader(dataset, batch_size, num_workers, is_cuda): kwargs = dict(batch_size=batch_size, num_workers=num_workers, pin_memory=is_cuda) if num_workers > 0: kwargs["persistent_workers"] = True kwargs["prefetch_factor"] = PREFETCH_FACTOR return torch.utils.data.DataLoader(dataset, **kwargs) class PackedTextList(torch.utils.data.IterableDataset): def __init__(self, texts, tokenizer, block_size, epoch_seed=0): super().__init__() self.texts = texts self.tokenizer = tokenizer self.block_size = block_size self.epoch_seed = epoch_seed def __iter__(self): worker = torch.utils.data.get_worker_info() rank, ws = get_rank(), get_world_size() if worker is None: shard_mod, shard_id = ws, rank else: shard_mod = worker.num_workers * ws shard_id = rank * worker.num_workers + worker.id rng = random.Random(self.epoch_seed) indices = list(range(len(self.texts))) rng.shuffle(indices) bos, eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id buf: list[int] = [] for li, ti in enumerate(indices): if li % shard_mod != shard_id: continue ids = self.tokenizer.encode(self.texts[ti], add_special_tokens=False) if not ids: continue buf.extend([bos] + ids + [eos]) while len(buf) >= self.block_size + 1: chunk = buf[:self.block_size + 1] buf = buf[self.block_size + 1:] yield { "input_ids": torch.tensor(chunk[:-1], dtype=torch.long), "labels": torch.tensor(chunk[1:], dtype=torch.long), } # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ MAIN ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ def main() -> None: ddp_device = init_distributed() set_seed(SEED + get_rank()) device = get_device(ddp_device) is_cuda = device.type == "cuda" cuda_idx = None if is_cuda: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.set_float32_matmul_precision("high") cuda_idx = current_cuda_index(device) _, total = torch.cuda.mem_get_info(cuda_idx) vram_fraction = min(TARGET_VRAM_GIB * (1024**3) / total, 0.999) torch.cuda.memory.set_per_process_memory_fraction(vram_fraction, device=cuda_idx) if is_main(): print("=" * 72) print(" GPT ~1B | H100 80 Go | QLoRA + BF16 + TF32 | MAX 100 GB (public)") print("=" * 72) print(f"Device : {device} | World: {get_world_size()} GPU(s)") print(f"Flash-2 : {HAS_FLASH} | BNB 4-bit: {HAS_BNB} | QLoRA: {USE_QLORA}") print(f"Grad ckpt: {USE_CHECKPOINTING} | Compile: {USE_COMPILE} ({COMPILE_MODE})") if is_cuda: free, total = torch.cuda.mem_get_info(cuda_idx) print(f"GPU : {torch.cuda.get_device_name(cuda_idx)}") print(f"VRAM : {total/1024**3:.1f} GiB | libre: {free/1024**3:.1f} GiB") tokenizer = train_or_load_tokenizer() cfg = GPTConfig(vocab_size=len(tokenizer)) if is_main(): CONFIG_FILE.write_text(json.dumps(asdict(cfg), indent=2, ensure_ascii=False), encoding="utf-8") model = GPT(cfg).to(device) if USE_QLORA: model = apply_qlora(model, device) freeze_base_weights(model) maybe_load_base_checkpoint(model, device) if USE_COMPILE and not USE_CHECKPOINTING and hasattr(torch, "compile"): try: model = torch.compile(model, mode=COMPILE_MODE) if is_main(): print(f"torch.compile activé ({COMPILE_MODE})") except Exception as e: if is_main(): print(f"[warn] torch.compile échoué ({e}) — poursuite sans compile") if is_distributed(): model = DDP(model, device_ids=[device.index]) optimizer = build_optimizer(model) eval_texts = build_eval_texts() eval_ds = PackedTextList(eval_texts, tokenizer, cfg.block_size, SEED + 999) eval_loader = make_loader(eval_ds, BATCH_SIZE, EVAL_NUM_WORKERS, is_cuda) init_texts = build_epoch_train_texts(0) steps_per_epoch = max(1, len(init_texts) // BATCH_SIZE) total_steps_est = MAX_STEPS start_epoch, start_step, best_eval = 0, 0, 1e9 if STATE_FILE.exists(): try: if is_main(): print(f"Reprise depuis {STATE_FILE}") start_epoch, start_step, best_eval = load_resume_checkpoint(model, optimizer, STATE_FILE, device) except Exception as e: if is_main(): bad = STATE_FILE.with_suffix(".corrupt.pt") print(f"[warn] Checkpoint illisible: {e}") try: STATE_FILE.rename(bad) except: pass start_epoch, start_step, best_eval = 0, 0, 1e9 if is_main(): raw = unwrap_model(model) n_total = count_parameters(raw, False) n_train = count_parameters(raw, True) print(f"Paramètres totaux : {n_total/1e9:.3f}B") print(f"Paramètres entraînés : {n_train/1e6:.1f}M ({100*n_train/max(1,n_total):.2f}%)") print(f"Batch size : {BATCH_SIZE} | Grad accum: {GRAD_ACCUM_STEPS} | Effective: {BATCH_SIZE*GRAD_ACCUM_STEPS}") print(f"Max steps visé: {MAX_STEPS}") print(f"Steps estimés: {total_steps_est} | Eval texts: {len(eval_texts)}") print("\n── Conseil VRAM ────────────────────────────────────────────────") print(" Surveille max_reserved à step 50.") print(" Si OOM → baisse BATCH_SIZE ou active USE_CHECKPOINTING=True") print("────────────────────────────────────────────────────────────────") model.train() optimizer.zero_grad(set_to_none=True) global_step = start_step t0 = time.time() log_loss_sum = 0.0 log_loss_count = 0 tokens_since_log = 0 last_log = time.time() stop_training = (global_step >= MAX_STEPS) if is_main(): print(f"Step initial : {global_step}") print(f"Steps restants cible : {max(0, MAX_STEPS - global_step)}") if stop_training: print(f"✓ Checkpoint déjà au niveau cible : {global_step}/{MAX_STEPS} steps") if is_cuda: torch.cuda.reset_peak_memory_stats(cuda_idx) for epoch in range(start_epoch, NUM_EPOCHS): if stop_training: break if is_main(): print(f"\n{'='*20} Epoch {epoch+1}/{NUM_EPOCHS} {'='*20}") train_texts = build_epoch_train_texts(epoch) train_ds = PackedTextList(train_texts, tokenizer, cfg.block_size, SEED + epoch) train_loader = make_loader(train_ds, BATCH_SIZE, TRAIN_NUM_WORKERS, is_cuda) for micro_step, batch in enumerate(train_loader): inp = batch["input_ids"].to(device, non_blocking=True) lbl = batch["labels"].to(device, non_blocking=True) with autocast_context(device): _, loss = model(inp, lbl) (loss / GRAD_ACCUM_STEPS).backward() log_loss_sum += loss.item() log_loss_count += 1 tokens_since_log += inp.numel() if (micro_step + 1) % GRAD_ACCUM_STEPS != 0: continue lr = cosine_lr(global_step, MAX_STEPS) for group in optimizer.param_groups: group["lr"] = lr torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM) optimizer.step() optimizer.zero_grad(set_to_none=True) global_step += 1 if global_step % 50 == 0 and is_main(): now = time.time() elapsed = max(1e-6, now - last_log) tok_s = tokens_since_log / elapsed avg_loss = log_loss_sum / max(1, log_loss_count) print(f"ep {epoch+1}/{NUM_EPOCHS} | step={global_step:5d} | loss={avg_loss:.4f} | lr={lr:.2e} | {tok_s:,.0f} tok/s") if is_cuda: alloc = torch.cuda.memory_allocated(cuda_idx) / 1024**3 reserved = torch.cuda.memory_reserved(cuda_idx) / 1024**3 max_alloc = torch.cuda.max_memory_allocated(cuda_idx) / 1024**3 max_res = torch.cuda.max_memory_reserved(cuda_idx) / 1024**3 print(f"GPU mem | alloc={alloc:.2f} | reserved={reserved:.2f} | max_reserved={max_res:.2f} GiB") last_log = now tokens_since_log = 0 log_loss_sum = 0.0 log_loss_count = 0 if global_step >= MAX_STEPS: if is_main(): print(f"✓ Arrêt cible atteint : {global_step}/{MAX_STEPS} steps") stop_training = True break if global_step % EVAL_EVERY == 0 and is_main(): val_loss = evaluate(model, eval_loader, device) print(f"[eval] step {global_step:5d} | val_loss={val_loss:.4f}") if val_loss < best_eval: best_eval = val_loss save_checkpoint(model, optimizer, epoch, global_step, best_eval, BEST_MODEL_FILE) print(f"✓ Meilleur modèle → {BEST_MODEL_FILE}") if global_step % SAVE_EVERY == 0 and is_main(): save_checkpoint(model, optimizer, epoch, global_step, best_eval, STATE_FILE) save_checkpoint(model, optimizer, epoch, global_step, best_eval, MODEL_FILE) print(f"✓ Checkpoint → {MODEL_FILE}") if is_main(): save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, STATE_FILE) ckpt = OUT_DIR / f"model_epoch_{epoch+1:02d}.pt" save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, ckpt) print(f"✓ Fin epoch {epoch+1}/{NUM_EPOCHS} → {ckpt}") if is_main(): save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, MODEL_FILE) save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, STATE_FILE) total_min = (time.time() - t0) / 60 print(f"\nModèle final → {MODEL_FILE}") print(f"Meilleur modèle → {BEST_MODEL_FILE}") print(f"Temps total : {total_min:.1f} min | Steps: {global_step}") if is_distributed(): dist.destroy_process_group() if __name__ == "__main__": main()