FirstChat / test_2h.py
Medyassino's picture
Add files using upload-large-folder tool
59dc998 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import argparse
import json
from collections import OrderedDict
from contextlib import nullcontext
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedTokenizerFast
MODEL_DIR = Path("./nlp_1b_h100_2h")
DEFAULT_CHECKPOINT = MODEL_DIR / "model_best.pt"
DEFAULT_CONFIG = MODEL_DIR / "config.json"
DEFAULT_TOKENIZER_DIR = Path("./nlp_1b_h100_opt/tokenizer_32k")
def get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
return torch.device("cpu")
def autocast_context(device: torch.device):
if device.type == "cuda":
return torch.autocast("cuda", dtype=torch.bfloat16)
return nullcontext()
def normalize_state_dict_keys(state_dict: dict) -> OrderedDict:
normalized = OrderedDict()
for k, v in state_dict.items():
nk = k
if nk.startswith("module._orig_mod."):
nk = nk[len("module._orig_mod."):]
elif nk.startswith("_orig_mod."):
nk = nk[len("_orig_mod."):]
elif nk.startswith("module."):
nk = nk[len("module."):]
normalized[nk] = v
return normalized
def clean_text(text: str) -> str:
text = text.replace("\x00", " ").strip()
return " ".join(text.split())
@dataclass
class GPTConfig:
vocab_size: int
block_size: int
d_model: int
n_heads: int
n_layers: int
d_ff: int
dropout: float = 0.0
use_checkpointing: bool = False
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, base: int = 10000, max_seq: int = 4096):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(max_seq).float()
freqs = torch.outer(t, inv_freq)
self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False)
self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False)
def forward(self, seq_len: int, dtype: torch.dtype):
return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x[..., ::2], x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).flatten(-2)
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
return x * cos + rotate_half(x) * sin
class CausalSelfAttention(nn.Module):
def __init__(self, cfg: GPTConfig):
super().__init__()
assert cfg.d_model % cfg.n_heads == 0
self.n_heads = cfg.n_heads
self.head_dim = cfg.d_model // cfg.n_heads
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
self.rope = RotaryEmbedding(self.head_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, t, c = x.shape
q, k, v = self.qkv(x).split(c, dim=-1)
q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rope(t, x.dtype)
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)
y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)
y = y.transpose(1, 2).contiguous().view(b, t, c)
return self.proj(y)
class SwiGLU(nn.Module):
def __init__(self, cfg: GPTConfig):
super().__init__()
self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w3(F.silu(self.w1(x)) * self.w2(x))
class Block(nn.Module):
def __init__(self, cfg: GPTConfig):
super().__init__()
self.ln1 = RMSNorm(cfg.d_model)
self.attn = CausalSelfAttention(cfg)
self.ln2 = RMSNorm(cfg.d_model)
self.ff = SwiGLU(cfg)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln1(x))
x = x + self.ff(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, cfg: GPTConfig):
super().__init__()
self.cfg = cfg
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
self.ln_f = RMSNorm(cfg.d_model)
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
self.lm_head.weight = self.tok_emb.weight
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
x = self.tok_emb(input_ids)
for block in self.blocks:
x = block(x)
return self.lm_head(self.ln_f(x))
@torch.inference_mode()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 96,
temperature: float = 0.2,
top_k: int = 20,
top_p: float = 0.8,
repetition_penalty: float = 1.2,
eos_token_id: Optional[int] = None,
no_repeat_ngram_size: int = 3,
) -> torch.Tensor:
self.eval()
for _ in range(max_new_tokens):
idx_cond = input_ids[:, -self.cfg.block_size:]
logits = self(idx_cond)
logits = logits[:, -1, :]
if repetition_penalty != 1.0:
for b in range(input_ids.size(0)):
seen = torch.unique(input_ids[b])
seen_logits = logits[b, seen]
logits[b, seen] = torch.where(
seen_logits < 0,
seen_logits * repetition_penalty,
seen_logits / repetition_penalty,
)
if no_repeat_ngram_size > 0 and input_ids.size(1) >= no_repeat_ngram_size - 1:
n = no_repeat_ngram_size
for b in range(input_ids.size(0)):
prefix = tuple(input_ids[b, -(n - 1):].tolist())
banned = set()
toks = input_ids[b].tolist()
for i in range(len(toks) - n + 1):
if tuple(toks[i:i+n-1]) == prefix:
banned.add(toks[i+n-1])
if banned:
logits[b, list(banned)] = -float("inf")
if temperature <= 0:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
else:
logits = logits / max(temperature, 1e-6)
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
if 0 < top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(probs, dim=-1)
sorted_mask = cumulative_probs > top_p
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
sorted_mask[..., 0] = False
mask = torch.zeros_like(logits, dtype=torch.bool)
mask.scatter_(1, sorted_indices, sorted_mask)
logits = logits.masked_fill(mask, -float("inf"))
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return input_ids
def load_model_and_tokenizer(checkpoint_path: Path, config_path: Path, tokenizer_dir: Path, device: torch.device):
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint introuvable: {checkpoint_path}")
if not config_path.exists():
raise FileNotFoundError(f"Config introuvable: {config_path}")
if not tokenizer_dir.exists():
raise FileNotFoundError(f"Tokenizer introuvable: {tokenizer_dir}")
cfg_dict = json.loads(config_path.read_text(encoding="utf-8"))
cfg = GPTConfig(**cfg_dict)
tokenizer = PreTrainedTokenizerFast.from_pretrained(str(tokenizer_dir))
model = GPT(cfg).to(device)
ckpt = torch.load(checkpoint_path, map_location=device)
state_dict = normalize_state_dict_keys(ckpt["model"])
model.load_state_dict(state_dict, strict=True)
model.eval()
return model, tokenizer, ckpt
def build_prompt(text: str, mode: str) -> str:
if mode == "raw":
return text
if mode == "completion":
return text
if mode == "qa":
return f"Réponds brièvement en français.\nQuestion: {text}\nRéponse:"
if mode == "instruction":
return f"Instruction: Réponds de façon concise.\nEntrée: {text}\nSortie:"
raise ValueError(f"Mode inconnu: {mode}")
def encode_prompt(tokenizer: PreTrainedTokenizerFast, prompt: str, device: torch.device) -> torch.Tensor:
encoded = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
input_ids = encoded["input_ids"].to(device)
if tokenizer.bos_token_id is not None:
bos = torch.tensor([[tokenizer.bos_token_id]], device=device, dtype=input_ids.dtype)
input_ids = torch.cat([bos, input_ids], dim=1)
return input_ids
def generate_text(model, tokenizer, prompt, device, max_new_tokens, temperature, top_k, top_p, repetition_penalty):
input_ids = encode_prompt(tokenizer, prompt, device)
prompt_len = input_ids.shape[1]
with autocast_context(device):
output_ids = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
eos_token_id=tokenizer.eos_token_id,
no_repeat_ngram_size=3,
)
generated_ids = output_ids[0][prompt_len:]
return clean_text(tokenizer.decode(generated_ids, skip_special_tokens=True))
@torch.inference_mode()
def score_text(model, tokenizer, text: str, device: torch.device) -> dict:
ids = encode_prompt(tokenizer, text, device)
if ids.size(1) < 2:
return {"tokens": int(ids.size(1)), "loss": None, "ppl": None}
inp = ids[:, :-1]
tgt = ids[:, 1:]
with autocast_context(device):
logits = model(inp)
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
tgt.reshape(-1),
reduction="mean",
)
return {"tokens": int(tgt.numel()), "loss": float(loss.item()), "ppl": float(torch.exp(loss).item())}
def built_in_tests() -> List[tuple[str, str]]:
return [
("completion", "Deep learning is a method of machine learning that"),
("completion", "Le deep learning est une méthode d'apprentissage qui"),
("completion", "الذكاء الاصطناعي هو مجال يهدف إلى"),
("qa", "What is machine learning?"),
("qa", "Qu'est-ce que l'apprentissage automatique ?"),
("instruction", "Give a short HTML page with a title and one paragraph."),
]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, default=str(DEFAULT_CHECKPOINT))
parser.add_argument("--config", type=str, default=str(DEFAULT_CONFIG))
parser.add_argument("--tokenizer_dir", type=str, default=str(DEFAULT_TOKENIZER_DIR))
parser.add_argument("--prompt", type=str, default="Deep learning is a method of machine learning that")
parser.add_argument("--mode", type=str, default="completion", choices=["completion", "qa", "instruction", "raw"])
parser.add_argument("--max_new_tokens", type=int, default=96)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top_k", type=int, default=20)
parser.add_argument("--top_p", type=float, default=0.8)
parser.add_argument("--repetition_penalty", type=float, default=1.2)
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--run_tests", action="store_true")
parser.add_argument("--score_only", action="store_true")
args = parser.parse_args()
device = get_device()
if device.type == "cuda":
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
model, tokenizer, ckpt = load_model_and_tokenizer(
checkpoint_path=Path(args.checkpoint),
config_path=Path(args.config),
tokenizer_dir=Path(args.tokenizer_dir),
device=device,
)
print(f"Device: {device}")
print(f"Checkpoint: {args.checkpoint}")
print(f"epoch={ckpt.get('epoch', 'N/A')} | step={ckpt.get('step', 'N/A')} | best_loss={ckpt.get('best_loss', 'N/A')}")
if args.run_tests:
print("\n=== Tests intégrés ===")
for i, (mode, text) in enumerate(built_in_tests(), start=1):
prompt = build_prompt(text, mode)
print(f"\n[{i}] mode={mode}")
print(f"Entrée: {text}")
print("Sortie:")
print(generate_text(
model=model,
tokenizer=tokenizer,
prompt=prompt,
device=device,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
))
return
if args.interactive:
print("Mode interactif.")
print("Commandes: /mode completion|qa|instruction|raw, /score texte, exit\n")
current_mode = args.mode
while True:
user_in = input(f"{current_mode}> ").strip()
if user_in.lower() in {"exit", "quit"}:
break
if not user_in:
continue
if user_in.startswith("/mode "):
new_mode = user_in.split(maxsplit=1)[1].strip()
if new_mode in {"completion", "qa", "instruction", "raw"}:
current_mode = new_mode
print(f"Mode changé: {current_mode}\n")
else:
print("Mode invalide.\n")
continue
if user_in.startswith("/score "):
sample = user_in.split(maxsplit=1)[1]
print(score_text(model, tokenizer, sample, device))
print()
continue
prompt = build_prompt(user_in, current_mode)
print("\n=== Sortie ===")
print(generate_text(
model=model,
tokenizer=tokenizer,
prompt=prompt,
device=device,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
))
print()
return
if args.score_only:
print(json.dumps(score_text(model, tokenizer, args.prompt, device), ensure_ascii=False, indent=2))
return
prompt = build_prompt(args.prompt, args.mode)
print(generate_text(
model=model,
tokenizer=tokenizer,
prompt=prompt,
device=device,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
))
if __name__ == "__main__":
main()