FirstChat / train.py
Medyassino's picture
Add files using upload-large-folder tool
b9049d2 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import json
import math
import os
import random
import time
from collections import OrderedDict
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Iterator, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from datasets import load_dataset
from transformers import PreTrainedTokenizerFast
# ============================================================
# Base model / tokenizer / config
# ============================================================
BASE_CHECKPOINT = Path("./wikipedia_ar_h100_codealpaca/model_best.pt")
BASE_TOKENIZER_DIR = Path("./wikipedia_ar_h100/tokenizer_32k")
BASE_CONFIG_FILE = Path("./wikipedia_ar_h100/config.json")
OUT_DIR = Path("./wikipedia_ar_h100_multicode_10x2000")
OUT_DIR.mkdir(parents=True, exist_ok=True)
MODEL_FILE = OUT_DIR / "model.pt"
BEST_MODEL_FILE = OUT_DIR / "model_best.pt"
STATE_FILE = OUT_DIR / "train_state.pt"
CONFIG_FILE = OUT_DIR / "config.json"
# ============================================================
# Datasets
# ============================================================
TRAIN_SOURCES = [
{
"name": "HuggingFaceH4/CodeAlpaca_20K",
"subset": None,
"split": "train",
"kind": "codealpaca",
"weight": 0.45,
"streaming": False,
},
{
"name": "open-r1/codeforces",
"subset": "verifiable-prompts",
"split": "train",
"kind": "codeforces_python",
"weight": 0.35,
"streaming": False,
},
{
"name": "wikimedia/wikipedia",
"subset": "20231101.ar",
"split": "train",
"kind": "wikipedia_ar",
"weight": 0.20,
"streaming": True,
},
]
EVAL_SOURCE = {
"name": "HuggingFaceH4/CodeAlpaca_20K",
"subset": None,
"split": "test",
"kind": "codealpaca",
"streaming": False,
}
CODEFORCES_LANGUAGE = "python"
# ============================================================
# Hyperparamètres
# ============================================================
SEED = 42
TARGET_VRAM_GIB = 75.0
LEARNING_RATE = 5e-5
MIN_LR = 5e-6
WEIGHT_DECAY = 0.1
WARMUP_STEPS = 200
NUM_ROUNDS = 10
STEPS_PER_ROUND = 2000
MAX_STEPS = NUM_ROUNDS * STEPS_PER_ROUND # 20000
BATCH_SIZE = 24
GRAD_ACCUM_STEPS = 1
MAX_GRAD_NORM = 1.0
EVAL_EVERY = 250
SAVE_EVERY = 500
MAX_EVAL_EXAMPLES = 2000
TEXT_CHAR_LIMIT = 6000
DTYPE = torch.bfloat16
USE_COMPILE = True
COMPILE_MODE = "default"
USE_CHECKPOINTING = False
TRAIN_NUM_WORKERS = 0
EVAL_NUM_WORKERS = 0
# ============================================================
# Helpers
# ============================================================
def is_distributed() -> bool:
return dist.is_available() and dist.is_initialized()
def get_rank() -> int:
return dist.get_rank() if is_distributed() else 0
def get_world_size() -> int:
return dist.get_world_size() if is_distributed() else 1
def is_main() -> bool:
return get_rank() == 0
def init_distributed() -> Optional[torch.device]:
local_rank = int(os.environ.get("LOCAL_RANK", -1))
if local_rank == -1:
return None
dist.init_process_group("nccl")
torch.cuda.set_device(local_rank)
return torch.device(f"cuda:{local_rank}")
def set_seed(seed: int) -> None:
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def get_device(ddp_device: Optional[torch.device] = None) -> torch.device:
if ddp_device is not None:
return ddp_device
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
return torch.device("cpu")
def current_cuda_index(device: torch.device) -> int:
if device.type != "cuda":
raise ValueError("Device non CUDA")
return device.index if device.index is not None else torch.cuda.current_device()
def autocast_context(device: torch.device):
if device.type == "cuda":
return torch.autocast("cuda", dtype=DTYPE)
return nullcontext()
def unwrap_model(model: nn.Module) -> nn.Module:
m = model.module if isinstance(model, DDP) else model
if hasattr(m, "_orig_mod"):
return m._orig_mod
return m
def count_parameters(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def normalize_state_dict_keys(state_dict: dict) -> OrderedDict:
normalized = OrderedDict()
for k, v in state_dict.items():
nk = k
if nk.startswith("module._orig_mod."):
nk = nk[len("module._orig_mod."):]
elif nk.startswith("_orig_mod."):
nk = nk[len("_orig_mod."):]
elif nk.startswith("module."):
nk = nk[len("module."):]
normalized[nk] = v
return normalized
def normalize_text(text: str) -> str:
return " ".join(text.strip().split())
# ============================================================
# Dataset loading / formatting
# ============================================================
def load_one_dataset(spec: dict):
kwargs = {
"path": spec["name"],
"split": spec["split"],
"streaming": spec["streaming"],
}
if spec["subset"] is not None:
kwargs["name"] = spec["subset"]
return load_dataset(**kwargs)
def format_record(row: dict, kind: str) -> str:
if kind == "codealpaca":
prompt = row.get("prompt", "")
completion = row.get("completion", "")
if not isinstance(prompt, str):
prompt = str(prompt)
if not isinstance(completion, str):
completion = str(completion)
text = (
"### Instruction\n"
f"{prompt.strip()}\n\n"
"### Response\n"
f"{completion.strip()}"
)
return normalize_text(text)
if kind == "codeforces_python":
language = row.get("language", "")
if language != CODEFORCES_LANGUAGE:
return ""
prompt = row.get("prompt", "")
title = row.get("title", "")
if not isinstance(prompt, str):
prompt = str(prompt)
if not isinstance(title, str):
title = str(title)
text = (
f"### Competitive Programming Problem ({language})\n"
f"{title.strip()}\n\n"
f"{prompt.strip()}"
)
return normalize_text(text)
if kind == "wikipedia_ar":
text = row.get("text", "")
if not isinstance(text, str):
text = str(text)
return normalize_text(text)
return ""
def example_text_iter(spec: dict, max_examples: Optional[int] = None) -> Iterator[str]:
ds = load_one_dataset(spec)
n = 0
for row in ds:
text = format_record(row, spec["kind"])
if not text or len(text) < 20:
continue
if TEXT_CHAR_LIMIT is not None:
text = text[:TEXT_CHAR_LIMIT]
yield text
n += 1
if max_examples is not None and n >= max_examples:
break
class MixedTextSource:
def __init__(self, specs: list[dict]):
self.specs = specs
self.weights = [s["weight"] for s in specs]
self.streams = [example_text_iter(s) for s in specs]
def next_text(self) -> str:
while True:
idx = random.choices(range(len(self.specs)), weights=self.weights, k=1)[0]
try:
return next(self.streams[idx])
except StopIteration:
self.streams[idx] = example_text_iter(self.specs[idx])
def packed_block_stream_mixed(
tokenizer: PreTrainedTokenizerFast,
specs: list[dict],
block_size: int,
) -> Iterator[list[int]]:
bos, eos = tokenizer.bos_token_id, tokenizer.eos_token_id
buffer: list[int] = []
source = MixedTextSource(specs)
while True:
text = source.next_text()
ids = tokenizer.encode(text, add_special_tokens=False)
if not ids:
continue
buffer.extend([bos] + ids + [eos])
while len(buffer) >= block_size + 1:
yield buffer[: block_size + 1]
buffer = buffer[block_size + 1:]
class PackedMixedBlocks(torch.utils.data.IterableDataset):
def __init__(self, tokenizer, specs, block_size):
super().__init__()
self.tokenizer = tokenizer
self.specs = specs
self.block_size = block_size
def __iter__(self):
worker = torch.utils.data.get_worker_info()
rank = get_rank()
world_size = get_world_size()
if worker is None:
shard_mod = world_size
shard_id = rank
else:
shard_mod = worker.num_workers * world_size
shard_id = rank * worker.num_workers + worker.id
for idx, chunk in enumerate(
packed_block_stream_mixed(
tokenizer=self.tokenizer,
specs=self.specs,
block_size=self.block_size,
)
):
if idx % shard_mod != shard_id:
continue
yield {
"input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
"labels": torch.tensor(chunk[1:], dtype=torch.long),
}
class PackedEvalBlocks(torch.utils.data.IterableDataset):
def __init__(self, tokenizer, spec, block_size, max_examples):
super().__init__()
self.tokenizer = tokenizer
self.spec = spec
self.block_size = block_size
self.max_examples = max_examples
def __iter__(self):
worker = torch.utils.data.get_worker_info()
rank = get_rank()
world_size = get_world_size()
if worker is None:
shard_mod = world_size
shard_id = rank
else:
shard_mod = worker.num_workers * world_size
shard_id = rank * worker.num_workers + worker.id
bos, eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
buffer: list[int] = []
for ex_idx, text in enumerate(example_text_iter(self.spec, max_examples=self.max_examples)):
if ex_idx % shard_mod != shard_id:
continue
ids = self.tokenizer.encode(text, add_special_tokens=False)
if not ids:
continue
buffer.extend([bos] + ids + [eos])
while len(buffer) >= self.block_size + 1:
chunk = buffer[: self.block_size + 1]
buffer = buffer[self.block_size + 1:]
yield {
"input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
"labels": torch.tensor(chunk[1:], dtype=torch.long),
}
# ============================================================
# Architecture
# ============================================================
@dataclass
class GPTConfig:
vocab_size: int
block_size: int
d_model: int
n_heads: int
n_layers: int
d_ff: int
dropout: float = 0.0
use_checkpointing: bool = False
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, base: int = 10000, max_seq: int = 4096):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(max_seq).float()
freqs = torch.outer(t, inv_freq)
self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False)
self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False)
def forward(self, seq_len: int, dtype: torch.dtype):
return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x[..., ::2], x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).flatten(-2)
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
return x * cos + rotate_half(x) * sin
class CausalSelfAttention(nn.Module):
def __init__(self, cfg: GPTConfig):
super().__init__()
assert cfg.d_model % cfg.n_heads == 0
self.n_heads = cfg.n_heads
self.head_dim = cfg.d_model // cfg.n_heads
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
self.dropout_p = cfg.dropout
self.rope = RotaryEmbedding(self.head_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, t, c = x.shape
q, k, v = self.qkv(x).split(c, dim=-1)
q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rope(t, x.dtype)
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)
y = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout_p if self.training else 0.0,
is_causal=True,
)
y = y.transpose(1, 2).contiguous().view(b, t, c)
return self.proj(y)
class SwiGLU(nn.Module):
def __init__(self, cfg: GPTConfig):
super().__init__()
self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w3(F.silu(self.w1(x)) * self.w2(x))
class Block(nn.Module):
def __init__(self, cfg: GPTConfig):
super().__init__()
self.ln1 = RMSNorm(cfg.d_model)
self.attn = CausalSelfAttention(cfg)
self.ln2 = RMSNorm(cfg.d_model)
self.ff = SwiGLU(cfg)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln1(x))
x = x + self.ff(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, cfg: GPTConfig):
super().__init__()
self.cfg = cfg
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
self.ln_f = RMSNorm(cfg.d_model)
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
self.lm_head.weight = self.tok_emb.weight
self.apply(self._init_weights)
@staticmethod
def _init_weights(m: nn.Module) -> None:
if isinstance(m, (nn.Linear, nn.Embedding)):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None):
x = self.tok_emb(input_ids)
for block in self.blocks:
if self.cfg.use_checkpointing and self.training:
x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
else:
x = block(x)
logits = self.lm_head(self.ln_f(x))
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
labels.reshape(-1),
ignore_index=-100,
)
return logits, loss
# ============================================================
# Optimizer / LR
# ============================================================
def build_optimizer(model: nn.Module) -> torch.optim.Optimizer:
decay, no_decay = [], []
for name, p in unwrap_model(model).named_parameters():
if not p.requires_grad:
continue
(decay if p.ndim >= 2 and "weight" in name else no_decay).append(p)
return torch.optim.AdamW(
[
{"params": decay, "weight_decay": WEIGHT_DECAY},
{"params": no_decay, "weight_decay": 0.0},
],
lr=LEARNING_RATE,
betas=(0.9, 0.95),
eps=1e-8,
fused=torch.cuda.is_available(),
)
def cosine_lr(step: int) -> float:
if step < WARMUP_STEPS:
return LEARNING_RATE * step / max(1, WARMUP_STEPS)
p = min(1.0, (step - WARMUP_STEPS) / max(1, MAX_STEPS - WARMUP_STEPS))
return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1 + math.cos(math.pi * p))
# ============================================================
# Checkpoints
# ============================================================
def load_base_config() -> GPTConfig:
cfg_dict = json.loads(BASE_CONFIG_FILE.read_text(encoding="utf-8"))
cfg_dict["use_checkpointing"] = USE_CHECKPOINTING
return GPTConfig(**cfg_dict)
def initialize_model_from_base(model: nn.Module, device: torch.device) -> None:
if not BASE_CHECKPOINT.exists():
raise FileNotFoundError(f"Checkpoint de base introuvable: {BASE_CHECKPOINT}")
ckpt = torch.load(BASE_CHECKPOINT, map_location=device)
state_dict = normalize_state_dict_keys(ckpt["model"])
unwrap_model(model).load_state_dict(state_dict, strict=True)
def save_checkpoint(model, optimizer, step, best_loss, path):
raw = unwrap_model(model)
model_state = normalize_state_dict_keys(raw.state_dict())
torch.save(
{
"model": model_state,
"optimizer": optimizer.state_dict(),
"step": step,
"best_loss": best_loss,
"config": asdict(raw.cfg),
},
path,
)
def load_resume_checkpoint(model, optimizer, path, device) -> tuple[int, float]:
ckpt = torch.load(path, map_location=device)
raw = unwrap_model(model)
model_state = normalize_state_dict_keys(ckpt["model"])
raw.load_state_dict(model_state, strict=True)
try:
optimizer.load_state_dict(ckpt["optimizer"])
except Exception as e:
print(f"[warn] Optimizer state non repris: {e}")
return int(ckpt.get("step", 0)), float(ckpt.get("best_loss", 1e9))
# ============================================================
# Evaluation
# ============================================================
@torch.no_grad()
def evaluate(model, loader, device, max_batches: int = 100) -> float:
model.eval()
losses = []
for i, batch in enumerate(loader):
if i >= max_batches:
break
inp = batch["input_ids"].to(device, non_blocking=True)
lbl = batch["labels"].to(device, non_blocking=True)
with autocast_context(device):
_, loss = model(inp, lbl)
losses.append(loss.item())
model.train()
return sum(losses) / max(1, len(losses))
# ============================================================
# Main
# ============================================================
def main() -> None:
ddp_device = init_distributed()
set_seed(SEED + get_rank())
device = get_device(ddp_device)
cuda_device_index = None
vram_fraction = None
if device.type == "cuda":
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
cuda_device_index = current_cuda_index(device)
_, total_mem_bytes = torch.cuda.mem_get_info(cuda_device_index)
target_bytes = int(TARGET_VRAM_GIB * (1024 ** 3))
vram_fraction = min(target_bytes / total_mem_bytes, 0.999)
torch.cuda.memory.set_per_process_memory_fraction(
vram_fraction,
device=cuda_device_index,
)
if is_main():
print("=" * 60)
print(" Re-train même modèle | 10 x 2000 steps")
print("=" * 60)
print(f"Device: {device} | World: {get_world_size()} GPU(s)")
if device.type == "cuda":
free_mem, total_mem = torch.cuda.mem_get_info(cuda_device_index)
print(f"GPU: {torch.cuda.get_device_name(cuda_device_index)}")
print(f"VRAM cible: {TARGET_VRAM_GIB:.1f} GiB")
print(f"Fraction PyTorch: {vram_fraction:.4f}")
print(f"GPU total: {total_mem / 1024**3:.2f} GiB | libre: {free_mem / 1024**3:.2f} GiB")
print(f"Rounds: {NUM_ROUNDS} | Steps/round: {STEPS_PER_ROUND} | MAX_STEPS: {MAX_STEPS}")
tokenizer = PreTrainedTokenizerFast.from_pretrained(str(BASE_TOKENIZER_DIR))
cfg = load_base_config()
cfg.vocab_size = len(tokenizer)
if is_main():
CONFIG_FILE.write_text(
json.dumps(asdict(cfg), indent=2, ensure_ascii=False),
encoding="utf-8",
)
print(f"Base checkpoint: {BASE_CHECKPOINT}")
print(f"Tokenizer: {BASE_TOKENIZER_DIR}")
model = GPT(cfg).to(device)
initialize_model_from_base(model, device)
if USE_COMPILE and hasattr(torch, "compile"):
model = torch.compile(model, mode=COMPILE_MODE)
if is_main():
print(f"torch.compile activé ({COMPILE_MODE})")
if is_distributed():
model = DDP(model, device_ids=[device.index])
optimizer = build_optimizer(model)
start_step, best_eval = 0, 1e9
if STATE_FILE.exists():
try:
if is_main():
print(f"Reprise depuis {STATE_FILE}")
start_step, best_eval = load_resume_checkpoint(model, optimizer, STATE_FILE, device)
except Exception as e:
if is_main():
bad_path = STATE_FILE.with_suffix(".corrupt.pt")
print(f"[warn] Checkpoint illisible: {e}")
try:
STATE_FILE.rename(bad_path)
print(f"[warn] Checkpoint corrompu renommé vers {bad_path}")
except Exception:
pass
print("[warn] Reprise ignorée, démarrage depuis le checkpoint de base.")
start_step, best_eval = 0, 1e9
if start_step >= MAX_STEPS:
if is_main():
print(f"[warn] start_step={start_step} >= MAX_STEPS={MAX_STEPS}")
print("[warn] Rien à entraîner.")
return
train_ds = PackedMixedBlocks(
tokenizer=tokenizer,
specs=TRAIN_SOURCES,
block_size=cfg.block_size,
)
eval_ds = PackedEvalBlocks(
tokenizer=tokenizer,
spec=EVAL_SOURCE,
block_size=cfg.block_size,
max_examples=MAX_EVAL_EXAMPLES,
)
train_loader = torch.utils.data.DataLoader(
train_ds,
batch_size=BATCH_SIZE,
num_workers=TRAIN_NUM_WORKERS,
pin_memory=(device.type == "cuda"),
)
eval_loader = torch.utils.data.DataLoader(
eval_ds,
batch_size=BATCH_SIZE,
num_workers=EVAL_NUM_WORKERS,
pin_memory=(device.type == "cuda"),
)
if is_main():
raw_model = unwrap_model(model)
n_params = count_parameters(raw_model)
print(f"Paramètres: {n_params / 1e6:.1f}M")
print(f"Architecture: d={cfg.d_model} | heads={cfg.n_heads} | layers={cfg.n_layers} | block={cfg.block_size}")
print(f"Batch size: {BATCH_SIZE} | Grad accum: {GRAD_ACCUM_STEPS}")
print(f"Dtype: {DTYPE} | Compile: {USE_COMPILE} ({COMPILE_MODE if USE_COMPILE else 'off'})")
model.train()
optimizer.zero_grad(set_to_none=True)
train_iter = iter(train_loader)
step = start_step
t0 = time.time()
log_loss_sum = 0.0
log_loss_count = 0
tokens_since_log = 0
last_log = time.time()
if device.type == "cuda":
torch.cuda.reset_peak_memory_stats(cuda_device_index)
current_round = (step // STEPS_PER_ROUND) + 1
while step < MAX_STEPS:
for _ in range(GRAD_ACCUM_STEPS):
batch = next(train_iter)
inp = batch["input_ids"].to(device, non_blocking=True)
lbl = batch["labels"].to(device, non_blocking=True)
with autocast_context(device):
_, loss = model(inp, lbl)
(loss / GRAD_ACCUM_STEPS).backward()
log_loss_sum += loss.item()
log_loss_count += 1
tokens_since_log += inp.numel()
lr = cosine_lr(step)
for group in optimizer.param_groups:
group["lr"] = lr
torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
step += 1
new_round = ((step - 1) // STEPS_PER_ROUND) + 1
if new_round != current_round and is_main():
current_round = new_round
print(f"\n===== Round {current_round}/{NUM_ROUNDS} =====")
if step % 50 == 0 and is_main():
now = time.time()
elapsed = max(1e-6, now - last_log)
tok_s = tokens_since_log / elapsed
avg_loss = log_loss_sum / max(1, log_loss_count)
round_idx = ((step - 1) // STEPS_PER_ROUND) + 1
step_in_round = ((step - 1) % STEPS_PER_ROUND) + 1
print(
f"round {round_idx:2d}/{NUM_ROUNDS} | "
f"step {step_in_round:4d}/{STEPS_PER_ROUND} | "
f"global {step:5d}/{MAX_STEPS} | "
f"loss={avg_loss:.4f} | lr={lr:.2e} | {tok_s:,.0f} tok/s"
)
if device.type == "cuda":
allocated = torch.cuda.memory_allocated(cuda_device_index) / 1024**3
reserved = torch.cuda.memory_reserved(cuda_device_index) / 1024**3
max_alloc = torch.cuda.max_memory_allocated(cuda_device_index) / 1024**3
max_reserved = torch.cuda.max_memory_reserved(cuda_device_index) / 1024**3
print(
f"GPU mem | alloc={allocated:.2f} GiB | reserved={reserved:.2f} GiB | "
f"max_alloc={max_alloc:.2f} GiB | max_reserved={max_reserved:.2f} GiB"
)
last_log = now
tokens_since_log = 0
log_loss_sum = 0.0
log_loss_count = 0
if step % EVAL_EVERY == 0 and is_main():
val_loss = evaluate(model, eval_loader, device)
print(f"[eval] global step {step:5d} | val_loss={val_loss:.4f}")
if val_loss < best_eval:
best_eval = val_loss
save_checkpoint(model, optimizer, step, best_eval, BEST_MODEL_FILE)
print(f"✓ Meilleur modèle → {BEST_MODEL_FILE}")
if step % SAVE_EVERY == 0 and is_main():
save_checkpoint(model, optimizer, step, best_eval, STATE_FILE)
save_checkpoint(model, optimizer, step, best_eval, MODEL_FILE)
print(f"✓ Checkpoint → {MODEL_FILE}")
if step % STEPS_PER_ROUND == 0 and is_main():
round_no = step // STEPS_PER_ROUND
round_ckpt = OUT_DIR / f"model_round_{round_no:02d}.pt"
save_checkpoint(model, optimizer, step, best_eval, round_ckpt)
print(f"✓ Fin round {round_no}/{NUM_ROUNDS}{round_ckpt}")
if is_main():
save_checkpoint(model, optimizer, step, best_eval, MODEL_FILE)
save_checkpoint(model, optimizer, step, best_eval, STATE_FILE)
total = (time.time() - t0) / 60
print(f"\nModèle final → {MODEL_FILE}")
print(f"Meilleur modèle → {BEST_MODEL_FILE}")
print(f"Temps total : {total:.1f} min")
print(f"Steps effectués : {step}")
if is_distributed():
dist.destroy_process_group()
if __name__ == "__main__":
main()