#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ simple_qa_test_aramix.py Test QA simple pour un modèle déjà entraîné dans une repo de type : - train_aramix_h100_full.py - aramix_h100/ - config.json - model_best.pt - model.pt - tokenizer_32k/ Hypothèses alignées avec ton repo : - le module d'entraînement expose : GPT, GPTConfig, train_or_load_tokenizer, load_checkpoint, DOMAINS - le tokenizer est géré par train_or_load_tokenizer(DOMAINS) - le checkpoint se recharge avec load_checkpoint(model, opt, ckpt_path, device) Usage ----- python simple_qa_test_aramix.py python simple_qa_test_aramix.py --repo_dir ./aramix_h100 python simple_qa_test_aramix.py --ckpt ./aramix_h100/model.pt python simple_qa_test_aramix.py --questions qa_questions.json python simple_qa_test_aramix.py --max_new_tokens 96 --temperature 0.4 --top_k 40 python simple_qa_test_aramix.py --save_report """ from __future__ import annotations import argparse import importlib.util import json import os import re import sys import time import unicodedata from pathlib import Path from typing import Any, Dict, List, Optional import torch import torch.nn.functional as F DEFAULT_QUESTIONS = [ { "category": "Géographie", "question": "Quelle est la capitale de la France ?", "reference": "Paris", }, { "category": "Géographie", "question": "Quel est le plus long fleuve d'Afrique ?", "reference": "Le Nil", }, { "category": "Science", "question": "Qu'est-ce que la photosynthèse ?", "reference": "Processus par lequel les plantes convertissent la lumière en énergie", }, { "category": "Science", "question": "Combien d'os compte le corps humain adulte ?", "reference": "206", }, { "category": "Histoire", "question": "En quelle année a eu lieu la Révolution française ?", "reference": "1789", }, { "category": "Histoire", "question": "Qui a écrit Les Misérables ?", "reference": "Victor Hugo", }, { "category": "Mathématiques", "question": "Quelle est la formule de l'aire d'un cercle ?", "reference": "pi r carre", }, { "category": "Langage", "question": "Donne un synonyme du mot heureux.", "reference": "joyeux", }, { "category": "Raisonnement", "question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?", "reference": "3", }, { "category": "Dialogue", "question": "Comment vas-tu aujourd'hui ?", "reference": None, }, ] def load_module_from_file(py_path: Path): spec = importlib.util.spec_from_file_location(py_path.stem, py_path) if spec is None or spec.loader is None: raise RuntimeError(f"Impossible de charger le module: {py_path}") module = importlib.util.module_from_spec(spec) sys.modules[py_path.stem] = module spec.loader.exec_module(module) return module def normalize_text(text: str) -> str: text = (text or "").strip().lower() text = unicodedata.normalize("NFKD", text) text = "".join(ch for ch in text if not unicodedata.combining(ch)) text = text.replace("π", "pi") text = re.sub(r"[\W_]+", " ", text, flags=re.UNICODE) text = re.sub(r"\s+", " ", text).strip() return text def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]: if not reference: return None ref = set(normalize_text(reference).split()) ans = set(normalize_text(answer).split()) if not ref: return None return len(ref & ans) / len(ref) def exact_match(reference: Optional[str], answer: str) -> Optional[bool]: if not reference: return None return normalize_text(reference) == normalize_text(answer) def infer_repo_defaults(repo_dir: Path): train_script = repo_dir.parent / "train_aramix_h100_full.py" if not train_script.exists(): train_script = repo_dir / "train_aramix_h100_full.py" ckpt = repo_dir / "model_best.pt" if not ckpt.exists(): ckpt = repo_dir / "model.pt" config = repo_dir / "config.json" tokenizer_dir = repo_dir / "tokenizer_32k" return train_script, ckpt, config, tokenizer_dir def safe_get(cfg: Dict[str, Any], *names: str, default=None): for name in names: if name in cfg: return cfg[name] return default def build_model_config_dict(cfg_json: Dict[str, Any], vocab_size: int) -> Dict[str, Any]: block_size = safe_get(cfg_json, "block_size", "max_seq_len", "seq_len", default=512) d_model = safe_get(cfg_json, "d_model", "n_embd", "dim", default=768) n_heads = safe_get(cfg_json, "n_heads", "n_head", "num_heads", default=12) n_layers = safe_get(cfg_json, "n_layers", "n_layer", "num_layers", default=12) d_ff = safe_get(cfg_json, "d_ff", "ffn_dim", "intermediate_size", default=d_model * 4) return { "vocab_size": vocab_size, "block_size": int(block_size), "d_model": int(d_model), "n_heads": int(n_heads), "n_layers": int(n_layers), "d_ff": int(d_ff), } class AramixChatTester: def __init__( self, repo_dir: Path, train_script: Path, ckpt_path: Path, config_path: Path, device: Optional[str] = None, ): self.repo_dir = repo_dir self.train_script = train_script self.ckpt_path = ckpt_path self.config_path = config_path self.device = torch.device( device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu") ) self.M = load_module_from_file(self.train_script) required = ["GPT", "GPTConfig", "train_or_load_tokenizer", "load_checkpoint", "DOMAINS"] missing = [x for x in required if not hasattr(self.M, x)] if missing: raise RuntimeError( f"Le fichier {self.train_script.name} ne contient pas les symboles attendus: {missing}" ) self.cfg_json: Dict[str, Any] = {} if self.config_path.exists(): with open(self.config_path, "r", encoding="utf-8") as f: self.cfg_json = json.load(f) self.tokenizer = self._load_tokenizer() self.model = self._load_model() def _load_tokenizer(self): old_cwd = Path.cwd() try: os.chdir(self.repo_dir.parent) tok = self.M.train_or_load_tokenizer(self.M.DOMAINS) finally: os.chdir(old_cwd) return tok def _make_gpt_config(self): kwargs = build_model_config_dict(self.cfg_json, vocab_size=len(self.tokenizer)) try: return self.M.GPTConfig(**kwargs) except TypeError: return self.M.GPTConfig(vocab_size=len(self.tokenizer)) def _load_model(self): cfg = self._make_gpt_config() model = self.M.GPT(cfg).to(self.device) try: self.M.load_checkpoint(model, None, self.ckpt_path, self.device) except TypeError: try: self.M.load_checkpoint(model, self.ckpt_path, self.device) except TypeError: ckpt = torch.load(self.ckpt_path, map_location=self.device) state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt if any(k.startswith("_orig_mod.") for k in state): state = {k.replace("_orig_mod.", ""): v for k, v in state.items()} model.load_state_dict(state, strict=False) model.eval() return model def encode_prompt(self, question: str) -> List[int]: bos = getattr(self.tokenizer, "bos_token_id", None) eos = getattr(self.tokenizer, "eos_token_id", None) prompt = f"Question: {question}\nRéponse:" ids = self.tokenizer.encode(prompt, add_special_tokens=False) if bos is not None: ids = [bos] + ids if eos is not None and len(ids) > 0 and ids[-1] == eos: ids = ids[:-1] return ids @torch.no_grad() def generate( self, question: str, max_new_tokens: int = 96, temperature: float = 0.4, top_k: int = 40, repetition_penalty: float = 1.12, ) -> str: ids = self.encode_prompt(question) x = torch.tensor([ids], dtype=torch.long, device=self.device) eos_id = getattr(self.tokenizer, "eos_token_id", None) block_size = getattr(getattr(self.model, "cfg", None), "block_size", None) if block_size is None: block_size = safe_get(self.cfg_json, "block_size", "max_seq_len", default=512) for step in range(max_new_tokens): x_ctx = x[:, -int(block_size):] try: logits, _ = self.model(x_ctx) except TypeError: out = self.model(x_ctx) logits = out[0] if isinstance(out, tuple) else out logits = logits[:, -1, :] recent = x[0, -64:].tolist() for tok in set(recent): logits[0, tok] /= repetition_penalty if temperature <= 0: next_tok = torch.argmax(logits, dim=-1, keepdim=True) else: logits = logits / max(temperature, 1e-5) if top_k is not None and top_k > 0: values, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) kth = values[:, -1].unsqueeze(-1) logits = torch.where(logits < kth, torch.full_like(logits, float("-inf")), logits) probs = F.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) x = torch.cat([x, next_tok], dim=1) if eos_id is not None and next_tok.item() == eos_id and step >= 2: break new_ids = x[0, len(ids):].tolist() text = self.tokenizer.decode(new_ids).strip() text = re.sub(r"\s+", " ", text).strip() text = text.replace("Réponse :", "").replace("Réponse:", "").strip() return text def load_questions(path: Optional[str]) -> List[Dict[str, Any]]: if not path: return DEFAULT_QUESTIONS with open(path, "r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, list): raise ValueError("Le fichier questions doit contenir une liste JSON.") return data def format_bar(score: float, width: int = 20) -> str: n = max(0, min(width, int(round(score * width)))) return "█" * n + "░" * (width - n) def save_reports(output_dir: Path, summary: Dict[str, Any]) -> None: output_dir.mkdir(parents=True, exist_ok=True) json_path = output_dir / "qa_test_report_simple.json" txt_path = output_dir / "qa_test_report_simple.txt" with open(json_path, "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) with open(txt_path, "w", encoding="utf-8") as f: f.write("TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ\n") f.write("=" * 60 + "\n\n") for r in summary["results"]: f.write(f"[{r['id']:02d}] {r['category']}\n") f.write(f" User : {r['question']}\n") f.write(f" Assistant : {r['answer']}\n") if r["reference"]: f.write(f" Référence : {r['reference']}\n") if r["overlap_score"] is not None: f.write(f" Overlap : {r['overlap_score']:.0%}\n") if r["exact_match"] is not None: f.write(f" ExactMatch: {r['exact_match']}\n") f.write(f" Latence : {r['latency_s']}s\n\n") def main(): parser = argparse.ArgumentParser("Test QA simple pour modèle Aramix déjà entraîné") parser.add_argument("--repo_dir", type=str, default="./aramix_h100") parser.add_argument("--train_script", type=str, default=None) parser.add_argument("--ckpt", type=str, default=None) parser.add_argument("--config", type=str, default=None) parser.add_argument("--questions", type=str, default=None) parser.add_argument("--max_new_tokens", type=int, default=96) parser.add_argument("--temperature", type=float, default=0.4) parser.add_argument("--top_k", type=int, default=40) parser.add_argument("--repetition_penalty", type=float, default=1.12) parser.add_argument("--device", type=str, default=None) parser.add_argument("--save_report", action="store_true") args = parser.parse_args() repo_dir = Path(args.repo_dir).resolve() train_script, ckpt_path, config_path, tokenizer_dir = infer_repo_defaults(repo_dir) if args.train_script: train_script = Path(args.train_script).resolve() if args.ckpt: ckpt_path = Path(args.ckpt).resolve() if args.config: config_path = Path(args.config).resolve() if not train_script.exists(): raise FileNotFoundError(f"Script train introuvable: {train_script}") if not ckpt_path.exists(): raise FileNotFoundError(f"Checkpoint introuvable: {ckpt_path}") if not config_path.exists(): print(f"[WARN] config.json introuvable: {config_path} — fallback sur GPTConfig(vocab_size=...).") questions = load_questions(args.questions) tester = AramixChatTester( repo_dir=repo_dir, train_script=train_script, ckpt_path=ckpt_path, config_path=config_path, device=args.device, ) results: List[Dict[str, Any]] = [] categories: Dict[str, List[Dict[str, Any]]] = {} print("\n" + "═" * 70) print(" TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ") print("═" * 70) print(f"Repo : {repo_dir}") print(f"Train script: {train_script}") print(f"Checkpoint : {ckpt_path}") print(f"Config : {config_path}") print(f"Tokenizer : {tokenizer_dir}") print(f"Device : {tester.device}") print(f"Questions : {len(questions)}") print("═" * 70 + "\n") for i, item in enumerate(questions, 1): q = item["question"] ref = item.get("reference") cat = item.get("category", "Général") t0 = time.time() ans = tester.generate( q, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, repetition_penalty=args.repetition_penalty, ) latency = time.time() - t0 overlap = lexical_overlap(ref, ans) em = exact_match(ref, ans) entry = { "id": i, "category": cat, "question": q, "answer": ans, "reference": ref, "latency_s": round(latency, 3), "tokens_generated_approx": len(ans.split()), "overlap_score": None if overlap is None else round(overlap, 4), "exact_match": em, } results.append(entry) categories.setdefault(cat, []).append(entry) overlap_str = f"{overlap:.0%}" if overlap is not None else "n/a" em_str = "✓" if em else ("✗" if em is not None else "n/a") print("─" * 70) print(f"[{i:02d}] [{cat}] overlap={overlap_str} | EM={em_str}") print(f" User : {q}") print(f" Assistant : {ans}") if ref: print(f" Référence : {ref}") print(f" ⏱ {latency:.2f}s | ~{entry['tokens_generated_approx']} mots\n") scored_overlap = [r["overlap_score"] for r in results if r["overlap_score"] is not None] scored_em = [r["exact_match"] for r in results if r["exact_match"] is not None] avg_overlap = sum(scored_overlap) / len(scored_overlap) if scored_overlap else 0.0 em_rate = sum(1 for x in scored_em if x) / len(scored_em) if scored_em else 0.0 avg_latency = sum(r["latency_s"] for r in results) / len(results) if results else 0.0 avg_words = sum(r["tokens_generated_approx"] for r in results) / len(results) if results else 0.0 cat_scores: Dict[str, float] = {} for cat, items in categories.items(): vals = [r["overlap_score"] for r in items if r["overlap_score"] is not None] cat_scores[cat] = (sum(vals) / len(vals)) if vals else 0.0 summary = { "repo_dir": str(repo_dir), "train_script": str(train_script), "checkpoint": str(ckpt_path), "config_path": str(config_path), "tokenizer_dir": str(tokenizer_dir), "device": str(tester.device), "total_questions": len(results), "avg_overlap_score": round(avg_overlap, 4), "exact_match_rate": round(em_rate, 4), "avg_latency_s": round(avg_latency, 3), "avg_words_generated": round(avg_words, 1), "scores_by_category": {k: round(v, 4) for k, v in cat_scores.items()}, "results": results, } print("═" * 70) print(" RÉSUMÉ") print("═" * 70) print(f"Questions testées : {len(results)}") print(f"Overlap moyen : {avg_overlap:.1%}") print(f"Exact match : {em_rate:.1%}") print(f"Latence moyenne : {avg_latency:.2f}s") print(f"Mots moyens : {avg_words:.1f}") print("Scores / catégorie:") for cat, score in sorted(cat_scores.items()): print(f" {cat:<15} {format_bar(score)} {score:.0%}") print("═" * 70) if args.save_report: save_reports(repo_dir, summary) print(f"Rapports sauvegardés dans : {repo_dir / 'qa_test_report_simple.json'}") print(f" {repo_dir / 'qa_test_report_simple.txt'}") if __name__ == "__main__": main()