#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ simple_qa_test_aramix_v2.py Version corrigée du test QA simple pour repo Aramix. Correction principale : - gère les checkpoints sauvés depuis torch.compile() avec préfixe "_orig_mod." - contourne load_checkpoint() du script train si celui-ci échoue sur ce cas Usage ----- python simple_qa_test_aramix_v2.py --repo_dir ./aramix_h100 --save_report """ from __future__ import annotations import argparse import importlib.util import json import os import re import sys import time import unicodedata from pathlib import Path from typing import Any, Dict, List, Optional import torch import torch.nn.functional as F DEFAULT_QUESTIONS = [ {"category": "Géographie", "question": "Quelle est la capitale de la France ?", "reference": "Paris"}, {"category": "Géographie", "question": "Quel est le plus long fleuve d'Afrique ?", "reference": "Le Nil"}, {"category": "Science", "question": "Qu'est-ce que la photosynthèse ?", "reference": "Processus par lequel les plantes convertissent la lumière en énergie"}, {"category": "Science", "question": "Combien d'os compte le corps humain adulte ?", "reference": "206"}, {"category": "Histoire", "question": "En quelle année a eu lieu la Révolution française ?", "reference": "1789"}, {"category": "Histoire", "question": "Qui a écrit Les Misérables ?", "reference": "Victor Hugo"}, {"category": "Mathématiques", "question": "Quelle est la formule de l'aire d'un cercle ?", "reference": "pi r carre"}, {"category": "Langage", "question": "Donne un synonyme du mot heureux.", "reference": "joyeux"}, {"category": "Raisonnement", "question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?", "reference": "3"}, {"category": "Dialogue", "question": "Comment vas-tu aujourd'hui ?", "reference": None}, ] def load_module_from_file(py_path: Path): spec = importlib.util.spec_from_file_location(py_path.stem, py_path) if spec is None or spec.loader is None: raise RuntimeError(f"Impossible de charger le module: {py_path}") module = importlib.util.module_from_spec(spec) sys.modules[py_path.stem] = module spec.loader.exec_module(module) return module def normalize_text(text: str) -> str: text = (text or "").strip().lower() text = unicodedata.normalize("NFKD", text) text = "".join(ch for ch in text if not unicodedata.combining(ch)) text = text.replace("π", "pi") text = re.sub(r"[\W_]+", " ", text, flags=re.UNICODE) text = re.sub(r"\s+", " ", text).strip() return text def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]: if not reference: return None ref = set(normalize_text(reference).split()) ans = set(normalize_text(answer).split()) if not ref: return None return len(ref & ans) / len(ref) def exact_match(reference: Optional[str], answer: str) -> Optional[bool]: if not reference: return None return normalize_text(reference) == normalize_text(answer) def infer_repo_defaults(repo_dir: Path): train_script = repo_dir.parent / "train_aramix_h100_full.py" if not train_script.exists(): train_script = repo_dir / "train_aramix_h100_full.py" ckpt = repo_dir / "model_best.pt" if not ckpt.exists(): ckpt = repo_dir / "model.pt" config = repo_dir / "config.json" tokenizer_dir = repo_dir / "tokenizer_32k" return train_script, ckpt, config, tokenizer_dir def safe_get(cfg: Dict[str, Any], *names: str, default=None): for name in names: if name in cfg: return cfg[name] return default def build_model_config_dict(cfg_json: Dict[str, Any], vocab_size: int) -> Dict[str, Any]: block_size = safe_get(cfg_json, "block_size", "max_seq_len", "seq_len", default=512) d_model = safe_get(cfg_json, "d_model", "n_embd", "dim", default=768) n_heads = safe_get(cfg_json, "n_heads", "n_head", "num_heads", default=12) n_layers = safe_get(cfg_json, "n_layers", "n_layer", "num_layers", default=12) d_ff = safe_get(cfg_json, "d_ff", "ffn_dim", "intermediate_size", default=d_model * 4) return { "vocab_size": vocab_size, "block_size": int(block_size), "d_model": int(d_model), "n_heads": int(n_heads), "n_layers": int(n_layers), "d_ff": int(d_ff), } def strip_orig_mod_prefix(state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if any(k.startswith("_orig_mod.") for k in state.keys()): return {k.replace("_orig_mod.", "", 1): v for k, v in state.items()} return state class AramixChatTester: def __init__( self, repo_dir: Path, train_script: Path, ckpt_path: Path, config_path: Path, device: Optional[str] = None, ): self.repo_dir = repo_dir self.train_script = train_script self.ckpt_path = ckpt_path self.config_path = config_path self.device = torch.device(device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu")) self.M = load_module_from_file(self.train_script) required = ["GPT", "GPTConfig", "train_or_load_tokenizer", "DOMAINS"] missing = [x for x in required if not hasattr(self.M, x)] if missing: raise RuntimeError(f"Le fichier {self.train_script.name} ne contient pas les symboles attendus: {missing}") self.cfg_json: Dict[str, Any] = {} if self.config_path.exists(): with open(self.config_path, "r", encoding="utf-8") as f: self.cfg_json = json.load(f) self.tokenizer = self._load_tokenizer() self.model = self._load_model() def _load_tokenizer(self): old_cwd = Path.cwd() try: os.chdir(self.repo_dir.parent) tok = self.M.train_or_load_tokenizer(self.M.DOMAINS) finally: os.chdir(old_cwd) return tok def _make_gpt_config(self): kwargs = build_model_config_dict(self.cfg_json, vocab_size=len(self.tokenizer)) try: return self.M.GPTConfig(**kwargs) except TypeError: return self.M.GPTConfig(vocab_size=len(self.tokenizer)) def _manual_load_state(self, model: torch.nn.Module): ckpt = torch.load(self.ckpt_path, map_location=self.device) if isinstance(ckpt, dict) and "model" in ckpt: state = ckpt["model"] else: state = ckpt if not isinstance(state, dict): raise RuntimeError("Checkpoint non reconnu: pas de state_dict exploitable.") state = strip_orig_mod_prefix(state) missing, unexpected = model.load_state_dict(state, strict=False) # tolérance uniquement sur lm_head/tied weights éventuels, sinon on échoue hard_missing = [k for k in missing if not k.endswith("lm_head.weight")] hard_unexpected = [k for k in unexpected if not k.startswith("_orig_mod.")] if hard_missing or hard_unexpected: raise RuntimeError( "Chargement manuel incomplet.\n" f"Missing: {hard_missing[:20]}\n" f"Unexpected: {hard_unexpected[:20]}" ) def _load_model(self): cfg = self._make_gpt_config() model = self.M.GPT(cfg).to(self.device) # 1) tentative via load_checkpoint du script train if hasattr(self.M, "load_checkpoint"): try: try: self.M.load_checkpoint(model, None, self.ckpt_path, self.device) model.eval() return model except TypeError: self.M.load_checkpoint(model, self.ckpt_path, self.device) model.eval() return model except RuntimeError as e: msg = str(e) if "_orig_mod." not in msg: raise # 2) fallback robuste self._manual_load_state(model) model.eval() return model def encode_prompt(self, question: str) -> List[int]: bos = getattr(self.tokenizer, "bos_token_id", None) eos = getattr(self.tokenizer, "eos_token_id", None) prompt = f"Question: {question}\nRéponse:" ids = self.tokenizer.encode(prompt, add_special_tokens=False) if bos is not None: ids = [bos] + ids if eos is not None and ids and ids[-1] == eos: ids = ids[:-1] return ids @torch.no_grad() def generate( self, question: str, max_new_tokens: int = 96, temperature: float = 0.4, top_k: int = 40, repetition_penalty: float = 1.12, ) -> str: ids = self.encode_prompt(question) x = torch.tensor([ids], dtype=torch.long, device=self.device) eos_id = getattr(self.tokenizer, "eos_token_id", None) block_size = getattr(getattr(self.model, "cfg", None), "block_size", None) if block_size is None: block_size = safe_get(self.cfg_json, "block_size", "max_seq_len", default=512) for step in range(max_new_tokens): x_ctx = x[:, -int(block_size):] out = self.model(x_ctx) logits = out[0] if isinstance(out, tuple) else out logits = logits[:, -1, :] recent = x[0, -64:].tolist() for tok in set(recent): logits[0, tok] /= repetition_penalty if temperature <= 0: next_tok = torch.argmax(logits, dim=-1, keepdim=True) else: logits = logits / max(temperature, 1e-5) if top_k is not None and top_k > 0: values, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) kth = values[:, -1].unsqueeze(-1) logits = torch.where(logits < kth, torch.full_like(logits, float("-inf")), logits) probs = F.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) x = torch.cat([x, next_tok], dim=1) if eos_id is not None and next_tok.item() == eos_id and step >= 2: break new_ids = x[0, len(ids):].tolist() text = self.tokenizer.decode(new_ids).strip() text = re.sub(r"\s+", " ", text).strip() text = text.replace("Réponse :", "").replace("Réponse:", "").strip() return text def load_questions(path: Optional[str]) -> List[Dict[str, Any]]: if not path: return DEFAULT_QUESTIONS with open(path, "r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, list): raise ValueError("Le fichier questions doit contenir une liste JSON.") return data def format_bar(score: float, width: int = 20) -> str: n = max(0, min(width, int(round(score * width)))) return "█" * n + "░" * (width - n) def save_reports(output_dir: Path, summary: Dict[str, Any]) -> None: output_dir.mkdir(parents=True, exist_ok=True) json_path = output_dir / "qa_test_report_simple.json" txt_path = output_dir / "qa_test_report_simple.txt" with open(json_path, "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) with open(txt_path, "w", encoding="utf-8") as f: f.write("TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ\n") f.write("=" * 60 + "\n\n") for r in summary["results"]: f.write(f"[{r['id']:02d}] {r['category']}\n") f.write(f" User : {r['question']}\n") f.write(f" Assistant : {r['answer']}\n") if r["reference"]: f.write(f" Référence : {r['reference']}\n") if r["overlap_score"] is not None: f.write(f" Overlap : {r['overlap_score']:.0%}\n") if r["exact_match"] is not None: f.write(f" ExactMatch: {r['exact_match']}\n") f.write(f" Latence : {r['latency_s']}s\n\n") def main(): parser = argparse.ArgumentParser("Test QA simple pour modèle Aramix déjà entraîné") parser.add_argument("--repo_dir", type=str, default="./aramix_h100") parser.add_argument("--train_script", type=str, default=None) parser.add_argument("--ckpt", type=str, default=None) parser.add_argument("--config", type=str, default=None) parser.add_argument("--questions", type=str, default=None) parser.add_argument("--max_new_tokens", type=int, default=96) parser.add_argument("--temperature", type=float, default=0.4) parser.add_argument("--top_k", type=int, default=40) parser.add_argument("--repetition_penalty", type=float, default=1.12) parser.add_argument("--device", type=str, default=None) parser.add_argument("--save_report", action="store_true") args = parser.parse_args() repo_dir = Path(args.repo_dir).resolve() train_script, ckpt_path, config_path, tokenizer_dir = infer_repo_defaults(repo_dir) if args.train_script: train_script = Path(args.train_script).resolve() if args.ckpt: ckpt_path = Path(args.ckpt).resolve() if args.config: config_path = Path(args.config).resolve() if not train_script.exists(): raise FileNotFoundError(f"Script train introuvable: {train_script}") if not ckpt_path.exists(): raise FileNotFoundError(f"Checkpoint introuvable: {ckpt_path}") if not config_path.exists(): print(f"[WARN] config.json introuvable: {config_path} — fallback sur GPTConfig(vocab_size=...).") questions = load_questions(args.questions) tester = AramixChatTester( repo_dir=repo_dir, train_script=train_script, ckpt_path=ckpt_path, config_path=config_path, device=args.device, ) results: List[Dict[str, Any]] = [] categories: Dict[str, List[Dict[str, Any]]] = {} print("\n" + "═" * 70) print(" TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ") print("═" * 70) print(f"Repo : {repo_dir}") print(f"Train script: {train_script}") print(f"Checkpoint : {ckpt_path}") print(f"Config : {config_path}") print(f"Tokenizer : {tokenizer_dir}") print(f"Device : {tester.device}") print(f"Questions : {len(questions)}") print("═" * 70 + "\n") for i, item in enumerate(questions, 1): q = item["question"] ref = item.get("reference") cat = item.get("category", "Général") t0 = time.time() ans = tester.generate( q, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, repetition_penalty=args.repetition_penalty, ) latency = time.time() - t0 overlap = lexical_overlap(ref, ans) em = exact_match(ref, ans) entry = { "id": i, "category": cat, "question": q, "answer": ans, "reference": ref, "latency_s": round(latency, 3), "tokens_generated_approx": len(ans.split()), "overlap_score": None if overlap is None else round(overlap, 4), "exact_match": em, } results.append(entry) categories.setdefault(cat, []).append(entry) overlap_str = f"{overlap:.0%}" if overlap is not None else "n/a" em_str = "✓" if em else ("✗" if em is not None else "n/a") print("─" * 70) print(f"[{i:02d}] [{cat}] overlap={overlap_str} | EM={em_str}") print(f" User : {q}") print(f" Assistant : {ans}") if ref: print(f" Référence : {ref}") print(f" ⏱ {latency:.2f}s | ~{entry['tokens_generated_approx']} mots\n") scored_overlap = [r["overlap_score"] for r in results if r["overlap_score"] is not None] scored_em = [r["exact_match"] for r in results if r["exact_match"] is not None] avg_overlap = sum(scored_overlap) / len(scored_overlap) if scored_overlap else 0.0 em_rate = sum(1 for x in scored_em if x) / len(scored_em) if scored_em else 0.0 avg_latency = sum(r["latency_s"] for r in results) / len(results) if results else 0.0 avg_words = sum(r["tokens_generated_approx"] for r in results) / len(results) if results else 0.0 cat_scores: Dict[str, float] = {} for cat, items in categories.items(): vals = [r["overlap_score"] for r in items if r["overlap_score"] is not None] cat_scores[cat] = (sum(vals) / len(vals)) if vals else 0.0 summary = { "repo_dir": str(repo_dir), "train_script": str(train_script), "checkpoint": str(ckpt_path), "config_path": str(config_path), "tokenizer_dir": str(tokenizer_dir), "device": str(tester.device), "total_questions": len(results), "avg_overlap_score": round(avg_overlap, 4), "exact_match_rate": round(em_rate, 4), "avg_latency_s": round(avg_latency, 3), "avg_words_generated": round(avg_words, 1), "scores_by_category": {k: round(v, 4) for k, v in cat_scores.items()}, "results": results, } print("═" * 70) print(" RÉSUMÉ") print("═" * 70) print(f"Questions testées : {len(results)}") print(f"Overlap moyen : {avg_overlap:.1%}") print(f"Exact match : {em_rate:.1%}") print(f"Latence moyenne : {avg_latency:.2f}s") print(f"Mots moyens : {avg_words:.1f}") print("Scores / catégorie:") for cat, score in sorted(cat_scores.items()): print(f" {cat:<15} {format_bar(score)} {score:.0%}") print("═" * 70) if args.save_report: save_reports(repo_dir, summary) print(f"Rapports sauvegardés dans : {repo_dir / 'qa_test_report_simple.json'}") print(f" {repo_dir / 'qa_test_report_simple.txt'}") if __name__ == "__main__": main()