| |
| |
|
|
| """ |
| Test QA simple pour un modèle déjà entraîné. |
| |
| Fonctionne avec le script : train_chatbot_100m_large.py |
| - charge la config depuis output_dir/train_config.json si disponible |
| - charge un checkpoint fini (par défaut: output_dir/sft_best.pt) |
| - pose une petite liste de questions QA |
| - calcule un score simple d'overlap lexical |
| - sauvegarde un rapport JSON + TXT |
| |
| Exemples |
| -------- |
| python simple_qa_test_finished_model.py --output_dir ./fr_100m |
| python simple_qa_test_finished_model.py --output_dir ./fr_100m --ckpt ./fr_100m/sft_final.pt |
| python simple_qa_test_finished_model.py --output_dir ./fr_100m --questions qa_questions.json |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import importlib.util |
| import json |
| import os |
| import sys |
| import time |
| import unicodedata |
| import re |
| from pathlib import Path |
| from typing import Dict, List, Optional |
|
|
|
|
| DEFAULT_QUESTIONS = [ |
| { |
| "category": "Géographie", |
| "question": "Quelle est la capitale de la France ?", |
| "reference": "Paris", |
| }, |
| { |
| "category": "Géographie", |
| "question": "Quel est le plus long fleuve d'Afrique ?", |
| "reference": "Le Nil", |
| }, |
| { |
| "category": "Science", |
| "question": "Qu'est-ce que la photosynthèse ?", |
| "reference": "Processus par lequel les plantes convertissent la lumière en énergie", |
| }, |
| { |
| "category": "Science", |
| "question": "Combien d'os compte le corps humain adulte ?", |
| "reference": "206", |
| }, |
| { |
| "category": "Histoire", |
| "question": "En quelle année a eu lieu la Révolution française ?", |
| "reference": "1789", |
| }, |
| { |
| "category": "Histoire", |
| "question": "Qui a écrit Les Misérables ?", |
| "reference": "Victor Hugo", |
| }, |
| { |
| "category": "Mathématiques", |
| "question": "Quelle est la formule de l'aire d'un cercle ?", |
| "reference": "π × r²", |
| }, |
| { |
| "category": "Langage", |
| "question": "Donne un synonyme du mot heureux.", |
| "reference": "joyeux", |
| }, |
| { |
| "category": "Raisonnement", |
| "question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?", |
| "reference": "3", |
| }, |
| { |
| "category": "Dialogue", |
| "question": "Comment vas-tu aujourd'hui ?", |
| "reference": None, |
| }, |
| ] |
|
|
|
|
| def normalize_text(s: str) -> str: |
| s = (s or "").strip().lower() |
| s = unicodedata.normalize("NFKD", s) |
| s = "".join(ch for ch in s if not unicodedata.combining(ch)) |
| s = re.sub(r"[^\w\s]", " ", s, flags=re.UNICODE) |
| s = re.sub(r"\s+", " ", s).strip() |
| return s |
|
|
|
|
| def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]: |
| if not reference: |
| return None |
| ref_tokens = set(normalize_text(reference).split()) |
| ans_tokens = set(normalize_text(answer).split()) |
| if not ref_tokens: |
| return 0.0 |
| return len(ref_tokens & ans_tokens) / len(ref_tokens) |
|
|
|
|
| def exact_match(reference: Optional[str], answer: str) -> Optional[bool]: |
| if not reference: |
| return None |
| return normalize_text(reference) == normalize_text(answer) |
|
|
|
|
| def import_train_module(train_script_path: str): |
| path = Path(train_script_path) |
| if not path.exists(): |
| raise FileNotFoundError(f"Script d'entraînement introuvable: {path}") |
|
|
| spec = importlib.util.spec_from_file_location("train_module", str(path)) |
| if spec is None or spec.loader is None: |
| raise ImportError(f"Impossible de charger le module: {path}") |
|
|
| module = importlib.util.module_from_spec(spec) |
| sys.modules["train_module"] = module |
| spec.loader.exec_module(module) |
| return module |
|
|
|
|
| def build_cfg(train_module, output_dir: str): |
| cfg_path = Path(output_dir) / "train_config.json" |
| if cfg_path.exists(): |
| with open(cfg_path, "r", encoding="utf-8") as f: |
| saved = json.load(f) |
| cfg = train_module.TrainConfig(**saved) |
| else: |
| cfg = train_module.TrainConfig(output_dir=output_dir, tokenizer_prefix=f"{output_dir}/tokenizer") |
|
|
| cfg.output_dir = output_dir |
| cfg.tokenizer_prefix = f"{output_dir}/tokenizer" |
| return cfg |
|
|
|
|
| def run_test( |
| train_script: str, |
| output_dir: str, |
| ckpt_path: str, |
| questions: List[Dict], |
| save_report: bool, |
| ): |
| train_module = import_train_module(train_script) |
| cfg = build_cfg(train_module, output_dir) |
| bot = train_module.Chatbot(cfg, ckpt_path) |
|
|
| results = [] |
| categories: Dict[str, List[Dict]] = {} |
|
|
| sep = "─" * 64 |
| print(f"\n{'═'*64}") |
| print(" TEST QA SIMPLE — MODÈLE ENTRAÎNÉ") |
| print(f" Checkpoint : {ckpt_path}") |
| print(f" Questions : {len(questions)}") |
| print(f"{'═'*64}\n") |
|
|
| for i, item in enumerate(questions, 1): |
| q = item["question"] |
| ref = item.get("reference") |
| cat = item.get("category", "Général") |
| ctx = item.get("context", "") |
|
|
| t0 = time.time() |
| ans = bot.chat(q, context=ctx) |
| latency = time.time() - t0 |
|
|
| overlap = lexical_overlap(ref, ans) |
| em = exact_match(ref, ans) |
|
|
| row = { |
| "id": i, |
| "category": cat, |
| "question": q, |
| "context": ctx, |
| "reference": ref, |
| "answer": ans, |
| "overlap_score": overlap, |
| "exact_match": em, |
| "latency_s": round(latency, 3), |
| "tokens_generated_approx": len(ans.split()), |
| } |
| results.append(row) |
| categories.setdefault(cat, []).append(row) |
|
|
| score_text = [] |
| if overlap is not None: |
| score_text.append(f"overlap={overlap:.0%}") |
| if em is not None: |
| score_text.append(f"EM={'oui' if em else 'non'}") |
| score_str = f" [{' | '.join(score_text)}]" if score_text else "" |
|
|
| print(sep) |
| print(f"[{i:02d}] [{cat}]{score_str}") |
| if ctx: |
| print(f" Contexte : {ctx[:120]}{'...' if len(ctx) > 120 else ''}") |
| print(f" User : {q}") |
| print(f" Assistant : {ans}") |
| if ref: |
| print(f" Référence : {ref}") |
| print(f" ⏱ {latency:.2f}s | ~{row['tokens_generated_approx']} mots\n") |
|
|
| scored = [r for r in results if r["overlap_score"] is not None] |
| avg_overlap = sum(r["overlap_score"] for r in scored) / len(scored) if scored else 0.0 |
| em_rows = [r for r in results if r["exact_match"] is not None] |
| em_rate = sum(1 for r in em_rows if r["exact_match"]) / len(em_rows) if em_rows else 0.0 |
| avg_latency = sum(r["latency_s"] for r in results) / max(1, len(results)) |
| avg_tokens = sum(r["tokens_generated_approx"] for r in results) / max(1, len(results)) |
|
|
| scores_by_category = {} |
| for cat, items in categories.items(): |
| cat_scored = [x for x in items if x["overlap_score"] is not None] |
| cat_em = [x for x in items if x["exact_match"] is not None] |
| scores_by_category[cat] = { |
| "avg_overlap": round(sum(x["overlap_score"] for x in cat_scored) / len(cat_scored), 4) if cat_scored else None, |
| "exact_match_rate": round(sum(1 for x in cat_em if x["exact_match"]) / len(cat_em), 4) if cat_em else None, |
| } |
|
|
| summary = { |
| "checkpoint": ckpt_path, |
| "total_questions": len(results), |
| "avg_overlap_score": round(avg_overlap, 4), |
| "exact_match_rate": round(em_rate, 4), |
| "avg_latency_s": round(avg_latency, 3), |
| "avg_tokens_generated_approx": round(avg_tokens, 1), |
| "scores_by_category": scores_by_category, |
| "results": results, |
| } |
|
|
| print(f"{'═'*64}") |
| print(" RÉSUMÉ") |
| print(f"{'═'*64}") |
| print(f" Questions testées : {len(results)}") |
| print(f" Overlap moyen : {avg_overlap:.1%}") |
| print(f" Exact match : {em_rate:.1%}") |
| print(f" Latence moyenne : {avg_latency:.2f}s") |
| print(f" Mots moyens : {avg_tokens:.0f}") |
| print(" Scores / catégorie :") |
| for cat, sc in scores_by_category.items(): |
| ov = sc["avg_overlap"] |
| emc = sc["exact_match_rate"] |
| print(f" - {cat:<15} overlap={ov if ov is not None else 'n/a'} | EM={emc if emc is not None else 'n/a'}") |
| print(f"{'═'*64}\n") |
|
|
| if save_report: |
| report_json = Path(output_dir) / "qa_test_simple_report.json" |
| report_txt = Path(output_dir) / "qa_test_simple_report.txt" |
|
|
| with open(report_json, "w", encoding="utf-8") as f: |
| json.dump(summary, f, ensure_ascii=False, indent=2) |
|
|
| with open(report_txt, "w", encoding="utf-8") as f: |
| f.write("TEST QA SIMPLE — MODÈLE ENTRAÎNÉ\n") |
| f.write(f"Checkpoint : {ckpt_path}\n\n") |
| for r in results: |
| f.write(f"[{r['id']:02d}] {r['category']}\n") |
| if r["context"]: |
| f.write(f" Contexte : {r['context']}\n") |
| f.write(f" User : {r['question']}\n") |
| f.write(f" Assistant : {r['answer']}\n") |
| if r["reference"]: |
| f.write(f" Référence : {r['reference']}\n") |
| if r["overlap_score"] is not None: |
| f.write(f" Overlap : {r['overlap_score']:.0%}\n") |
| if r["exact_match"] is not None: |
| f.write(f" EM : {'oui' if r['exact_match'] else 'non'}\n") |
| f.write(f" Latence : {r['latency_s']}s\n\n") |
|
|
| print(f"Rapport JSON -> {report_json}") |
| print(f"Rapport TXT -> {report_txt}") |
|
|
| return summary |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser("Test QA simple pour modèle déjà entraîné") |
| parser.add_argument("--train_script", type=str, default="./train_chatbot_100m_large.py") |
| parser.add_argument("--output_dir", type=str, default="./fr_100m") |
| parser.add_argument("--ckpt", type=str, default=None) |
| parser.add_argument("--questions", type=str, default=None, help="JSON optionnel [{question, reference, category, context?}]") |
| parser.add_argument("--no_save", action="store_true") |
| args = parser.parse_args() |
|
|
| ckpt_path = args.ckpt or os.path.join(args.output_dir, "sft_best.pt") |
| if not Path(ckpt_path).exists(): |
| raise FileNotFoundError( |
| f"Checkpoint introuvable: {ckpt_path}\n" |
| f"Exemple: --ckpt {args.output_dir}/sft_final.pt" |
| ) |
|
|
| if args.questions: |
| with open(args.questions, "r", encoding="utf-8") as f: |
| questions = json.load(f) |
| else: |
| questions = DEFAULT_QUESTIONS |
|
|
| run_test( |
| train_script=args.train_script, |
| output_dir=args.output_dir, |
| ckpt_path=ckpt_path, |
| questions=questions, |
| save_report=not args.no_save, |
| ) |
|
|