#!/usr/bin/env python3 # -*- coding: utf-8 -*- from __future__ import annotations import argparse import json from collections import OrderedDict from contextlib import nullcontext from dataclasses import dataclass from pathlib import Path from typing import Optional, List import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedTokenizerFast MODEL_DIR = Path("./nlp_1b_h100_2h") DEFAULT_CHECKPOINT = MODEL_DIR / "model_best.pt" DEFAULT_CONFIG = MODEL_DIR / "config.json" DEFAULT_TOKENIZER_DIR = Path("./nlp_1b_h100_opt/tokenizer_32k") def get_device() -> torch.device: if torch.cuda.is_available(): return torch.device(f"cuda:{torch.cuda.current_device()}") return torch.device("cpu") def autocast_context(device: torch.device): if device.type == "cuda": return torch.autocast("cuda", dtype=torch.bfloat16) return nullcontext() def normalize_state_dict_keys(state_dict: dict) -> OrderedDict: normalized = OrderedDict() for k, v in state_dict.items(): nk = k if nk.startswith("module._orig_mod."): nk = nk[len("module._orig_mod."):] elif nk.startswith("_orig_mod."): nk = nk[len("_orig_mod."):] elif nk.startswith("module."): nk = nk[len("module."):] normalized[nk] = v return normalized def clean_text(text: str) -> str: text = text.replace("\x00", " ").strip() return " ".join(text.split()) @dataclass class GPTConfig: vocab_size: int block_size: int d_model: int n_heads: int n_layers: int d_ff: int dropout: float = 0.0 use_checkpointing: bool = False class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) class RotaryEmbedding(nn.Module): def __init__(self, dim: int, base: int = 10000, max_seq: int = 4096): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(max_seq).float() freqs = torch.outer(t, inv_freq) self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False) self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False) def forward(self, seq_len: int, dtype: torch.dtype): return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype) def rotate_half(x: torch.Tensor) -> torch.Tensor: x1, x2 = x[..., ::2], x[..., 1::2] return torch.stack((-x2, x1), dim=-1).flatten(-2) def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) return x * cos + rotate_half(x) * sin class CausalSelfAttention(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() assert cfg.d_model % cfg.n_heads == 0 self.n_heads = cfg.n_heads self.head_dim = cfg.d_model // cfg.n_heads self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False) self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False) self.rope = RotaryEmbedding(self.head_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: b, t, c = x.shape q, k, v = self.qkv(x).split(c, dim=-1) q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2) cos, sin = self.rope(t, x.dtype) q = apply_rope(q, cos, sin) k = apply_rope(k, cos, sin) y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True) y = y.transpose(1, 2).contiguous().view(b, t, c) return self.proj(y) class SwiGLU(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w3(F.silu(self.w1(x)) * self.w2(x)) class Block(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() self.ln1 = RMSNorm(cfg.d_model) self.attn = CausalSelfAttention(cfg) self.ln2 = RMSNorm(cfg.d_model) self.ff = SwiGLU(cfg) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.ln1(x)) x = x + self.ff(self.ln2(x)) return x class GPT(nn.Module): def __init__(self, cfg: GPTConfig): super().__init__() self.cfg = cfg self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) self.ln_f = RMSNorm(cfg.d_model) self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) self.lm_head.weight = self.tok_emb.weight def forward(self, input_ids: torch.Tensor) -> torch.Tensor: x = self.tok_emb(input_ids) for block in self.blocks: x = block(x) return self.lm_head(self.ln_f(x)) @torch.inference_mode() def generate( self, input_ids: torch.Tensor, max_new_tokens: int = 96, temperature: float = 0.2, top_k: int = 20, top_p: float = 0.8, repetition_penalty: float = 1.2, eos_token_id: Optional[int] = None, no_repeat_ngram_size: int = 3, ) -> torch.Tensor: self.eval() for _ in range(max_new_tokens): idx_cond = input_ids[:, -self.cfg.block_size:] logits = self(idx_cond) logits = logits[:, -1, :] if repetition_penalty != 1.0: for b in range(input_ids.size(0)): seen = torch.unique(input_ids[b]) seen_logits = logits[b, seen] logits[b, seen] = torch.where( seen_logits < 0, seen_logits * repetition_penalty, seen_logits / repetition_penalty, ) if no_repeat_ngram_size > 0 and input_ids.size(1) >= no_repeat_ngram_size - 1: n = no_repeat_ngram_size for b in range(input_ids.size(0)): prefix = tuple(input_ids[b, -(n - 1):].tolist()) banned = set() toks = input_ids[b].tolist() for i in range(len(toks) - n + 1): if tuple(toks[i:i+n-1]) == prefix: banned.add(toks[i+n-1]) if banned: logits[b, list(banned)] = -float("inf") if temperature <= 0: next_token = torch.argmax(logits, dim=-1, keepdim=True) else: logits = logits / max(temperature, 1e-6) if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("inf") if 0 < top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) probs = F.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(probs, dim=-1) sorted_mask = cumulative_probs > top_p sorted_mask[..., 1:] = sorted_mask[..., :-1].clone() sorted_mask[..., 0] = False mask = torch.zeros_like(logits, dtype=torch.bool) mask.scatter_(1, sorted_indices, sorted_mask) logits = logits.masked_fill(mask, -float("inf")) probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_token], dim=1) if eos_token_id is not None and (next_token == eos_token_id).all(): break return input_ids def load_model_and_tokenizer(checkpoint_path: Path, config_path: Path, tokenizer_dir: Path, device: torch.device): if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint introuvable: {checkpoint_path}") if not config_path.exists(): raise FileNotFoundError(f"Config introuvable: {config_path}") if not tokenizer_dir.exists(): raise FileNotFoundError(f"Tokenizer introuvable: {tokenizer_dir}") cfg_dict = json.loads(config_path.read_text(encoding="utf-8")) cfg = GPTConfig(**cfg_dict) tokenizer = PreTrainedTokenizerFast.from_pretrained(str(tokenizer_dir)) model = GPT(cfg).to(device) ckpt = torch.load(checkpoint_path, map_location=device) state_dict = normalize_state_dict_keys(ckpt["model"]) model.load_state_dict(state_dict, strict=True) model.eval() return model, tokenizer, ckpt def build_prompt(text: str, mode: str) -> str: if mode == "raw": return text if mode == "completion": return text if mode == "qa": return f"Réponds brièvement en français.\nQuestion: {text}\nRéponse:" if mode == "instruction": return f"Instruction: Réponds de façon concise.\nEntrée: {text}\nSortie:" raise ValueError(f"Mode inconnu: {mode}") def encode_prompt(tokenizer: PreTrainedTokenizerFast, prompt: str, device: torch.device) -> torch.Tensor: encoded = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) input_ids = encoded["input_ids"].to(device) if tokenizer.bos_token_id is not None: bos = torch.tensor([[tokenizer.bos_token_id]], device=device, dtype=input_ids.dtype) input_ids = torch.cat([bos, input_ids], dim=1) return input_ids def generate_text(model, tokenizer, prompt, device, max_new_tokens, temperature, top_k, top_p, repetition_penalty): input_ids = encode_prompt(tokenizer, prompt, device) prompt_len = input_ids.shape[1] with autocast_context(device): output_ids = model.generate( input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id, no_repeat_ngram_size=3, ) generated_ids = output_ids[0][prompt_len:] return clean_text(tokenizer.decode(generated_ids, skip_special_tokens=True)) @torch.inference_mode() def score_text(model, tokenizer, text: str, device: torch.device) -> dict: ids = encode_prompt(tokenizer, text, device) if ids.size(1) < 2: return {"tokens": int(ids.size(1)), "loss": None, "ppl": None} inp = ids[:, :-1] tgt = ids[:, 1:] with autocast_context(device): logits = model(inp) loss = F.cross_entropy( logits.reshape(-1, logits.size(-1)), tgt.reshape(-1), reduction="mean", ) return {"tokens": int(tgt.numel()), "loss": float(loss.item()), "ppl": float(torch.exp(loss).item())} def built_in_tests() -> List[tuple[str, str]]: return [ ("completion", "Deep learning is a method of machine learning that"), ("completion", "Le deep learning est une méthode d'apprentissage qui"), ("completion", "الذكاء الاصطناعي هو مجال يهدف إلى"), ("qa", "What is machine learning?"), ("qa", "Qu'est-ce que l'apprentissage automatique ?"), ("instruction", "Give a short HTML page with a title and one paragraph."), ] def main(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", type=str, default=str(DEFAULT_CHECKPOINT)) parser.add_argument("--config", type=str, default=str(DEFAULT_CONFIG)) parser.add_argument("--tokenizer_dir", type=str, default=str(DEFAULT_TOKENIZER_DIR)) parser.add_argument("--prompt", type=str, default="Deep learning is a method of machine learning that") parser.add_argument("--mode", type=str, default="completion", choices=["completion", "qa", "instruction", "raw"]) parser.add_argument("--max_new_tokens", type=int, default=96) parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--top_k", type=int, default=20) parser.add_argument("--top_p", type=float, default=0.8) parser.add_argument("--repetition_penalty", type=float, default=1.2) parser.add_argument("--interactive", action="store_true") parser.add_argument("--run_tests", action="store_true") parser.add_argument("--score_only", action="store_true") args = parser.parse_args() device = get_device() if device.type == "cuda": torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.set_float32_matmul_precision("high") model, tokenizer, ckpt = load_model_and_tokenizer( checkpoint_path=Path(args.checkpoint), config_path=Path(args.config), tokenizer_dir=Path(args.tokenizer_dir), device=device, ) print(f"Device: {device}") print(f"Checkpoint: {args.checkpoint}") print(f"epoch={ckpt.get('epoch', 'N/A')} | step={ckpt.get('step', 'N/A')} | best_loss={ckpt.get('best_loss', 'N/A')}") if args.run_tests: print("\n=== Tests intégrés ===") for i, (mode, text) in enumerate(built_in_tests(), start=1): prompt = build_prompt(text, mode) print(f"\n[{i}] mode={mode}") print(f"Entrée: {text}") print("Sortie:") print(generate_text( model=model, tokenizer=tokenizer, prompt=prompt, device=device, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, )) return if args.interactive: print("Mode interactif.") print("Commandes: /mode completion|qa|instruction|raw, /score texte, exit\n") current_mode = args.mode while True: user_in = input(f"{current_mode}> ").strip() if user_in.lower() in {"exit", "quit"}: break if not user_in: continue if user_in.startswith("/mode "): new_mode = user_in.split(maxsplit=1)[1].strip() if new_mode in {"completion", "qa", "instruction", "raw"}: current_mode = new_mode print(f"Mode changé: {current_mode}\n") else: print("Mode invalide.\n") continue if user_in.startswith("/score "): sample = user_in.split(maxsplit=1)[1] print(score_text(model, tokenizer, sample, device)) print() continue prompt = build_prompt(user_in, current_mode) print("\n=== Sortie ===") print(generate_text( model=model, tokenizer=tokenizer, prompt=prompt, device=device, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, )) print() return if args.score_only: print(json.dumps(score_text(model, tokenizer, args.prompt, device), ensure_ascii=False, indent=2)) return prompt = build_prompt(args.prompt, args.mode) print(generate_text( model=model, tokenizer=tokenizer, prompt=prompt, device=device, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, )) if __name__ == "__main__": main()