| |
| |
| """ |
| train_nlp_h100_optimized.py โ v2 (bugfix device mismatch) |
| =========================================================== |
| Corrections vs v1 : |
| โข apply_qlora() appelรฉ APRรS model.to(device) โ lora_A/lora_B naissent sur CUDA |
| โข LoRALinear.__init__ : move explicite des adaptateurs sur le device du base_layer |
| โข torch.compile dรฉsactivรฉ quand USE_CHECKPOINTING=True (conflict dynamo+checkpoint |
| avec sous-modules custom) โ on utilise COMPILE_AFTER_CKPT pour les cas oรน on |
| veut quand mรชme compiler (USE_CHECKPOINTING=False) |
| โข Ajout d'un fallback propre : si compile crash, on continue sans compile |
| """ |
|
|
| from __future__ import annotations |
|
|
| import itertools |
| import json |
| import math |
| import os |
| import random |
| import time |
| from collections import OrderedDict |
| from contextlib import nullcontext |
| from dataclasses import asdict, dataclass |
| from pathlib import Path |
| from typing import Iterator, Optional |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| try: |
| import bitsandbytes as bnb |
| from bitsandbytes.nn import Params4bit |
| HAS_BNB = True |
| except ImportError: |
| HAS_BNB = False |
| print("[warn] bitsandbytes non disponible โ quantification 4-bit dรฉsactivรฉe") |
|
|
| try: |
| from flash_attn import flash_attn_func |
| HAS_FLASH = True |
| except ImportError: |
| HAS_FLASH = False |
| print("[warn] flash-attn non disponible โ fallback F.scaled_dot_product_attention") |
|
|
| from datasets import load_dataset |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from tokenizers import ( |
| Tokenizer, decoders, models, normalizers, |
| pre_tokenizers, processors, trainers, |
| ) |
| from transformers import PreTrainedTokenizerFast |
|
|
|
|
| |
| |
| |
|
|
| OUT_DIR = Path("./nlp_1b_h100_opt") |
| OUT_DIR.mkdir(parents=True, exist_ok=True) |
| TOKENIZER_DIR = OUT_DIR / "tokenizer_32k" |
| CONFIG_FILE = OUT_DIR / "config.json" |
| MODEL_FILE = OUT_DIR / "model.pt" |
| BEST_MODEL_FILE= OUT_DIR / "model_best.pt" |
| STATE_FILE = OUT_DIR / "train_state.pt" |
| BASE_CHECKPOINT: Optional[Path] = None |
|
|
|
|
| |
| |
| |
|
|
| SEED = 42 |
| TARGET_VRAM_GIB= 78.0 |
|
|
| BLOCK_SIZE = 1024 |
| VOCAB_SIZE = 32_000 |
| D_MODEL = 1536 |
| N_HEADS = 24 |
| N_LAYERS = 24 |
| D_FF = 6144 |
| DROPOUT = 0.0 |
|
|
| USE_QLORA = True |
| LORA_R = 64 |
| LORA_ALPHA = 128 |
| LORA_DROPOUT = 0.05 |
| LORA_TARGET_MODULES = ["qkv", "proj", "w1", "w2", "w3"] |
|
|
| NUM_EPOCHS = 10 |
| LEARNING_RATE = 3e-4 |
| MIN_LR = 3e-5 |
| WEIGHT_DECAY = 0.1 |
| WARMUP_STEPS = 500 |
|
|
| |
| |
| |
| |
| |
| |
| BATCH_SIZE = 16 |
| GRAD_ACCUM_STEPS = 1 |
| MAX_GRAD_NORM = 1.0 |
| EVAL_EVERY = 500 |
| SAVE_EVERY = 1_000 |
|
|
| DTYPE = torch.bfloat16 |
|
|
| |
| |
| |
| USE_CHECKPOINTING = False |
| USE_COMPILE = True |
| COMPILE_MODE = "reduce-overhead" |
|
|
| TRAIN_NUM_WORKERS = 4 |
| EVAL_NUM_WORKERS = 2 |
| PREFETCH_FACTOR = 2 |
|
|
| TOKENIZER_SAMPLE_DOCS_PER_SOURCE = 15_000 |
| TOKENIZER_CHAR_LIMIT = 2_000 |
| TEXT_CHAR_LIMIT = 4_000 |
|
|
| SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"] |
| PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN = SPECIAL_TOKENS |
|
|
| WIKI_CONFIGS = ["20231101.en", "20231101.fr", "20231101.ar"] |
| FINEWEB_CONFIG = "sample-10BT" |
| DEV_DOCS_PER_WIKI_CONFIG = 1_500 |
| DEV_DOCS_FINEWEB = 3_000 |
| TRAIN_DOCS_PER_WIKI_CONFIG = 30_000 |
| TRAIN_DOCS_FINEWEB = 60_000 |
|
|
|
|
| |
| |
| |
|
|
| def is_distributed() -> bool: |
| return dist.is_available() and dist.is_initialized() |
|
|
| def get_rank() -> int: |
| return dist.get_rank() if is_distributed() else 0 |
|
|
| def get_world_size() -> int: |
| return dist.get_world_size() if is_distributed() else 1 |
|
|
| def is_main() -> bool: |
| return get_rank() == 0 |
|
|
| def init_distributed() -> Optional[torch.device]: |
| local_rank = int(os.environ.get("LOCAL_RANK", -1)) |
| if local_rank == -1: |
| return None |
| dist.init_process_group("nccl") |
| torch.cuda.set_device(local_rank) |
| return torch.device(f"cuda:{local_rank}") |
|
|
| def set_seed(seed: int) -> None: |
| random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
| def get_device(ddp_device: Optional[torch.device] = None) -> torch.device: |
| if ddp_device is not None: |
| return ddp_device |
| if torch.cuda.is_available(): |
| return torch.device(f"cuda:{torch.cuda.current_device()}") |
| return torch.device("cpu") |
|
|
| def current_cuda_index(device: torch.device) -> int: |
| return device.index if device.index is not None else torch.cuda.current_device() |
|
|
| def autocast_context(device: torch.device): |
| if device.type == "cuda": |
| return torch.autocast("cuda", dtype=DTYPE) |
| return nullcontext() |
|
|
| def unwrap_model(model: nn.Module) -> nn.Module: |
| m = model.module if isinstance(model, DDP) else model |
| return m._orig_mod if hasattr(m, "_orig_mod") else m |
|
|
| def count_parameters(model: nn.Module, trainable_only: bool = True) -> int: |
| return sum(p.numel() for p in model.parameters() if not trainable_only or p.requires_grad) |
|
|
| def normalize_state_dict_keys(sd: dict) -> OrderedDict: |
| out = OrderedDict() |
| for k, v in sd.items(): |
| for prefix in ("module._orig_mod.", "_orig_mod.", "module."): |
| if k.startswith(prefix): |
| k = k[len(prefix):] |
| break |
| out[k] = v |
| return out |
|
|
| def normalize_text(t: str) -> str: |
| return " ".join(t.strip().split()) |
|
|
| def safe_str(x) -> str: |
| return x if isinstance(x, str) else ("" if x is None else str(x)) |
|
|
|
|
| |
| |
| |
|
|
| def load_wiki_stream(cfg_name: str): |
| return load_dataset("wikimedia/wikipedia", cfg_name, split="train", streaming=True) |
|
|
| def load_fineweb_stream(): |
| return load_dataset("HuggingFaceFW/fineweb-edu", FINEWEB_CONFIG, split="train", streaming=True) |
|
|
| def stream_texts(ds, start: int, count: int, char_limit: int) -> Iterator[str]: |
| for row in itertools.islice(ds, start, start + count): |
| text = normalize_text(safe_str(row.get("text", ""))) |
| if len(text) >= 20: |
| yield text[:char_limit] |
|
|
| def tokenizer_training_iterator() -> Iterator[str]: |
| for c in WIKI_CONFIGS: |
| yield from stream_texts(load_wiki_stream(c), 0, TOKENIZER_SAMPLE_DOCS_PER_SOURCE, TOKENIZER_CHAR_LIMIT) |
| yield from stream_texts(load_fineweb_stream(), 0, TOKENIZER_SAMPLE_DOCS_PER_SOURCE, TOKENIZER_CHAR_LIMIT) |
|
|
| def build_epoch_train_texts(epoch: int) -> list[str]: |
| texts: list[str] = [] |
| for c in WIKI_CONFIGS: |
| start = DEV_DOCS_PER_WIKI_CONFIG + epoch * TRAIN_DOCS_PER_WIKI_CONFIG |
| texts.extend(stream_texts(load_wiki_stream(c), start, TRAIN_DOCS_PER_WIKI_CONFIG, TEXT_CHAR_LIMIT)) |
| start = DEV_DOCS_FINEWEB + epoch * TRAIN_DOCS_FINEWEB |
| texts.extend(stream_texts(load_fineweb_stream(), start, TRAIN_DOCS_FINEWEB, TEXT_CHAR_LIMIT)) |
| random.Random(SEED + epoch).shuffle(texts) |
| return texts |
|
|
| def build_eval_texts() -> list[str]: |
| texts: list[str] = [] |
| for c in WIKI_CONFIGS: |
| texts.extend(stream_texts(load_wiki_stream(c), 0, DEV_DOCS_PER_WIKI_CONFIG, TEXT_CHAR_LIMIT)) |
| texts.extend(stream_texts(load_fineweb_stream(), 0, DEV_DOCS_FINEWEB, TEXT_CHAR_LIMIT)) |
| return texts |
|
|
|
|
| |
| |
| |
|
|
| class PackedTextList(torch.utils.data.IterableDataset): |
| def __init__(self, texts, tokenizer, block_size, epoch_seed=0): |
| super().__init__() |
| self.texts = texts |
| self.tokenizer = tokenizer |
| self.block_size = block_size |
| self.epoch_seed = epoch_seed |
|
|
| def __iter__(self): |
| worker = torch.utils.data.get_worker_info() |
| rank, ws = get_rank(), get_world_size() |
| if worker is None: |
| shard_mod, shard_id = ws, rank |
| else: |
| shard_mod = worker.num_workers * ws |
| shard_id = rank * worker.num_workers + worker.id |
|
|
| rng = random.Random(self.epoch_seed) |
| indices = list(range(len(self.texts))) |
| rng.shuffle(indices) |
|
|
| bos, eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id |
| buf: list[int] = [] |
|
|
| for li, ti in enumerate(indices): |
| if li % shard_mod != shard_id: |
| continue |
| ids = self.tokenizer.encode(self.texts[ti], add_special_tokens=False) |
| if not ids: |
| continue |
| buf.extend([bos] + ids + [eos]) |
| while len(buf) >= self.block_size + 1: |
| chunk = buf[: self.block_size + 1] |
| buf = buf[self.block_size + 1 :] |
| yield { |
| "input_ids": torch.tensor(chunk[:-1], dtype=torch.long), |
| "labels": torch.tensor(chunk[1:], dtype=torch.long), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def tokenizer_ready() -> bool: |
| return (TOKENIZER_DIR / "tokenizer.json").exists() and (TOKENIZER_DIR / "tokenizer_config.json").exists() |
|
|
| def train_tokenizer_once() -> None: |
| TOKENIZER_DIR.mkdir(parents=True, exist_ok=True) |
| tok = Tokenizer(models.BPE(unk_token=UNK_TOKEN)) |
| tok.normalizer = normalizers.Sequence([normalizers.NFKC()]) |
| tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) |
| tok.decoder = decoders.ByteLevel() |
| trainer = trainers.BpeTrainer( |
| vocab_size=VOCAB_SIZE, min_frequency=2, show_progress=is_main(), |
| special_tokens=SPECIAL_TOKENS, initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), |
| ) |
| tok.train_from_iterator(tokenizer_training_iterator(), trainer=trainer) |
| bos_id, eos_id = tok.token_to_id(BOS_TOKEN), tok.token_to_id(EOS_TOKEN) |
| tok.post_processor = processors.TemplateProcessing( |
| single=f"{BOS_TOKEN} $A {EOS_TOKEN}", |
| pair=f"{BOS_TOKEN} $A {EOS_TOKEN} $B:1 {EOS_TOKEN}:1", |
| special_tokens=[(BOS_TOKEN, bos_id), (EOS_TOKEN, eos_id)], |
| ) |
| tok.save(str(TOKENIZER_DIR / "tokenizer.json")) |
| fast = PreTrainedTokenizerFast( |
| tokenizer_file=str(TOKENIZER_DIR / "tokenizer.json"), |
| bos_token=BOS_TOKEN, eos_token=EOS_TOKEN, unk_token=UNK_TOKEN, pad_token=PAD_TOKEN, |
| ) |
| fast.save_pretrained(str(TOKENIZER_DIR)) |
|
|
| def train_or_load_tokenizer() -> PreTrainedTokenizerFast: |
| TOKENIZER_DIR.mkdir(parents=True, exist_ok=True) |
| if not tokenizer_ready(): |
| if is_distributed(): |
| if is_main(): |
| print("Entraรฎnement tokenizer 32kโฆ"); train_tokenizer_once() |
| dist.barrier() |
| else: |
| print("Entraรฎnement tokenizer 32kโฆ"); train_tokenizer_once() |
| return PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR)) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class GPTConfig: |
| vocab_size: int = VOCAB_SIZE |
| block_size: int = BLOCK_SIZE |
| d_model: int = D_MODEL |
| n_heads: int = N_HEADS |
| n_layers: int = N_LAYERS |
| d_ff: int = D_FF |
| dropout: float = DROPOUT |
| use_checkpointing: bool = USE_CHECKPOINTING |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
| def forward(self, x): |
| return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim: int, base: int = 10_000, max_seq: int = 4_096): |
| super().__init__() |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| t = torch.arange(max_seq).float() |
| freqs = torch.outer(t, inv_freq) |
| self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False) |
| self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False) |
| def forward(self, seq_len: int, dtype: torch.dtype): |
| return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype) |
|
|
|
|
| def rotate_half(x): |
| x1, x2 = x[..., ::2], x[..., 1::2] |
| return torch.stack((-x2, x1), dim=-1).flatten(-2) |
|
|
| def apply_rope(x, cos, sin): |
| return x * cos.unsqueeze(0).unsqueeze(0) + rotate_half(x) * sin.unsqueeze(0).unsqueeze(0) |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, cfg: GPTConfig): |
| super().__init__() |
| assert cfg.d_model % cfg.n_heads == 0 |
| self.n_heads = cfg.n_heads |
| self.head_dim = cfg.d_model // cfg.n_heads |
| self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False) |
| self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False) |
| self.dropout_p = cfg.dropout |
| self.rope = RotaryEmbedding(self.head_dim) |
|
|
| def forward(self, x): |
| b, t, c = x.shape |
| q, k, v = self.qkv(x).split(c, dim=-1) |
| q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) |
| k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) |
| v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) |
| cos, sin = self.rope(t, x.dtype) |
| q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin) |
|
|
| if HAS_FLASH: |
| |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
| y = flash_attn_func(q, k, v, dropout_p=self.dropout_p if self.training else 0.0, causal=True) |
| y = y.reshape(b, t, c) |
| else: |
| y = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout_p if self.training else 0.0, is_causal=True) |
| y = y.transpose(1, 2).contiguous().view(b, t, c) |
|
|
| return self.proj(y) |
|
|
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, cfg: GPTConfig): |
| super().__init__() |
| self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) |
| self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) |
| self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False) |
| def forward(self, x): |
| return self.w3(F.silu(self.w1(x)) * self.w2(x)) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, cfg: GPTConfig): |
| super().__init__() |
| self.ln1 = RMSNorm(cfg.d_model) |
| self.attn = CausalSelfAttention(cfg) |
| self.ln2 = RMSNorm(cfg.d_model) |
| self.ff = SwiGLU(cfg) |
| def forward(self, x): |
| x = x + self.attn(self.ln1(x)) |
| x = x + self.ff(self.ln2(x)) |
| return x |
|
|
|
|
| class GPT(nn.Module): |
| def __init__(self, cfg: GPTConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) |
| self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) |
| self.ln_f = RMSNorm(cfg.d_model) |
| self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) |
| self.lm_head.weight = self.tok_emb.weight |
| self.apply(self._init_weights) |
|
|
| @staticmethod |
| def _init_weights(m): |
| if isinstance(m, (nn.Linear, nn.Embedding)): |
| nn.init.normal_(m.weight, 0.0, 0.02) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
| def forward(self, input_ids, labels=None): |
| x = self.tok_emb(input_ids) |
| for block in self.blocks: |
| if self.cfg.use_checkpointing and self.training: |
| x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) |
| else: |
| x = block(x) |
| logits = self.lm_head(self.ln_f(x)) |
| loss = None |
| if labels is not None: |
| loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=-100) |
| return logits, loss |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| class LoRALinear(nn.Module): |
| """ |
| Adaptateur LoRA autour d'un nn.Linear existant. |
| |
| IMPORTANT : les sous-modules lora_A et lora_B sont crรฉรฉs sur le MรME |
| device que base_layer.weight via le move explicite ci-dessous. |
| C'est la correction du bug 'cuda:0 vs cpu' de la v1. |
| """ |
| def __init__(self, base_layer: nn.Linear, r: int = LORA_R, alpha: int = LORA_ALPHA, dropout: float = LORA_DROPOUT): |
| super().__init__() |
| self.base = base_layer |
| self.r = r |
| self.scale = alpha / r |
| in_f, out_f = base_layer.in_features, base_layer.out_features |
|
|
| |
| |
| try: |
| dev = next(base_layer.parameters()).device |
| except StopIteration: |
| dev = torch.device("cpu") |
|
|
| |
| self.lora_A = nn.Linear(in_f, r, bias=False, device=dev) |
| self.lora_B = nn.Linear(r, out_f, bias=False, device=dev) |
| self.drop = nn.Dropout(dropout) |
|
|
| |
| nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B.weight) |
|
|
| |
| for p in self.base.parameters(): |
| p.requires_grad = False |
|
|
| def forward(self, x): |
| return self.base(x) + self.lora_B(self.lora_A(self.drop(x))) * self.scale |
|
|
|
|
| def apply_qlora(model: GPT, device: torch.device) -> GPT: |
| """ |
| Remplace les couches cibles par LoRALinear. |
| ร appeler IMPรRATIVEMENT aprรจs model.to(device). |
| """ |
| if not USE_QLORA: |
| return model |
|
|
| replaced = 0 |
| |
| targets = [] |
| for name, module in model.named_modules(): |
| parts = name.split(".") |
| if parts[-1] in LORA_TARGET_MODULES and isinstance(module, nn.Linear): |
| targets.append((name, module)) |
|
|
| for name, module in targets: |
| parts = name.split(".") |
| parent = model |
| for part in parts[:-1]: |
| parent = getattr(parent, part) |
|
|
| lora_layer = LoRALinear(module) |
| setattr(parent, parts[-1], lora_layer) |
| replaced += 1 |
|
|
| if is_main(): |
| print(f"QLoRA : {replaced} couches remplacรฉes (device={device}, NF4={HAS_BNB})") |
| return model |
|
|
|
|
| def freeze_base_weights(model: GPT) -> None: |
| """Seuls lora_A et lora_B restent entraรฎnables.""" |
| for name, p in model.named_parameters(): |
| p.requires_grad = ("lora_A" in name or "lora_B" in name) |
|
|
|
|
| |
| |
| |
|
|
| def build_optimizer(model: nn.Module) -> torch.optim.Optimizer: |
| decay, no_decay = [], [] |
| for name, p in unwrap_model(model).named_parameters(): |
| if not p.requires_grad: |
| continue |
| (decay if p.ndim >= 2 and "weight" in name else no_decay).append(p) |
|
|
| groups = [ |
| {"params": decay, "weight_decay": WEIGHT_DECAY}, |
| {"params": no_decay, "weight_decay": 0.0}, |
| ] |
|
|
| if HAS_BNB: |
| return bnb.optim.PagedAdamW8bit(groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8) |
|
|
| return torch.optim.AdamW(groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8, fused=torch.cuda.is_available()) |
|
|
|
|
| def cosine_lr(step: int, total_steps: int) -> float: |
| if step < WARMUP_STEPS: |
| return LEARNING_RATE * step / max(1, WARMUP_STEPS) |
| p = min(1.0, (step - WARMUP_STEPS) / max(1, total_steps - WARMUP_STEPS)) |
| return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1 + math.cos(math.pi * p)) |
|
|
|
|
| |
| |
| |
|
|
| def save_checkpoint(model, optimizer, epoch, step, best_loss, path): |
| raw = unwrap_model(model) |
| torch.save({ |
| "model": normalize_state_dict_keys(raw.state_dict()), |
| "optimizer": optimizer.state_dict(), |
| "epoch": epoch, "step": step, "best_loss": best_loss, |
| "config": asdict(raw.cfg), |
| }, path) |
|
|
| def maybe_load_base_checkpoint(model, device): |
| if BASE_CHECKPOINT is None or not Path(BASE_CHECKPOINT).exists(): |
| return |
| ckpt = torch.load(BASE_CHECKPOINT, map_location=device) |
| unwrap_model(model).load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=False) |
|
|
| def load_resume_checkpoint(model, optimizer, path, device): |
| ckpt = torch.load(path, map_location=device) |
| unwrap_model(model).load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=True) |
| try: |
| optimizer.load_state_dict(ckpt["optimizer"]) |
| except Exception as e: |
| print(f"[warn] Optimizer state non repris: {e}") |
| return int(ckpt.get("epoch", 0)), int(ckpt.get("step", 0)), float(ckpt.get("best_loss", 1e9)) |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def evaluate(model, loader, device, max_batches=200) -> float: |
| model.eval() |
| losses = [] |
| for i, batch in enumerate(loader): |
| if i >= max_batches: |
| break |
| inp = batch["input_ids"].to(device, non_blocking=True) |
| lbl = batch["labels"].to(device, non_blocking=True) |
| with autocast_context(device): |
| _, loss = model(inp, lbl) |
| losses.append(loss.item()) |
| model.train() |
| return sum(losses) / max(1, len(losses)) |
|
|
|
|
| |
| |
| |
|
|
| def make_loader(dataset, batch_size, num_workers, is_cuda): |
| kwargs = dict(batch_size=batch_size, num_workers=num_workers, pin_memory=is_cuda) |
| if num_workers > 0: |
| kwargs["persistent_workers"] = True |
| kwargs["prefetch_factor"] = PREFETCH_FACTOR |
| return torch.utils.data.DataLoader(dataset, **kwargs) |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| ddp_device = init_distributed() |
| set_seed(SEED + get_rank()) |
| device = get_device(ddp_device) |
| is_cuda = device.type == "cuda" |
|
|
| cuda_idx = None |
| vram_fraction = None |
|
|
| if is_cuda: |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.set_float32_matmul_precision("high") |
| cuda_idx = current_cuda_index(device) |
| _, total = torch.cuda.mem_get_info(cuda_idx) |
| vram_fraction = min(TARGET_VRAM_GIB * (1024**3) / total, 0.999) |
| torch.cuda.memory.set_per_process_memory_fraction(vram_fraction, device=cuda_idx) |
|
|
| if is_main(): |
| print("=" * 72) |
| print(" GPT ~1B | H100 80 Go | QLoRA + BF16 + TF32 | v2 (device fix)") |
| print("=" * 72) |
| print(f"Device : {device} | World: {get_world_size()} GPU(s)") |
| print(f"Flash-2 : {HAS_FLASH} | BNB 4-bit: {HAS_BNB} | QLoRA: {USE_QLORA}") |
| print(f"Grad ckpt: {USE_CHECKPOINTING} | Compile: {USE_COMPILE} ({COMPILE_MODE})") |
| if is_cuda: |
| free, total = torch.cuda.mem_get_info(cuda_idx) |
| print(f"GPU : {torch.cuda.get_device_name(cuda_idx)}") |
| print(f"VRAM : {total/1024**3:.1f} GiB | libre: {free/1024**3:.1f} GiB") |
|
|
| tokenizer = train_or_load_tokenizer() |
| cfg = GPTConfig(vocab_size=len(tokenizer)) |
|
|
| if is_main(): |
| CONFIG_FILE.write_text(json.dumps(asdict(cfg), indent=2, ensure_ascii=False), encoding="utf-8") |
|
|
| |
| model = GPT(cfg).to(device) |
|
|
| |
| |
| if USE_QLORA: |
| model = apply_qlora(model, device) |
| freeze_base_weights(model) |
|
|
| maybe_load_base_checkpoint(model, device) |
|
|
| |
| |
| |
| if USE_COMPILE and not USE_CHECKPOINTING and hasattr(torch, "compile"): |
| try: |
| model = torch.compile(model, mode=COMPILE_MODE) |
| if is_main(): |
| print(f"torch.compile activรฉ ({COMPILE_MODE})") |
| except Exception as e: |
| if is_main(): |
| print(f"[warn] torch.compile รฉchouรฉ ({e}) โ poursuite sans compile") |
|
|
| |
| if is_distributed(): |
| model = DDP(model, device_ids=[device.index]) |
|
|
| optimizer = build_optimizer(model) |
|
|
| |
| eval_texts = build_eval_texts() |
| eval_ds = PackedTextList(eval_texts, tokenizer, cfg.block_size, SEED + 999) |
| eval_loader = make_loader(eval_ds, BATCH_SIZE, EVAL_NUM_WORKERS, is_cuda) |
|
|
| init_texts = build_epoch_train_texts(0) |
| steps_per_epoch = max(1, len(init_texts) // BATCH_SIZE) |
| total_steps_est = steps_per_epoch * NUM_EPOCHS |
|
|
| |
| start_epoch, start_step, best_eval = 0, 0, 1e9 |
| if STATE_FILE.exists(): |
| try: |
| if is_main(): print(f"Reprise depuis {STATE_FILE}") |
| start_epoch, start_step, best_eval = load_resume_checkpoint(model, optimizer, STATE_FILE, device) |
| except Exception as e: |
| if is_main(): |
| bad = STATE_FILE.with_suffix(".corrupt.pt") |
| print(f"[warn] Checkpoint illisible: {e}") |
| try: STATE_FILE.rename(bad) |
| except Exception: pass |
| start_epoch, start_step, best_eval = 0, 0, 1e9 |
|
|
| if is_main(): |
| raw = unwrap_model(model) |
| n_total = count_parameters(raw, False) |
| n_train = count_parameters(raw, True) |
| print(f"Paramรจtres totaux : {n_total/1e9:.3f}B") |
| print(f"Paramรจtres entraรฎnรฉs : {n_train/1e6:.1f}M ({100*n_train/max(1,n_total):.2f}%)") |
| print(f"Batch size : {BATCH_SIZE} | Grad accum: {GRAD_ACCUM_STEPS} | Effective: {BATCH_SIZE*GRAD_ACCUM_STEPS}") |
| print(f"Steps estimรฉs: {total_steps_est} | Eval texts: {len(eval_texts)}") |
| print() |
| print("โโ Conseil VRAM โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ") |
| print(" Surveille 'max_reserved=XX GiB' ร step 50.") |
| print(" Augmente BATCH_SIZE par +2 jusqu'ร ~77 Go rรฉservรฉs.") |
| print(" Si OOM : BATCH_SIZE -= 1 ou USE_CHECKPOINTING=True.") |
| print("โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ") |
|
|
| |
| model.train() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| global_step = start_step |
| t0 = time.time() |
| log_loss_sum = 0.0 |
| log_loss_count = 0 |
| tokens_since_log = 0 |
| last_log = time.time() |
|
|
| if is_cuda: |
| torch.cuda.reset_peak_memory_stats(cuda_idx) |
|
|
| for epoch in range(start_epoch, NUM_EPOCHS): |
| if is_main(): |
| print(f"\n{'='*20} Epoch {epoch+1}/{NUM_EPOCHS} {'='*20}") |
|
|
| train_texts = build_epoch_train_texts(epoch) |
| train_ds = PackedTextList(train_texts, tokenizer, cfg.block_size, SEED + epoch) |
| train_loader = make_loader(train_ds, BATCH_SIZE, TRAIN_NUM_WORKERS, is_cuda) |
|
|
| for micro_step, batch in enumerate(train_loader): |
| inp = batch["input_ids"].to(device, non_blocking=True) |
| lbl = batch["labels"].to(device, non_blocking=True) |
|
|
| with autocast_context(device): |
| _, loss = model(inp, lbl) |
|
|
| (loss / GRAD_ACCUM_STEPS).backward() |
|
|
| log_loss_sum += loss.item() |
| log_loss_count += 1 |
| tokens_since_log += inp.numel() |
|
|
| if (micro_step + 1) % GRAD_ACCUM_STEPS != 0: |
| continue |
|
|
| lr = cosine_lr(global_step, total_steps_est) |
| for group in optimizer.param_groups: |
| group["lr"] = lr |
|
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM) |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
| global_step += 1 |
|
|
| if global_step % 50 == 0 and is_main(): |
| now = time.time() |
| elapsed = max(1e-6, now - last_log) |
| tok_s = tokens_since_log / elapsed |
| avg_loss = log_loss_sum / max(1, log_loss_count) |
| print( |
| f"ep {epoch+1}/{NUM_EPOCHS} | step={global_step:5d} | " |
| f"loss={avg_loss:.4f} | lr={lr:.2e} | {tok_s:,.0f} tok/s" |
| ) |
| if is_cuda: |
| alloc = torch.cuda.memory_allocated(cuda_idx) / 1024**3 |
| reserved = torch.cuda.memory_reserved(cuda_idx) / 1024**3 |
| max_alloc = torch.cuda.max_memory_allocated(cuda_idx) / 1024**3 |
| max_res = torch.cuda.max_memory_reserved(cuda_idx) / 1024**3 |
| print( |
| f"GPU mem | alloc={alloc:.2f} | reserved={reserved:.2f} | " |
| f"max_alloc={max_alloc:.2f} | max_reserved={max_res:.2f} (GiB)" |
| ) |
| last_log = now |
| tokens_since_log = 0 |
| log_loss_sum = 0.0 |
| log_loss_count = 0 |
|
|
| if global_step % EVAL_EVERY == 0 and is_main(): |
| val_loss = evaluate(model, eval_loader, device) |
| print(f"[eval] step {global_step:5d} | val_loss={val_loss:.4f}") |
| if val_loss < best_eval: |
| best_eval = val_loss |
| save_checkpoint(model, optimizer, epoch, global_step, best_eval, BEST_MODEL_FILE) |
| print(f"โ Meilleur modรจle โ {BEST_MODEL_FILE}") |
|
|
| if global_step % SAVE_EVERY == 0 and is_main(): |
| save_checkpoint(model, optimizer, epoch, global_step, best_eval, STATE_FILE) |
| save_checkpoint(model, optimizer, epoch, global_step, best_eval, MODEL_FILE) |
| print(f"โ Checkpoint โ {MODEL_FILE}") |
|
|
| if is_main(): |
| save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, STATE_FILE) |
| ckpt = OUT_DIR / f"model_epoch_{epoch+1:02d}.pt" |
| save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, ckpt) |
| print(f"โ Fin epoch {epoch+1}/{NUM_EPOCHS} โ {ckpt}") |
|
|
| if is_main(): |
| save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, MODEL_FILE) |
| save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, STATE_FILE) |
| total_min = (time.time() - t0) / 60 |
| print(f"\nModรจle final โ {MODEL_FILE}") |
| print(f"Meilleur modรจle โ {BEST_MODEL_FILE}") |
| print(f"Temps total : {total_min:.1f} min | Steps: {global_step}") |
|
|
| if is_distributed(): |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |