#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ train_nlp_h100_optimized.py — v2 (bugfix device mismatch) =========================================================== Corrections vs v1 : • apply_qlora() appelé APRÈS model.to(device) → lora_A/lora_B naissent sur CUDA • LoRALinear.__init__ : move explicite des adaptateurs sur le device du base_layer • torch.compile désactivé quand USE_CHECKPOINTING=True (conflict dynamo+checkpoint avec sous-modules custom) — on utilise COMPILE_AFTER_CKPT pour les cas où on veut quand même compiler (USE_CHECKPOINTING=False) • Ajout d'un fallback propre : si compile crash, on continue sans compile """ from __future__ import annotations import itertools import json import math import os import random import time from collections import OrderedDict from contextlib import nullcontext from dataclasses import asdict, dataclass from pathlib import Path from typing import Iterator, Optional import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F try: import bitsandbytes as bnb from bitsandbytes.nn import Params4bit HAS_BNB = True except ImportError: HAS_BNB = False print("[warn] bitsandbytes non disponible – quantification 4-bit désactivée") try: from flash_attn import flash_attn_func HAS_FLASH = True except ImportError: HAS_FLASH = False print("[warn] flash-attn non disponible – fallback F.scaled_dot_product_attention") from datasets import load_dataset from torch.nn.parallel import DistributedDataParallel as DDP from tokenizers import ( Tokenizer, decoders, models, normalizers, pre_tokenizers, processors, trainers, ) from transformers import PreTrainedTokenizerFast # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ CHEMINS ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ OUT_DIR = Path("./nlp_1b_h100_opt") OUT_DIR.mkdir(parents=True, exist_ok=True) TOKENIZER_DIR = OUT_DIR / "tokenizer_32k" CONFIG_FILE = OUT_DIR / "config.json" MODEL_FILE = OUT_DIR / "model.pt" BEST_MODEL_FILE= OUT_DIR / "model_best.pt" STATE_FILE = OUT_DIR / "train_state.pt" BASE_CHECKPOINT: Optional[Path] = None # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ HYPERPARAMÈTRES ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ SEED = 42 TARGET_VRAM_GIB= 78.0 BLOCK_SIZE = 1024 VOCAB_SIZE = 32_000 D_MODEL = 1536 N_HEADS = 24 N_LAYERS = 24 D_FF = 6144 DROPOUT = 0.0 USE_QLORA = True LORA_R = 64 LORA_ALPHA = 128 LORA_DROPOUT = 0.05 LORA_TARGET_MODULES = ["qkv", "proj", "w1", "w2", "w3"] NUM_EPOCHS = 10 LEARNING_RATE = 3e-4 MIN_LR = 3e-5 WEIGHT_DECAY = 0.1 WARMUP_STEPS = 500 # ┌─────────────────────────────────────────────────────────────────────────────┐ # │ RÉGLAGE BATCH SIZE → 78 Go VRAM │ # │ Démarrer : BATCH_SIZE=8, GRAD_ACCUM_STEPS=2 │ # │ Augmenter BATCH_SIZE par +2 jusqu'à max_reserved ≈ 77 Go dans les logs │ # │ Si OOM : BATCH_SIZE -= 1 ou USE_CHECKPOINTING=True │ # └─────────────────────────────────────────────────────────────────────────────┘ BATCH_SIZE = 16 GRAD_ACCUM_STEPS = 1 MAX_GRAD_NORM = 1.0 EVAL_EVERY = 500 SAVE_EVERY = 1_000 DTYPE = torch.bfloat16 # ── Compile : désactivé quand USE_CHECKPOINTING=True pour éviter le conflict # dynamo ↔ checkpoint ↔ sous-modules custom (LoRALinear). # Mettre USE_CHECKPOINTING=False ET USE_COMPILE=True pour vitesse max. USE_CHECKPOINTING = False # économise ~8× activations VRAM USE_COMPILE = True # ← mettre True seulement si USE_CHECKPOINTING=False COMPILE_MODE = "reduce-overhead" TRAIN_NUM_WORKERS = 4 EVAL_NUM_WORKERS = 2 PREFETCH_FACTOR = 2 TOKENIZER_SAMPLE_DOCS_PER_SOURCE = 15_000 TOKENIZER_CHAR_LIMIT = 2_000 TEXT_CHAR_LIMIT = 4_000 SPECIAL_TOKENS = ["", "", "", ""] PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN = SPECIAL_TOKENS WIKI_CONFIGS = ["20231101.en", "20231101.fr", "20231101.ar"] FINEWEB_CONFIG = "sample-10BT" DEV_DOCS_PER_WIKI_CONFIG = 1_500 DEV_DOCS_FINEWEB = 3_000 TRAIN_DOCS_PER_WIKI_CONFIG = 30_000 TRAIN_DOCS_FINEWEB = 60_000 # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ DISTRIBUTED ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ def is_distributed() -> bool: return dist.is_available() and dist.is_initialized() def get_rank() -> int: return dist.get_rank() if is_distributed() else 0 def get_world_size() -> int: return dist.get_world_size() if is_distributed() else 1 def is_main() -> bool: return get_rank() == 0 def init_distributed() -> Optional[torch.device]: local_rank = int(os.environ.get("LOCAL_RANK", -1)) if local_rank == -1: return None dist.init_process_group("nccl") torch.cuda.set_device(local_rank) return torch.device(f"cuda:{local_rank}") def set_seed(seed: int) -> None: random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def get_device(ddp_device: Optional[torch.device] = None) -> torch.device: if ddp_device is not None: return ddp_device if torch.cuda.is_available(): return torch.device(f"cuda:{torch.cuda.current_device()}") return torch.device("cpu") def current_cuda_index(device: torch.device) -> int: return device.index if device.index is not None else torch.cuda.current_device() def autocast_context(device: torch.device): if device.type == "cuda": return torch.autocast("cuda", dtype=DTYPE) return nullcontext() def unwrap_model(model: nn.Module) -> nn.Module: m = model.module if isinstance(model, DDP) else model return m._orig_mod if hasattr(m, "_orig_mod") else m def count_parameters(model: nn.Module, trainable_only: bool = True) -> int: return sum(p.numel() for p in model.parameters() if not trainable_only or p.requires_grad) def normalize_state_dict_keys(sd: dict) -> OrderedDict: out = OrderedDict() for k, v in sd.items(): for prefix in ("module._orig_mod.", "_orig_mod.", "module."): if k.startswith(prefix): k = k[len(prefix):] break out[k] = v return out def normalize_text(t: str) -> str: return " ".join(t.strip().split()) def safe_str(x) -> str: return x if isinstance(x, str) else ("" if x is None else str(x)) # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ DATASETS ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ def load_wiki_stream(cfg_name: str): return load_dataset("wikimedia/wikipedia", cfg_name, split="train", streaming=True) def load_fineweb_stream(): return load_dataset("HuggingFaceFW/fineweb-edu", FINEWEB_CONFIG, split="train", streaming=True) def stream_texts(ds, start: int, count: int, char_limit: int) -> Iterator[str]: for row in itertools.islice(ds, start, start + count): text = normalize_text(safe_str(row.get("text", ""))) if len(text) >= 20: yield text[:char_limit] def tokenizer_training_iterator() -> Iterator[str]: for c in WIKI_CONFIGS: yield from stream_texts(load_wiki_stream(c), 0, TOKENIZER_SAMPLE_DOCS_PER_SOURCE, TOKENIZER_CHAR_LIMIT) yield from stream_texts(load_fineweb_stream(), 0, TOKENIZER_SAMPLE_DOCS_PER_SOURCE, TOKENIZER_CHAR_LIMIT) def build_epoch_train_texts(epoch: int) -> list[str]: texts: list[str] = [] for c in WIKI_CONFIGS: start = DEV_DOCS_PER_WIKI_CONFIG + epoch * TRAIN_DOCS_PER_WIKI_CONFIG texts.extend(stream_texts(load_wiki_stream(c), start, TRAIN_DOCS_PER_WIKI_CONFIG, TEXT_CHAR_LIMIT)) start = DEV_DOCS_FINEWEB + epoch * TRAIN_DOCS_FINEWEB texts.extend(stream_texts(load_fineweb_stream(), start, TRAIN_DOCS_FINEWEB, TEXT_CHAR_LIMIT)) random.Random(SEED + epoch).shuffle(texts) return texts def build_eval_texts() -> list[str]: texts: list[str] = [] for c in WIKI_CONFIGS: texts.extend(stream_texts(load_wiki_stream(c), 0, DEV_DOCS_PER_WIKI_CONFIG, TEXT_CHAR_LIMIT)) texts.extend(stream_texts(load_fineweb_stream(), 0, DEV_DOCS_FINEWEB, TEXT_CHAR_LIMIT)) return texts # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ PACKED DATASET ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ class PackedTextList(torch.utils.data.IterableDataset): def __init__(self, texts, tokenizer, block_size, epoch_seed=0): super().__init__() self.texts = texts self.tokenizer = tokenizer self.block_size = block_size self.epoch_seed = epoch_seed def __iter__(self): worker = torch.utils.data.get_worker_info() rank, ws = get_rank(), get_world_size() if worker is None: shard_mod, shard_id = ws, rank else: shard_mod = worker.num_workers * ws shard_id = rank * worker.num_workers + worker.id rng = random.Random(self.epoch_seed) indices = list(range(len(self.texts))) rng.shuffle(indices) bos, eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id buf: list[int] = [] for li, ti in enumerate(indices): if li % shard_mod != shard_id: continue ids = self.tokenizer.encode(self.texts[ti], add_special_tokens=False) if not ids: continue buf.extend([bos] + ids + [eos]) while len(buf) >= self.block_size + 1: chunk = buf[: self.block_size + 1] buf = buf[self.block_size + 1 :] yield { "input_ids": torch.tensor(chunk[:-1], dtype=torch.long), "labels": torch.tensor(chunk[1:], dtype=torch.long), } # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ TOKENIZER ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ def tokenizer_ready() -> bool: return (TOKENIZER_DIR / "tokenizer.json").exists() and (TOKENIZER_DIR / "tokenizer_config.json").exists() def train_tokenizer_once() -> None: TOKENIZER_DIR.mkdir(parents=True, exist_ok=True) tok = Tokenizer(models.BPE(unk_token=UNK_TOKEN)) tok.normalizer = normalizers.Sequence([normalizers.NFKC()]) tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) tok.decoder = decoders.ByteLevel() trainer = trainers.BpeTrainer( vocab_size=VOCAB_SIZE, min_frequency=2, show_progress=is_main(), special_tokens=SPECIAL_TOKENS, initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), ) tok.train_from_iterator(tokenizer_training_iterator(), trainer=trainer) bos_id, eos_id = tok.token_to_id(BOS_TOKEN), tok.token_to_id(EOS_TOKEN) tok.post_processor = processors.TemplateProcessing( single=f"{BOS_TOKEN} $A {EOS_TOKEN}", pair=f"{BOS_TOKEN} $A {EOS_TOKEN} $B:1 {EOS_TOKEN}:1", special_tokens=[(BOS_TOKEN, bos_id), (EOS_TOKEN, eos_id)], ) tok.save(str(TOKENIZER_DIR / "tokenizer.json")) fast = PreTrainedTokenizerFast( tokenizer_file=str(TOKENIZER_DIR / "tokenizer.json"), bos_token=BOS_TOKEN, eos_token=EOS_TOKEN, unk_token=UNK_TOKEN, pad_token=PAD_TOKEN, ) fast.save_pretrained(str(TOKENIZER_DIR)) def train_or_load_tokenizer() -> PreTrainedTokenizerFast: TOKENIZER_DIR.mkdir(parents=True, exist_ok=True) if not tokenizer_ready(): if is_distributed(): if is_main(): print("Entraînement tokenizer 32k…"); train_tokenizer_once() dist.barrier() else: print("Entraînement tokenizer 32k…"); train_tokenizer_once() return PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR)) # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ MODÈLE ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ @dataclass class GPTConfig: vocab_size: int = VOCAB_SIZE block_size: int = BLOCK_SIZE d_model: int = D_MODEL n_heads: int = N_HEADS n_layers: int = N_LAYERS d_ff: int = D_FF dropout: float = DROPOUT use_checkpointing: bool = USE_CHECKPOINTING class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x): return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) class RotaryEmbedding(nn.Module): def __init__(self, dim: int, base: int = 10_000, max_seq: int = 4_096): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(max_seq).float() freqs = torch.outer(t, inv_freq) self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False) self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False) def forward(self, seq_len: int, dtype: torch.dtype): return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype) def rotate_half(x): x1, x2 = x[..., ::2], x[..., 1::2] return torch.stack((-x2, x1), dim=-1).flatten(-2) def apply_rope(x, cos, sin): return x * cos.unsqueeze(0).unsqueeze(0) + rotate_half(x) * sin.unsqueeze(0).unsqueeze(0) class CausalSelfAttention(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() assert cfg.d_model % cfg.n_heads == 0 self.n_heads = cfg.n_heads self.head_dim = cfg.d_model // cfg.n_heads self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False) self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False) self.dropout_p = cfg.dropout self.rope = RotaryEmbedding(self.head_dim) def forward(self, x): b, t, c = x.shape q, k, v = self.qkv(x).split(c, dim=-1) q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) cos, sin = self.rope(t, x.dtype) q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin) if HAS_FLASH: # Flash Attention 2 attend (b, t, nh, hd) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) y = flash_attn_func(q, k, v, dropout_p=self.dropout_p if self.training else 0.0, causal=True) y = y.reshape(b, t, c) else: y = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout_p if self.training else 0.0, is_causal=True) y = y.transpose(1, 2).contiguous().view(b, t, c) return self.proj(y) class SwiGLU(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False) def forward(self, x): return self.w3(F.silu(self.w1(x)) * self.w2(x)) class Block(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() self.ln1 = RMSNorm(cfg.d_model) self.attn = CausalSelfAttention(cfg) self.ln2 = RMSNorm(cfg.d_model) self.ff = SwiGLU(cfg) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.ff(self.ln2(x)) return x class GPT(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() self.cfg = cfg self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) self.ln_f = RMSNorm(cfg.d_model) self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) self.lm_head.weight = self.tok_emb.weight # weight tying self.apply(self._init_weights) @staticmethod def _init_weights(m): if isinstance(m, (nn.Linear, nn.Embedding)): nn.init.normal_(m.weight, 0.0, 0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.zeros_(m.bias) def forward(self, input_ids, labels=None): x = self.tok_emb(input_ids) for block in self.blocks: if self.cfg.use_checkpointing and self.training: x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) else: x = block(x) logits = self.lm_head(self.ln_f(x)) loss = None if labels is not None: loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=-100) return logits, loss # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ QLORA ║ # ║ ║ # ║ CORRECTIF CLÉ : apply_qlora() DOIT être appelé APRÈS model.to(device). ║ # ║ LoRALinear détecte automatiquement le device du base_layer et y crée ║ # ║ lora_A / lora_B directement, sans besoin de .to() séparé. ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ class LoRALinear(nn.Module): """ Adaptateur LoRA autour d'un nn.Linear existant. IMPORTANT : les sous-modules lora_A et lora_B sont créés sur le MÊME device que base_layer.weight via le move explicite ci-dessous. C'est la correction du bug 'cuda:0 vs cpu' de la v1. """ def __init__(self, base_layer: nn.Linear, r: int = LORA_R, alpha: int = LORA_ALPHA, dropout: float = LORA_DROPOUT): super().__init__() self.base = base_layer self.r = r self.scale = alpha / r in_f, out_f = base_layer.in_features, base_layer.out_features # ── Détecter le device du base_layer ────────────────────────────────── # base_layer.weight peut être un Params4bit (pas de .device direct) try: dev = next(base_layer.parameters()).device except StopIteration: dev = torch.device("cpu") # Créer les adaptateurs DIRECTEMENT sur le bon device self.lora_A = nn.Linear(in_f, r, bias=False, device=dev) self.lora_B = nn.Linear(r, out_f, bias=False, device=dev) self.drop = nn.Dropout(dropout) # Initialisation standard LoRA nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) nn.init.zeros_(self.lora_B.weight) # Geler les poids de base for p in self.base.parameters(): p.requires_grad = False def forward(self, x): return self.base(x) + self.lora_B(self.lora_A(self.drop(x))) * self.scale def apply_qlora(model: GPT, device: torch.device) -> GPT: """ Remplace les couches cibles par LoRALinear. À appeler IMPÉRATIVEMENT après model.to(device). """ if not USE_QLORA: return model replaced = 0 # Collecter d'abord pour éviter de modifier le dict pendant l'itération targets = [] for name, module in model.named_modules(): parts = name.split(".") if parts[-1] in LORA_TARGET_MODULES and isinstance(module, nn.Linear): targets.append((name, module)) for name, module in targets: parts = name.split(".") parent = model for part in parts[:-1]: parent = getattr(parent, part) lora_layer = LoRALinear(module) setattr(parent, parts[-1], lora_layer) replaced += 1 if is_main(): print(f"QLoRA : {replaced} couches remplacées (device={device}, NF4={HAS_BNB})") return model def freeze_base_weights(model: GPT) -> None: """Seuls lora_A et lora_B restent entraînables.""" for name, p in model.named_parameters(): p.requires_grad = ("lora_A" in name or "lora_B" in name) # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ OPTIMIZER ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ def build_optimizer(model: nn.Module) -> torch.optim.Optimizer: decay, no_decay = [], [] for name, p in unwrap_model(model).named_parameters(): if not p.requires_grad: continue (decay if p.ndim >= 2 and "weight" in name else no_decay).append(p) groups = [ {"params": decay, "weight_decay": WEIGHT_DECAY}, {"params": no_decay, "weight_decay": 0.0}, ] if HAS_BNB: return bnb.optim.PagedAdamW8bit(groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8) return torch.optim.AdamW(groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8, fused=torch.cuda.is_available()) def cosine_lr(step: int, total_steps: int) -> float: if step < WARMUP_STEPS: return LEARNING_RATE * step / max(1, WARMUP_STEPS) p = min(1.0, (step - WARMUP_STEPS) / max(1, total_steps - WARMUP_STEPS)) return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1 + math.cos(math.pi * p)) # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ CHECKPOINT ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ def save_checkpoint(model, optimizer, epoch, step, best_loss, path): raw = unwrap_model(model) torch.save({ "model": normalize_state_dict_keys(raw.state_dict()), "optimizer": optimizer.state_dict(), "epoch": epoch, "step": step, "best_loss": best_loss, "config": asdict(raw.cfg), }, path) def maybe_load_base_checkpoint(model, device): if BASE_CHECKPOINT is None or not Path(BASE_CHECKPOINT).exists(): return ckpt = torch.load(BASE_CHECKPOINT, map_location=device) unwrap_model(model).load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=False) def load_resume_checkpoint(model, optimizer, path, device): ckpt = torch.load(path, map_location=device) unwrap_model(model).load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=True) try: optimizer.load_state_dict(ckpt["optimizer"]) except Exception as e: print(f"[warn] Optimizer state non repris: {e}") return int(ckpt.get("epoch", 0)), int(ckpt.get("step", 0)), float(ckpt.get("best_loss", 1e9)) # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ ÉVALUATION ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ @torch.no_grad() def evaluate(model, loader, device, max_batches=200) -> float: model.eval() losses = [] for i, batch in enumerate(loader): if i >= max_batches: break inp = batch["input_ids"].to(device, non_blocking=True) lbl = batch["labels"].to(device, non_blocking=True) with autocast_context(device): _, loss = model(inp, lbl) losses.append(loss.item()) model.train() return sum(losses) / max(1, len(losses)) # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ DATALOADER ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ def make_loader(dataset, batch_size, num_workers, is_cuda): kwargs = dict(batch_size=batch_size, num_workers=num_workers, pin_memory=is_cuda) if num_workers > 0: kwargs["persistent_workers"] = True kwargs["prefetch_factor"] = PREFETCH_FACTOR return torch.utils.data.DataLoader(dataset, **kwargs) # ╔══════════════════════════════════════════════════════════════════════════════╗ # ║ MAIN ║ # ╚══════════════════════════════════════════════════════════════════════════════╝ def main() -> None: ddp_device = init_distributed() set_seed(SEED + get_rank()) device = get_device(ddp_device) is_cuda = device.type == "cuda" cuda_idx = None vram_fraction = None if is_cuda: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.set_float32_matmul_precision("high") cuda_idx = current_cuda_index(device) _, total = torch.cuda.mem_get_info(cuda_idx) vram_fraction = min(TARGET_VRAM_GIB * (1024**3) / total, 0.999) torch.cuda.memory.set_per_process_memory_fraction(vram_fraction, device=cuda_idx) if is_main(): print("=" * 72) print(" GPT ~1B | H100 80 Go | QLoRA + BF16 + TF32 | v2 (device fix)") print("=" * 72) print(f"Device : {device} | World: {get_world_size()} GPU(s)") print(f"Flash-2 : {HAS_FLASH} | BNB 4-bit: {HAS_BNB} | QLoRA: {USE_QLORA}") print(f"Grad ckpt: {USE_CHECKPOINTING} | Compile: {USE_COMPILE} ({COMPILE_MODE})") if is_cuda: free, total = torch.cuda.mem_get_info(cuda_idx) print(f"GPU : {torch.cuda.get_device_name(cuda_idx)}") print(f"VRAM : {total/1024**3:.1f} GiB | libre: {free/1024**3:.1f} GiB") tokenizer = train_or_load_tokenizer() cfg = GPTConfig(vocab_size=len(tokenizer)) if is_main(): CONFIG_FILE.write_text(json.dumps(asdict(cfg), indent=2, ensure_ascii=False), encoding="utf-8") # ── 1. Créer le modèle ──────────────────────────────────────────────────── model = GPT(cfg).to(device) # ── 2. Appliquer QLoRA APRÈS .to(device) ───────────────────────────────── # C'est la correction principale : lora_A/lora_B sont créés sur CUDA if USE_QLORA: model = apply_qlora(model, device) freeze_base_weights(model) maybe_load_base_checkpoint(model, device) # ── 3. torch.compile (seulement si USE_CHECKPOINTING=False) ────────────── # La combinaison compile + checkpoint + LoRALinear custom est instable # avec torch.dynamo sur PyTorch 2.x. Choisir l'un ou l'autre. if USE_COMPILE and not USE_CHECKPOINTING and hasattr(torch, "compile"): try: model = torch.compile(model, mode=COMPILE_MODE) if is_main(): print(f"torch.compile activé ({COMPILE_MODE})") except Exception as e: if is_main(): print(f"[warn] torch.compile échoué ({e}) — poursuite sans compile") # ── 4. DDP ──────────────────────────────────────────────────────────────── if is_distributed(): model = DDP(model, device_ids=[device.index]) optimizer = build_optimizer(model) # ── Datasets ────────────────────────────────────────────────────────────── eval_texts = build_eval_texts() eval_ds = PackedTextList(eval_texts, tokenizer, cfg.block_size, SEED + 999) eval_loader = make_loader(eval_ds, BATCH_SIZE, EVAL_NUM_WORKERS, is_cuda) init_texts = build_epoch_train_texts(0) steps_per_epoch = max(1, len(init_texts) // BATCH_SIZE) total_steps_est = steps_per_epoch * NUM_EPOCHS # ── Reprise ─────────────────────────────────────────────────────────────── start_epoch, start_step, best_eval = 0, 0, 1e9 if STATE_FILE.exists(): try: if is_main(): print(f"Reprise depuis {STATE_FILE}") start_epoch, start_step, best_eval = load_resume_checkpoint(model, optimizer, STATE_FILE, device) except Exception as e: if is_main(): bad = STATE_FILE.with_suffix(".corrupt.pt") print(f"[warn] Checkpoint illisible: {e}") try: STATE_FILE.rename(bad) except Exception: pass start_epoch, start_step, best_eval = 0, 0, 1e9 if is_main(): raw = unwrap_model(model) n_total = count_parameters(raw, False) n_train = count_parameters(raw, True) print(f"Paramètres totaux : {n_total/1e9:.3f}B") print(f"Paramètres entraînés : {n_train/1e6:.1f}M ({100*n_train/max(1,n_total):.2f}%)") print(f"Batch size : {BATCH_SIZE} | Grad accum: {GRAD_ACCUM_STEPS} | Effective: {BATCH_SIZE*GRAD_ACCUM_STEPS}") print(f"Steps estimés: {total_steps_est} | Eval texts: {len(eval_texts)}") print() print("── Conseil VRAM ────────────────────────────────────────────────") print(" Surveille 'max_reserved=XX GiB' à step 50.") print(" Augmente BATCH_SIZE par +2 jusqu'à ~77 Go réservés.") print(" Si OOM : BATCH_SIZE -= 1 ou USE_CHECKPOINTING=True.") print("────────────────────────────────────────────────────────────────") # ── Boucle d'entraînement ───────────────────────────────────────────────── model.train() optimizer.zero_grad(set_to_none=True) global_step = start_step t0 = time.time() log_loss_sum = 0.0 log_loss_count = 0 tokens_since_log = 0 last_log = time.time() if is_cuda: torch.cuda.reset_peak_memory_stats(cuda_idx) for epoch in range(start_epoch, NUM_EPOCHS): if is_main(): print(f"\n{'='*20} Epoch {epoch+1}/{NUM_EPOCHS} {'='*20}") train_texts = build_epoch_train_texts(epoch) train_ds = PackedTextList(train_texts, tokenizer, cfg.block_size, SEED + epoch) train_loader = make_loader(train_ds, BATCH_SIZE, TRAIN_NUM_WORKERS, is_cuda) for micro_step, batch in enumerate(train_loader): inp = batch["input_ids"].to(device, non_blocking=True) lbl = batch["labels"].to(device, non_blocking=True) with autocast_context(device): _, loss = model(inp, lbl) (loss / GRAD_ACCUM_STEPS).backward() log_loss_sum += loss.item() log_loss_count += 1 tokens_since_log += inp.numel() if (micro_step + 1) % GRAD_ACCUM_STEPS != 0: continue lr = cosine_lr(global_step, total_steps_est) for group in optimizer.param_groups: group["lr"] = lr torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM) optimizer.step() optimizer.zero_grad(set_to_none=True) global_step += 1 if global_step % 50 == 0 and is_main(): now = time.time() elapsed = max(1e-6, now - last_log) tok_s = tokens_since_log / elapsed avg_loss = log_loss_sum / max(1, log_loss_count) print( f"ep {epoch+1}/{NUM_EPOCHS} | step={global_step:5d} | " f"loss={avg_loss:.4f} | lr={lr:.2e} | {tok_s:,.0f} tok/s" ) if is_cuda: alloc = torch.cuda.memory_allocated(cuda_idx) / 1024**3 reserved = torch.cuda.memory_reserved(cuda_idx) / 1024**3 max_alloc = torch.cuda.max_memory_allocated(cuda_idx) / 1024**3 max_res = torch.cuda.max_memory_reserved(cuda_idx) / 1024**3 print( f"GPU mem | alloc={alloc:.2f} | reserved={reserved:.2f} | " f"max_alloc={max_alloc:.2f} | max_reserved={max_res:.2f} (GiB)" ) last_log = now tokens_since_log = 0 log_loss_sum = 0.0 log_loss_count = 0 if global_step % EVAL_EVERY == 0 and is_main(): val_loss = evaluate(model, eval_loader, device) print(f"[eval] step {global_step:5d} | val_loss={val_loss:.4f}") if val_loss < best_eval: best_eval = val_loss save_checkpoint(model, optimizer, epoch, global_step, best_eval, BEST_MODEL_FILE) print(f"✓ Meilleur modèle → {BEST_MODEL_FILE}") if global_step % SAVE_EVERY == 0 and is_main(): save_checkpoint(model, optimizer, epoch, global_step, best_eval, STATE_FILE) save_checkpoint(model, optimizer, epoch, global_step, best_eval, MODEL_FILE) print(f"✓ Checkpoint → {MODEL_FILE}") if is_main(): save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, STATE_FILE) ckpt = OUT_DIR / f"model_epoch_{epoch+1:02d}.pt" save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, ckpt) print(f"✓ Fin epoch {epoch+1}/{NUM_EPOCHS} → {ckpt}") if is_main(): save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, MODEL_FILE) save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, STATE_FILE) total_min = (time.time() - t0) / 60 print(f"\nModèle final → {MODEL_FILE}") print(f"Meilleur modèle → {BEST_MODEL_FILE}") print(f"Temps total : {total_min:.1f} min | Steps: {global_step}") if is_distributed(): dist.destroy_process_group() if __name__ == "__main__": main()