#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Test QA simple pour un modèle déjà entraîné. Fonctionne avec le script : train_chatbot_100m_large.py - charge la config depuis output_dir/train_config.json si disponible - charge un checkpoint fini (par défaut: output_dir/sft_best.pt) - pose une petite liste de questions QA - calcule un score simple d'overlap lexical - sauvegarde un rapport JSON + TXT Exemples -------- python simple_qa_test_finished_model.py --output_dir ./fr_100m python simple_qa_test_finished_model.py --output_dir ./fr_100m --ckpt ./fr_100m/sft_final.pt python simple_qa_test_finished_model.py --output_dir ./fr_100m --questions qa_questions.json """ from __future__ import annotations import argparse import importlib.util import json import os import sys import time import unicodedata import re from pathlib import Path from typing import Dict, List, Optional DEFAULT_QUESTIONS = [ { "category": "Géographie", "question": "Quelle est la capitale de la France ?", "reference": "Paris", }, { "category": "Géographie", "question": "Quel est le plus long fleuve d'Afrique ?", "reference": "Le Nil", }, { "category": "Science", "question": "Qu'est-ce que la photosynthèse ?", "reference": "Processus par lequel les plantes convertissent la lumière en énergie", }, { "category": "Science", "question": "Combien d'os compte le corps humain adulte ?", "reference": "206", }, { "category": "Histoire", "question": "En quelle année a eu lieu la Révolution française ?", "reference": "1789", }, { "category": "Histoire", "question": "Qui a écrit Les Misérables ?", "reference": "Victor Hugo", }, { "category": "Mathématiques", "question": "Quelle est la formule de l'aire d'un cercle ?", "reference": "π × r²", }, { "category": "Langage", "question": "Donne un synonyme du mot heureux.", "reference": "joyeux", }, { "category": "Raisonnement", "question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?", "reference": "3", }, { "category": "Dialogue", "question": "Comment vas-tu aujourd'hui ?", "reference": None, }, ] def normalize_text(s: str) -> str: s = (s or "").strip().lower() s = unicodedata.normalize("NFKD", s) s = "".join(ch for ch in s if not unicodedata.combining(ch)) s = re.sub(r"[^\w\s]", " ", s, flags=re.UNICODE) s = re.sub(r"\s+", " ", s).strip() return s def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]: if not reference: return None ref_tokens = set(normalize_text(reference).split()) ans_tokens = set(normalize_text(answer).split()) if not ref_tokens: return 0.0 return len(ref_tokens & ans_tokens) / len(ref_tokens) def exact_match(reference: Optional[str], answer: str) -> Optional[bool]: if not reference: return None return normalize_text(reference) == normalize_text(answer) def import_train_module(train_script_path: str): path = Path(train_script_path) if not path.exists(): raise FileNotFoundError(f"Script d'entraînement introuvable: {path}") spec = importlib.util.spec_from_file_location("train_module", str(path)) if spec is None or spec.loader is None: raise ImportError(f"Impossible de charger le module: {path}") module = importlib.util.module_from_spec(spec) sys.modules["train_module"] = module spec.loader.exec_module(module) return module def build_cfg(train_module, output_dir: str): cfg_path = Path(output_dir) / "train_config.json" if cfg_path.exists(): with open(cfg_path, "r", encoding="utf-8") as f: saved = json.load(f) cfg = train_module.TrainConfig(**saved) else: cfg = train_module.TrainConfig(output_dir=output_dir, tokenizer_prefix=f"{output_dir}/tokenizer") cfg.output_dir = output_dir cfg.tokenizer_prefix = f"{output_dir}/tokenizer" return cfg def run_test( train_script: str, output_dir: str, ckpt_path: str, questions: List[Dict], save_report: bool, ): train_module = import_train_module(train_script) cfg = build_cfg(train_module, output_dir) bot = train_module.Chatbot(cfg, ckpt_path) results = [] categories: Dict[str, List[Dict]] = {} sep = "─" * 64 print(f"\n{'═'*64}") print(" TEST QA SIMPLE — MODÈLE ENTRAÎNÉ") print(f" Checkpoint : {ckpt_path}") print(f" Questions : {len(questions)}") print(f"{'═'*64}\n") for i, item in enumerate(questions, 1): q = item["question"] ref = item.get("reference") cat = item.get("category", "Général") ctx = item.get("context", "") t0 = time.time() ans = bot.chat(q, context=ctx) latency = time.time() - t0 overlap = lexical_overlap(ref, ans) em = exact_match(ref, ans) row = { "id": i, "category": cat, "question": q, "context": ctx, "reference": ref, "answer": ans, "overlap_score": overlap, "exact_match": em, "latency_s": round(latency, 3), "tokens_generated_approx": len(ans.split()), } results.append(row) categories.setdefault(cat, []).append(row) score_text = [] if overlap is not None: score_text.append(f"overlap={overlap:.0%}") if em is not None: score_text.append(f"EM={'oui' if em else 'non'}") score_str = f" [{' | '.join(score_text)}]" if score_text else "" print(sep) print(f"[{i:02d}] [{cat}]{score_str}") if ctx: print(f" Contexte : {ctx[:120]}{'...' if len(ctx) > 120 else ''}") print(f" User : {q}") print(f" Assistant : {ans}") if ref: print(f" Référence : {ref}") print(f" ⏱ {latency:.2f}s | ~{row['tokens_generated_approx']} mots\n") scored = [r for r in results if r["overlap_score"] is not None] avg_overlap = sum(r["overlap_score"] for r in scored) / len(scored) if scored else 0.0 em_rows = [r for r in results if r["exact_match"] is not None] em_rate = sum(1 for r in em_rows if r["exact_match"]) / len(em_rows) if em_rows else 0.0 avg_latency = sum(r["latency_s"] for r in results) / max(1, len(results)) avg_tokens = sum(r["tokens_generated_approx"] for r in results) / max(1, len(results)) scores_by_category = {} for cat, items in categories.items(): cat_scored = [x for x in items if x["overlap_score"] is not None] cat_em = [x for x in items if x["exact_match"] is not None] scores_by_category[cat] = { "avg_overlap": round(sum(x["overlap_score"] for x in cat_scored) / len(cat_scored), 4) if cat_scored else None, "exact_match_rate": round(sum(1 for x in cat_em if x["exact_match"]) / len(cat_em), 4) if cat_em else None, } summary = { "checkpoint": ckpt_path, "total_questions": len(results), "avg_overlap_score": round(avg_overlap, 4), "exact_match_rate": round(em_rate, 4), "avg_latency_s": round(avg_latency, 3), "avg_tokens_generated_approx": round(avg_tokens, 1), "scores_by_category": scores_by_category, "results": results, } print(f"{'═'*64}") print(" RÉSUMÉ") print(f"{'═'*64}") print(f" Questions testées : {len(results)}") print(f" Overlap moyen : {avg_overlap:.1%}") print(f" Exact match : {em_rate:.1%}") print(f" Latence moyenne : {avg_latency:.2f}s") print(f" Mots moyens : {avg_tokens:.0f}") print(" Scores / catégorie :") for cat, sc in scores_by_category.items(): ov = sc["avg_overlap"] emc = sc["exact_match_rate"] print(f" - {cat:<15} overlap={ov if ov is not None else 'n/a'} | EM={emc if emc is not None else 'n/a'}") print(f"{'═'*64}\n") if save_report: report_json = Path(output_dir) / "qa_test_simple_report.json" report_txt = Path(output_dir) / "qa_test_simple_report.txt" with open(report_json, "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) with open(report_txt, "w", encoding="utf-8") as f: f.write("TEST QA SIMPLE — MODÈLE ENTRAÎNÉ\n") f.write(f"Checkpoint : {ckpt_path}\n\n") for r in results: f.write(f"[{r['id']:02d}] {r['category']}\n") if r["context"]: f.write(f" Contexte : {r['context']}\n") f.write(f" User : {r['question']}\n") f.write(f" Assistant : {r['answer']}\n") if r["reference"]: f.write(f" Référence : {r['reference']}\n") if r["overlap_score"] is not None: f.write(f" Overlap : {r['overlap_score']:.0%}\n") if r["exact_match"] is not None: f.write(f" EM : {'oui' if r['exact_match'] else 'non'}\n") f.write(f" Latence : {r['latency_s']}s\n\n") print(f"Rapport JSON -> {report_json}") print(f"Rapport TXT -> {report_txt}") return summary if __name__ == "__main__": parser = argparse.ArgumentParser("Test QA simple pour modèle déjà entraîné") parser.add_argument("--train_script", type=str, default="./train_chatbot_100m_large.py") parser.add_argument("--output_dir", type=str, default="./fr_100m") parser.add_argument("--ckpt", type=str, default=None) parser.add_argument("--questions", type=str, default=None, help="JSON optionnel [{question, reference, category, context?}]") parser.add_argument("--no_save", action="store_true") args = parser.parse_args() ckpt_path = args.ckpt or os.path.join(args.output_dir, "sft_best.pt") if not Path(ckpt_path).exists(): raise FileNotFoundError( f"Checkpoint introuvable: {ckpt_path}\n" f"Exemple: --ckpt {args.output_dir}/sft_final.pt" ) if args.questions: with open(args.questions, "r", encoding="utf-8") as f: questions = json.load(f) else: questions = DEFAULT_QUESTIONS run_test( train_script=args.train_script, output_dir=args.output_dir, ckpt_path=ckpt_path, questions=questions, save_report=not args.no_save, )