#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ simple_qa_test_aramix_v3.py Test QA simple et strict pour un modèle déjà entraîné dans une repo de type : - train_aramix_h100_full.py - aramix_h100/ - config.json - model_best.pt - model.pt - tokenizer_32k/ Cette version améliore le test QA en : - gérant les checkpoints torch.compile() avec préfixe "_orig_mod." - utilisant un prompt plus directif pour des réponses courtes - privilégiant une génération greedy / quasi-greedy - tronquant proprement les réponses trop longues - ajoutant une métrique "contains_reference" Usage ----- python simple_qa_test_aramix_v3.py --repo_dir ./aramix_h100 --save_report python simple_qa_test_aramix_v3.py --repo_dir ./aramix_h100 --temperature 0 --max_new_tokens 16 python simple_qa_test_aramix_v3.py --repo_dir ./aramix_h100 --ckpt ./aramix_h100/model_best.pt python simple_qa_test_aramix_v3.py --repo_dir ./aramix_h100 --questions qa_questions.json """ from __future__ import annotations import argparse import importlib.util import json import os import re import sys import time import unicodedata from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn.functional as F DEFAULT_QUESTIONS = [ {"category": "Géographie", "question": "Quelle est la capitale de la France ?", "reference": "Paris"}, {"category": "Géographie", "question": "Quel est le plus long fleuve d'Afrique ?", "reference": "Le Nil"}, {"category": "Science", "question": "Qu'est-ce que la photosynthèse ?", "reference": "Processus par lequel les plantes convertissent la lumière en énergie"}, {"category": "Science", "question": "Combien d'os compte le corps humain adulte ?", "reference": "206"}, {"category": "Histoire", "question": "En quelle année a eu lieu la Révolution française ?", "reference": "1789"}, {"category": "Histoire", "question": "Qui a écrit Les Misérables ?", "reference": "Victor Hugo"}, {"category": "Mathématiques", "question": "Quelle est la formule de l'aire d'un cercle ?", "reference": "pi r carre"}, {"category": "Langage", "question": "Donne un synonyme du mot heureux.", "reference": "joyeux"}, {"category": "Raisonnement", "question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?", "reference": "3"}, {"category": "Dialogue", "question": "Comment vas-tu aujourd'hui ?", "reference": None}, ] def load_module_from_file(py_path: Path): spec = importlib.util.spec_from_file_location(py_path.stem, py_path) if spec is None or spec.loader is None: raise RuntimeError(f"Impossible de charger le module: {py_path}") module = importlib.util.module_from_spec(spec) sys.modules[py_path.stem] = module spec.loader.exec_module(module) return module def normalize_text(text: str) -> str: text = (text or "").strip().lower() text = unicodedata.normalize("NFKD", text) text = "".join(ch for ch in text if not unicodedata.combining(ch)) text = text.replace("π", "pi") text = re.sub(r"[\W_]+", " ", text, flags=re.UNICODE) text = re.sub(r"\s+", " ", text).strip() return text def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]: if not reference: return None ref = set(normalize_text(reference).split()) ans = set(normalize_text(answer).split()) if not ref: return None return len(ref & ans) / len(ref) def exact_match(reference: Optional[str], answer: str) -> Optional[bool]: if not reference: return None return normalize_text(reference) == normalize_text(answer) def contains_reference(reference: Optional[str], answer: str) -> Optional[bool]: if not reference: return None ref = normalize_text(reference) ans = normalize_text(answer) if not ref: return None return ref in ans def infer_repo_defaults(repo_dir: Path): train_script = repo_dir.parent / "train_aramix_h100_full.py" if not train_script.exists(): train_script = repo_dir / "train_aramix_h100_full.py" ckpt = repo_dir / "model_best.pt" if not ckpt.exists(): ckpt = repo_dir / "model.pt" config = repo_dir / "config.json" tokenizer_dir = repo_dir / "tokenizer_32k" return train_script, ckpt, config, tokenizer_dir def safe_get(cfg: Dict[str, Any], *names: str, default=None): for name in names: if name in cfg: return cfg[name] return default def build_model_config_dict(cfg_json: Dict[str, Any], vocab_size: int) -> Dict[str, Any]: block_size = safe_get(cfg_json, "block_size", "max_seq_len", "seq_len", default=512) d_model = safe_get(cfg_json, "d_model", "n_embd", "dim", default=768) n_heads = safe_get(cfg_json, "n_heads", "n_head", "num_heads", default=12) n_layers = safe_get(cfg_json, "n_layers", "n_layer", "num_layers", default=12) d_ff = safe_get(cfg_json, "d_ff", "ffn_dim", "intermediate_size", default=d_model * 4) return { "vocab_size": vocab_size, "block_size": int(block_size), "d_model": int(d_model), "n_heads": int(n_heads), "n_layers": int(n_layers), "d_ff": int(d_ff), } def strip_orig_mod_prefix(state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if any(k.startswith("_orig_mod.") for k in state.keys()): return {k.replace("_orig_mod.", "", 1): v for k, v in state.items()} return state def clean_answer_text(text: str, max_words: int = 16) -> str: text = (text or "").strip() # Retire quelques marqueurs fréquents text = text.replace("", " ") text = text.replace("", " ") text = text.replace("", " ") text = text.replace("Réponse :", " ") text = text.replace("Réponse:", " ") text = text.replace("Answer:", " ") # Garde la première ligne text = text.split("\n")[0].strip() # Coupe à la première vraie fin de phrase courte m = re.search(r"([.!?])", text) if m and m.start() < 120: text = text[: m.start() + 1] # Compacte les espaces text = re.sub(r"\s+", " ", text).strip() # Tronque au nombre de mots voulu words = text.split() if len(words) > max_words: text = " ".join(words[:max_words]).strip() # Supprime ponctuation finale excessive text = re.sub(r"[,\s;:]+$", "", text).strip() return text def safe_top_k_logits(logits: torch.Tensor, top_k: int) -> torch.Tensor: if top_k is None or top_k <= 0: return logits k = min(int(top_k), logits.size(-1)) values, _ = torch.topk(logits, k=k) kth = values[:, -1].unsqueeze(-1) return torch.where(logits < kth, torch.full_like(logits, float("-inf")), logits) class AramixChatTester: def __init__( self, repo_dir: Path, train_script: Path, ckpt_path: Path, config_path: Path, device: Optional[str] = None, prompt_style: str = "strict_qa", ): self.repo_dir = repo_dir self.train_script = train_script self.ckpt_path = ckpt_path self.config_path = config_path self.prompt_style = prompt_style self.device = torch.device(device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu")) self.M = load_module_from_file(self.train_script) required = ["GPT", "GPTConfig", "train_or_load_tokenizer", "DOMAINS"] missing = [x for x in required if not hasattr(self.M, x)] if missing: raise RuntimeError(f"Le fichier {self.train_script.name} ne contient pas les symboles attendus: {missing}") self.cfg_json: Dict[str, Any] = {} if self.config_path.exists(): with open(self.config_path, "r", encoding="utf-8") as f: self.cfg_json = json.load(f) self.tokenizer = self._load_tokenizer() self.model = self._load_model() def _load_tokenizer(self): old_cwd = Path.cwd() try: os.chdir(self.repo_dir.parent) tok = self.M.train_or_load_tokenizer(self.M.DOMAINS) finally: os.chdir(old_cwd) return tok def _make_gpt_config(self): kwargs = build_model_config_dict(self.cfg_json, vocab_size=len(self.tokenizer)) try: return self.M.GPTConfig(**kwargs) except TypeError: return self.M.GPTConfig(vocab_size=len(self.tokenizer)) def _manual_load_state(self, model: torch.nn.Module): ckpt = torch.load(self.ckpt_path, map_location=self.device) if isinstance(ckpt, dict) and "model" in ckpt: state = ckpt["model"] else: state = ckpt if not isinstance(state, dict): raise RuntimeError("Checkpoint non reconnu: pas de state_dict exploitable.") state = strip_orig_mod_prefix(state) missing, unexpected = model.load_state_dict(state, strict=False) # tolérance minimale uniquement sur le tying éventuel hard_missing = [k for k in missing if not k.endswith("lm_head.weight")] hard_unexpected = [k for k in unexpected if not k.startswith("_orig_mod.")] if hard_missing or hard_unexpected: raise RuntimeError( "Chargement manuel incomplet.\n" f"Missing: {hard_missing[:20]}\n" f"Unexpected: {hard_unexpected[:20]}" ) def _load_model(self): cfg = self._make_gpt_config() model = self.M.GPT(cfg).to(self.device) if hasattr(self.M, "load_checkpoint"): try: try: self.M.load_checkpoint(model, None, self.ckpt_path, self.device) model.eval() return model except TypeError: self.M.load_checkpoint(model, self.ckpt_path, self.device) model.eval() return model except RuntimeError as e: if "_orig_mod." not in str(e): raise self._manual_load_state(model) model.eval() return model def build_prompt(self, question: str) -> str: if self.prompt_style == "strict_qa": return ( "Réponds très brièvement et uniquement en français.\n" "Donne seulement la réponse finale, sans explication.\n\n" f"Question : {question}\n" "Réponse :" ) if self.prompt_style == "qa": return f"Question: {question}\nRéponse:" return question def encode_prompt(self, question: str) -> List[int]: bos = getattr(self.tokenizer, "bos_token_id", None) eos = getattr(self.tokenizer, "eos_token_id", None) prompt = self.build_prompt(question) ids = self.tokenizer.encode(prompt, add_special_tokens=False) if bos is not None: ids = [bos] + ids if eos is not None and ids and ids[-1] == eos: ids = ids[:-1] return ids @torch.no_grad() def generate( self, question: str, max_new_tokens: int = 16, temperature: float = 0.0, top_k: int = 1, repetition_penalty: float = 1.10, max_answer_words: int = 16, ) -> str: ids = self.encode_prompt(question) x = torch.tensor([ids], dtype=torch.long, device=self.device) eos_id = getattr(self.tokenizer, "eos_token_id", None) block_size = getattr(getattr(self.model, "cfg", None), "block_size", None) if block_size is None: block_size = safe_get(self.cfg_json, "block_size", "max_seq_len", default=512) generated_word_budget_hit = False for step in range(max_new_tokens): x_ctx = x[:, -int(block_size):] out = self.model(x_ctx) logits = out[0] if isinstance(out, tuple) else out logits = logits[:, -1, :] # Pénalité légère de répétition recent = x[0, -48:].tolist() for tok in set(recent): logits[0, tok] /= repetition_penalty # Greedy par défaut pour QA courte if temperature <= 0: if top_k is not None and top_k > 1: masked = safe_top_k_logits(logits, top_k) probs = F.softmax(masked, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) else: next_tok = torch.argmax(logits, dim=-1, keepdim=True) else: logits = logits / max(temperature, 1e-5) logits = safe_top_k_logits(logits, top_k) probs = F.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) x = torch.cat([x, next_tok], dim=1) if eos_id is not None and next_tok.item() == eos_id and step >= 1: break # arrêt anticipé si la réponse devient déjà trop longue partial = self.tokenizer.decode(x[0, len(ids):].tolist()).strip() partial = clean_answer_text(partial, max_words=max_answer_words) if len(partial.split()) >= max_answer_words: generated_word_budget_hit = True break new_ids = x[0, len(ids):].tolist() text = self.tokenizer.decode(new_ids).strip() text = clean_answer_text(text, max_words=max_answer_words) if generated_word_budget_hit: text = clean_answer_text(text, max_words=max_answer_words) return text def load_questions(path: Optional[str]) -> List[Dict[str, Any]]: if not path: return DEFAULT_QUESTIONS with open(path, "r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, list): raise ValueError("Le fichier questions doit contenir une liste JSON.") return data def format_bar(score: float, width: int = 20) -> str: n = max(0, min(width, int(round(score * width)))) return "█" * n + "░" * (width - n) def save_reports(output_dir: Path, summary: Dict[str, Any]) -> Tuple[Path, Path]: output_dir.mkdir(parents=True, exist_ok=True) json_path = output_dir / "qa_test_report_simple_v3.json" txt_path = output_dir / "qa_test_report_simple_v3.txt" with open(json_path, "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) with open(txt_path, "w", encoding="utf-8") as f: f.write("TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ (V3)\n") f.write("=" * 60 + "\n\n") for r in summary["results"]: f.write(f"[{r['id']:02d}] {r['category']}\n") f.write(f" User : {r['question']}\n") f.write(f" Assistant : {r['answer']}\n") if r["reference"]: f.write(f" Référence : {r['reference']}\n") if r["overlap_score"] is not None: f.write(f" Overlap : {r['overlap_score']:.0%}\n") if r["exact_match"] is not None: f.write(f" ExactMatch: {r['exact_match']}\n") if r["contains_reference"] is not None: f.write(f" Contains : {r['contains_reference']}\n") f.write(f" Latence : {r['latency_s']}s\n\n") return json_path, txt_path def main(): parser = argparse.ArgumentParser("Test QA simple et strict pour modèle Aramix déjà entraîné") parser.add_argument("--repo_dir", type=str, default="./aramix_h100") parser.add_argument("--train_script", type=str, default=None) parser.add_argument("--ckpt", type=str, default=None) parser.add_argument("--config", type=str, default=None) parser.add_argument("--questions", type=str, default=None) parser.add_argument("--max_new_tokens", type=int, default=16) parser.add_argument("--temperature", type=float, default=0.0) parser.add_argument("--top_k", type=int, default=1) parser.add_argument("--repetition_penalty", type=float, default=1.10) parser.add_argument("--max_answer_words", type=int, default=16) parser.add_argument("--prompt_style", type=str, default="strict_qa", choices=["strict_qa", "qa", "raw"]) parser.add_argument("--device", type=str, default=None) parser.add_argument("--save_report", action="store_true") args = parser.parse_args() repo_dir = Path(args.repo_dir).resolve() train_script, ckpt_path, config_path, tokenizer_dir = infer_repo_defaults(repo_dir) if args.train_script: train_script = Path(args.train_script).resolve() if args.ckpt: ckpt_path = Path(args.ckpt).resolve() if args.config: config_path = Path(args.config).resolve() if not train_script.exists(): raise FileNotFoundError(f"Script train introuvable: {train_script}") if not ckpt_path.exists(): raise FileNotFoundError(f"Checkpoint introuvable: {ckpt_path}") if not config_path.exists(): print(f"[WARN] config.json introuvable: {config_path} — fallback sur GPTConfig(vocab_size=...).") questions = load_questions(args.questions) tester = AramixChatTester( repo_dir=repo_dir, train_script=train_script, ckpt_path=ckpt_path, config_path=config_path, device=args.device, prompt_style=args.prompt_style, ) results: List[Dict[str, Any]] = [] categories: Dict[str, List[Dict[str, Any]]] = {} print("\n" + "=" * 60) print("TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ") print("=" * 60) print(f"Repo : {repo_dir}") print(f"Train script: {train_script}") print(f"Checkpoint : {ckpt_path}") print(f"Config : {config_path}") print(f"Tokenizer : {tokenizer_dir}") print(f"Device : {tester.device}") print(f"Prompt : {args.prompt_style}") print(f"Questions : {len(questions)}") print("=" * 60 + "\n") for i, item in enumerate(questions, 1): q = item["question"] ref = item.get("reference") cat = item.get("category", "Général") t0 = time.time() ans = tester.generate( q, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, repetition_penalty=args.repetition_penalty, max_answer_words=args.max_answer_words, ) latency = time.time() - t0 overlap = lexical_overlap(ref, ans) em = exact_match(ref, ans) contains_ref = contains_reference(ref, ans) entry = { "id": i, "category": cat, "question": q, "answer": ans, "reference": ref, "latency_s": round(latency, 3), "tokens_generated_approx": len(ans.split()), "overlap_score": None if overlap is None else round(overlap, 4), "exact_match": em, "contains_reference": contains_ref, } results.append(entry) categories.setdefault(cat, []).append(entry) overlap_str = f"{overlap:.0%}" if overlap is not None else "n/a" em_str = "✓" if em else ("✗" if em is not None else "n/a") contains_str = "✓" if contains_ref else ("✗" if contains_ref is not None else "n/a") print(f"[{i:02d}] {cat}") print(f" User : {q}") print(f" Assistant : {ans}") if ref: print(f" Référence : {ref}") print(f" Overlap : {overlap_str}") print(f" ExactMatch: {em_str}") print(f" Contains : {contains_str}") print(f" Latence : {latency:.3f}s\n") scored_overlap = [r["overlap_score"] for r in results if r["overlap_score"] is not None] scored_em = [r["exact_match"] for r in results if r["exact_match"] is not None] scored_contains = [r["contains_reference"] for r in results if r["contains_reference"] is not None] avg_overlap = sum(scored_overlap) / len(scored_overlap) if scored_overlap else 0.0 em_rate = sum(1 for x in scored_em if x) / len(scored_em) if scored_em else 0.0 contains_rate = sum(1 for x in scored_contains if x) / len(scored_contains) if scored_contains else 0.0 avg_latency = sum(r["latency_s"] for r in results) / len(results) if results else 0.0 avg_words = sum(r["tokens_generated_approx"] for r in results) / len(results) if results else 0.0 cat_scores: Dict[str, float] = {} for cat, items in categories.items(): vals = [r["overlap_score"] for r in items if r["overlap_score"] is not None] cat_scores[cat] = (sum(vals) / len(vals)) if vals else 0.0 summary = { "repo_dir": str(repo_dir), "train_script": str(train_script), "checkpoint": str(ckpt_path), "config_path": str(config_path), "tokenizer_dir": str(tokenizer_dir), "device": str(tester.device), "prompt_style": args.prompt_style, "total_questions": len(results), "avg_overlap_score": round(avg_overlap, 4), "exact_match_rate": round(em_rate, 4), "contains_reference_rate": round(contains_rate, 4), "avg_latency_s": round(avg_latency, 3), "avg_words_generated": round(avg_words, 1), "scores_by_category": {k: round(v, 4) for k, v in cat_scores.items()}, "results": results, } print("=" * 60) print("RÉSUMÉ") print("=" * 60) print(f"Questions testées : {len(results)}") print(f"Overlap moyen : {avg_overlap:.1%}") print(f"Exact match : {em_rate:.1%}") print(f"Contains ref : {contains_rate:.1%}") print(f"Latence moyenne : {avg_latency:.2f}s") print(f"Mots moyens : {avg_words:.1f}") print("Scores / catégorie:") for cat, score in sorted(cat_scores.items()): print(f" {cat:<15} {format_bar(score)} {score:.0%}") print("=" * 60) if args.save_report: json_path, txt_path = save_reports(repo_dir, summary) print(f"Rapports sauvegardés dans : {json_path}") print(f" {txt_path}") if __name__ == "__main__": main()