| |
| |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from collections import OrderedDict |
| from contextlib import nullcontext |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Optional, List |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedTokenizerFast |
|
|
| MODEL_DIR = Path("./nlp_1b_h100_2h") |
| DEFAULT_CHECKPOINT = MODEL_DIR / "model_best.pt" |
| DEFAULT_CONFIG = MODEL_DIR / "config.json" |
| DEFAULT_TOKENIZER_DIR = Path("./nlp_1b_h100_opt/tokenizer_32k") |
|
|
|
|
| def get_device() -> torch.device: |
| if torch.cuda.is_available(): |
| return torch.device(f"cuda:{torch.cuda.current_device()}") |
| return torch.device("cpu") |
|
|
|
|
| def autocast_context(device: torch.device): |
| if device.type == "cuda": |
| return torch.autocast("cuda", dtype=torch.bfloat16) |
| return nullcontext() |
|
|
|
|
| def normalize_state_dict_keys(state_dict: dict) -> OrderedDict: |
| normalized = OrderedDict() |
| for k, v in state_dict.items(): |
| nk = k |
| if nk.startswith("module._orig_mod."): |
| nk = nk[len("module._orig_mod."):] |
| elif nk.startswith("_orig_mod."): |
| nk = nk[len("_orig_mod."):] |
| elif nk.startswith("module."): |
| nk = nk[len("module."):] |
| normalized[nk] = v |
| return normalized |
|
|
|
|
| def clean_text(text: str) -> str: |
| text = text.replace("\x00", " ").strip() |
| return " ".join(text.split()) |
|
|
|
|
| @dataclass |
| class GPTConfig: |
| vocab_size: int |
| block_size: int |
| d_model: int |
| n_heads: int |
| n_layers: int |
| d_ff: int |
| dropout: float = 0.0 |
| use_checkpointing: bool = False |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim: int, base: int = 10000, max_seq: int = 4096): |
| super().__init__() |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| t = torch.arange(max_seq).float() |
| freqs = torch.outer(t, inv_freq) |
| self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False) |
| self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False) |
|
|
| def forward(self, seq_len: int, dtype: torch.dtype): |
| return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype) |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| x1, x2 = x[..., ::2], x[..., 1::2] |
| return torch.stack((-x2, x1), dim=-1).flatten(-2) |
|
|
|
|
| def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| cos = cos.unsqueeze(0).unsqueeze(0) |
| sin = sin.unsqueeze(0).unsqueeze(0) |
| return x * cos + rotate_half(x) * sin |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, cfg: GPTConfig): |
| super().__init__() |
| assert cfg.d_model % cfg.n_heads == 0 |
| self.n_heads = cfg.n_heads |
| self.head_dim = cfg.d_model // cfg.n_heads |
| self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False) |
| self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False) |
| self.rope = RotaryEmbedding(self.head_dim) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| b, t, c = x.shape |
| q, k, v = self.qkv(x).split(c, dim=-1) |
| q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) |
| k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) |
| v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
| cos, sin = self.rope(t, x.dtype) |
| q = apply_rope(q, cos, sin) |
| k = apply_rope(k, cos, sin) |
|
|
| y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True) |
| y = y.transpose(1, 2).contiguous().view(b, t, c) |
| return self.proj(y) |
|
|
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, cfg: GPTConfig): |
| super().__init__() |
| self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) |
| self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) |
| self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w3(F.silu(self.w1(x)) * self.w2(x)) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, cfg: GPTConfig): |
| super().__init__() |
| self.ln1 = RMSNorm(cfg.d_model) |
| self.attn = CausalSelfAttention(cfg) |
| self.ln2 = RMSNorm(cfg.d_model) |
| self.ff = SwiGLU(cfg) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x + self.attn(self.ln1(x)) |
| x = x + self.ff(self.ln2(x)) |
| return x |
|
|
|
|
| class GPT(nn.Module): |
| def __init__(self, cfg: GPTConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) |
| self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) |
| self.ln_f = RMSNorm(cfg.d_model) |
| self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) |
| self.lm_head.weight = self.tok_emb.weight |
|
|
| def forward(self, input_ids: torch.Tensor) -> torch.Tensor: |
| x = self.tok_emb(input_ids) |
| for block in self.blocks: |
| x = block(x) |
| return self.lm_head(self.ln_f(x)) |
|
|
| @torch.inference_mode() |
| def generate( |
| self, |
| input_ids: torch.Tensor, |
| max_new_tokens: int = 96, |
| temperature: float = 0.2, |
| top_k: int = 20, |
| top_p: float = 0.8, |
| repetition_penalty: float = 1.2, |
| eos_token_id: Optional[int] = None, |
| no_repeat_ngram_size: int = 3, |
| ) -> torch.Tensor: |
| self.eval() |
|
|
| for _ in range(max_new_tokens): |
| idx_cond = input_ids[:, -self.cfg.block_size:] |
| logits = self(idx_cond) |
| logits = logits[:, -1, :] |
|
|
| if repetition_penalty != 1.0: |
| for b in range(input_ids.size(0)): |
| seen = torch.unique(input_ids[b]) |
| seen_logits = logits[b, seen] |
| logits[b, seen] = torch.where( |
| seen_logits < 0, |
| seen_logits * repetition_penalty, |
| seen_logits / repetition_penalty, |
| ) |
|
|
| if no_repeat_ngram_size > 0 and input_ids.size(1) >= no_repeat_ngram_size - 1: |
| n = no_repeat_ngram_size |
| for b in range(input_ids.size(0)): |
| prefix = tuple(input_ids[b, -(n - 1):].tolist()) |
| banned = set() |
| toks = input_ids[b].tolist() |
| for i in range(len(toks) - n + 1): |
| if tuple(toks[i:i+n-1]) == prefix: |
| banned.add(toks[i+n-1]) |
| if banned: |
| logits[b, list(banned)] = -float("inf") |
|
|
| if temperature <= 0: |
| next_token = torch.argmax(logits, dim=-1, keepdim=True) |
| else: |
| logits = logits / max(temperature, 1e-6) |
|
|
| if top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float("inf") |
|
|
| if 0 < top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| probs = F.softmax(sorted_logits, dim=-1) |
| cumulative_probs = torch.cumsum(probs, dim=-1) |
|
|
| sorted_mask = cumulative_probs > top_p |
| sorted_mask[..., 1:] = sorted_mask[..., :-1].clone() |
| sorted_mask[..., 0] = False |
|
|
| mask = torch.zeros_like(logits, dtype=torch.bool) |
| mask.scatter_(1, sorted_indices, sorted_mask) |
| logits = logits.masked_fill(mask, -float("inf")) |
|
|
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
|
|
| input_ids = torch.cat([input_ids, next_token], dim=1) |
|
|
| if eos_token_id is not None and (next_token == eos_token_id).all(): |
| break |
|
|
| return input_ids |
|
|
|
|
| def load_model_and_tokenizer(checkpoint_path: Path, config_path: Path, tokenizer_dir: Path, device: torch.device): |
| if not checkpoint_path.exists(): |
| raise FileNotFoundError(f"Checkpoint introuvable: {checkpoint_path}") |
| if not config_path.exists(): |
| raise FileNotFoundError(f"Config introuvable: {config_path}") |
| if not tokenizer_dir.exists(): |
| raise FileNotFoundError(f"Tokenizer introuvable: {tokenizer_dir}") |
|
|
| cfg_dict = json.loads(config_path.read_text(encoding="utf-8")) |
| cfg = GPTConfig(**cfg_dict) |
|
|
| tokenizer = PreTrainedTokenizerFast.from_pretrained(str(tokenizer_dir)) |
| model = GPT(cfg).to(device) |
|
|
| ckpt = torch.load(checkpoint_path, map_location=device) |
| state_dict = normalize_state_dict_keys(ckpt["model"]) |
| model.load_state_dict(state_dict, strict=True) |
| model.eval() |
| return model, tokenizer, ckpt |
|
|
|
|
| def build_prompt(text: str, mode: str) -> str: |
| if mode == "raw": |
| return text |
| if mode == "completion": |
| return text |
| if mode == "qa": |
| return f"Réponds brièvement en français.\nQuestion: {text}\nRéponse:" |
| if mode == "instruction": |
| return f"Instruction: Réponds de façon concise.\nEntrée: {text}\nSortie:" |
| raise ValueError(f"Mode inconnu: {mode}") |
|
|
|
|
| def encode_prompt(tokenizer: PreTrainedTokenizerFast, prompt: str, device: torch.device) -> torch.Tensor: |
| encoded = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) |
| input_ids = encoded["input_ids"].to(device) |
| if tokenizer.bos_token_id is not None: |
| bos = torch.tensor([[tokenizer.bos_token_id]], device=device, dtype=input_ids.dtype) |
| input_ids = torch.cat([bos, input_ids], dim=1) |
| return input_ids |
|
|
|
|
| def generate_text(model, tokenizer, prompt, device, max_new_tokens, temperature, top_k, top_p, repetition_penalty): |
| input_ids = encode_prompt(tokenizer, prompt, device) |
| prompt_len = input_ids.shape[1] |
|
|
| with autocast_context(device): |
| output_ids = model.generate( |
| input_ids=input_ids, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| eos_token_id=tokenizer.eos_token_id, |
| no_repeat_ngram_size=3, |
| ) |
|
|
| generated_ids = output_ids[0][prompt_len:] |
| return clean_text(tokenizer.decode(generated_ids, skip_special_tokens=True)) |
|
|
|
|
| @torch.inference_mode() |
| def score_text(model, tokenizer, text: str, device: torch.device) -> dict: |
| ids = encode_prompt(tokenizer, text, device) |
| if ids.size(1) < 2: |
| return {"tokens": int(ids.size(1)), "loss": None, "ppl": None} |
|
|
| inp = ids[:, :-1] |
| tgt = ids[:, 1:] |
|
|
| with autocast_context(device): |
| logits = model(inp) |
| loss = F.cross_entropy( |
| logits.reshape(-1, logits.size(-1)), |
| tgt.reshape(-1), |
| reduction="mean", |
| ) |
|
|
| return {"tokens": int(tgt.numel()), "loss": float(loss.item()), "ppl": float(torch.exp(loss).item())} |
|
|
|
|
| def built_in_tests() -> List[tuple[str, str]]: |
| return [ |
| ("completion", "Deep learning is a method of machine learning that"), |
| ("completion", "Le deep learning est une méthode d'apprentissage qui"), |
| ("completion", "الذكاء الاصطناعي هو مجال يهدف إلى"), |
| ("qa", "What is machine learning?"), |
| ("qa", "Qu'est-ce que l'apprentissage automatique ?"), |
| ("instruction", "Give a short HTML page with a title and one paragraph."), |
| ] |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", type=str, default=str(DEFAULT_CHECKPOINT)) |
| parser.add_argument("--config", type=str, default=str(DEFAULT_CONFIG)) |
| parser.add_argument("--tokenizer_dir", type=str, default=str(DEFAULT_TOKENIZER_DIR)) |
| parser.add_argument("--prompt", type=str, default="Deep learning is a method of machine learning that") |
| parser.add_argument("--mode", type=str, default="completion", choices=["completion", "qa", "instruction", "raw"]) |
| parser.add_argument("--max_new_tokens", type=int, default=96) |
| parser.add_argument("--temperature", type=float, default=0.2) |
| parser.add_argument("--top_k", type=int, default=20) |
| parser.add_argument("--top_p", type=float, default=0.8) |
| parser.add_argument("--repetition_penalty", type=float, default=1.2) |
| parser.add_argument("--interactive", action="store_true") |
| parser.add_argument("--run_tests", action="store_true") |
| parser.add_argument("--score_only", action="store_true") |
| args = parser.parse_args() |
|
|
| device = get_device() |
| if device.type == "cuda": |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.set_float32_matmul_precision("high") |
|
|
| model, tokenizer, ckpt = load_model_and_tokenizer( |
| checkpoint_path=Path(args.checkpoint), |
| config_path=Path(args.config), |
| tokenizer_dir=Path(args.tokenizer_dir), |
| device=device, |
| ) |
|
|
| print(f"Device: {device}") |
| print(f"Checkpoint: {args.checkpoint}") |
| print(f"epoch={ckpt.get('epoch', 'N/A')} | step={ckpt.get('step', 'N/A')} | best_loss={ckpt.get('best_loss', 'N/A')}") |
|
|
| if args.run_tests: |
| print("\n=== Tests intégrés ===") |
| for i, (mode, text) in enumerate(built_in_tests(), start=1): |
| prompt = build_prompt(text, mode) |
| print(f"\n[{i}] mode={mode}") |
| print(f"Entrée: {text}") |
| print("Sortie:") |
| print(generate_text( |
| model=model, |
| tokenizer=tokenizer, |
| prompt=prompt, |
| device=device, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| repetition_penalty=args.repetition_penalty, |
| )) |
| return |
|
|
| if args.interactive: |
| print("Mode interactif.") |
| print("Commandes: /mode completion|qa|instruction|raw, /score texte, exit\n") |
| current_mode = args.mode |
|
|
| while True: |
| user_in = input(f"{current_mode}> ").strip() |
| if user_in.lower() in {"exit", "quit"}: |
| break |
| if not user_in: |
| continue |
|
|
| if user_in.startswith("/mode "): |
| new_mode = user_in.split(maxsplit=1)[1].strip() |
| if new_mode in {"completion", "qa", "instruction", "raw"}: |
| current_mode = new_mode |
| print(f"Mode changé: {current_mode}\n") |
| else: |
| print("Mode invalide.\n") |
| continue |
|
|
| if user_in.startswith("/score "): |
| sample = user_in.split(maxsplit=1)[1] |
| print(score_text(model, tokenizer, sample, device)) |
| print() |
| continue |
|
|
| prompt = build_prompt(user_in, current_mode) |
| print("\n=== Sortie ===") |
| print(generate_text( |
| model=model, |
| tokenizer=tokenizer, |
| prompt=prompt, |
| device=device, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| repetition_penalty=args.repetition_penalty, |
| )) |
| print() |
| return |
|
|
| if args.score_only: |
| print(json.dumps(score_text(model, tokenizer, args.prompt, device), ensure_ascii=False, indent=2)) |
| return |
|
|
| prompt = build_prompt(args.prompt, args.mode) |
| print(generate_text( |
| model=model, |
| tokenizer=tokenizer, |
| prompt=prompt, |
| device=device, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| repetition_penalty=args.repetition_penalty, |
| )) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|