FirstChat / simple_qa_test_finished_model.py
Medyassino's picture
Add files using upload-large-folder tool
b9049d2 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Test QA simple pour un modèle déjà entraîné.
Fonctionne avec le script : train_chatbot_100m_large.py
- charge la config depuis output_dir/train_config.json si disponible
- charge un checkpoint fini (par défaut: output_dir/sft_best.pt)
- pose une petite liste de questions QA
- calcule un score simple d'overlap lexical
- sauvegarde un rapport JSON + TXT
Exemples
--------
python simple_qa_test_finished_model.py --output_dir ./fr_100m
python simple_qa_test_finished_model.py --output_dir ./fr_100m --ckpt ./fr_100m/sft_final.pt
python simple_qa_test_finished_model.py --output_dir ./fr_100m --questions qa_questions.json
"""
from __future__ import annotations
import argparse
import importlib.util
import json
import os
import sys
import time
import unicodedata
import re
from pathlib import Path
from typing import Dict, List, Optional
DEFAULT_QUESTIONS = [
{
"category": "Géographie",
"question": "Quelle est la capitale de la France ?",
"reference": "Paris",
},
{
"category": "Géographie",
"question": "Quel est le plus long fleuve d'Afrique ?",
"reference": "Le Nil",
},
{
"category": "Science",
"question": "Qu'est-ce que la photosynthèse ?",
"reference": "Processus par lequel les plantes convertissent la lumière en énergie",
},
{
"category": "Science",
"question": "Combien d'os compte le corps humain adulte ?",
"reference": "206",
},
{
"category": "Histoire",
"question": "En quelle année a eu lieu la Révolution française ?",
"reference": "1789",
},
{
"category": "Histoire",
"question": "Qui a écrit Les Misérables ?",
"reference": "Victor Hugo",
},
{
"category": "Mathématiques",
"question": "Quelle est la formule de l'aire d'un cercle ?",
"reference": "π × r²",
},
{
"category": "Langage",
"question": "Donne un synonyme du mot heureux.",
"reference": "joyeux",
},
{
"category": "Raisonnement",
"question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?",
"reference": "3",
},
{
"category": "Dialogue",
"question": "Comment vas-tu aujourd'hui ?",
"reference": None,
},
]
def normalize_text(s: str) -> str:
s = (s or "").strip().lower()
s = unicodedata.normalize("NFKD", s)
s = "".join(ch for ch in s if not unicodedata.combining(ch))
s = re.sub(r"[^\w\s]", " ", s, flags=re.UNICODE)
s = re.sub(r"\s+", " ", s).strip()
return s
def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]:
if not reference:
return None
ref_tokens = set(normalize_text(reference).split())
ans_tokens = set(normalize_text(answer).split())
if not ref_tokens:
return 0.0
return len(ref_tokens & ans_tokens) / len(ref_tokens)
def exact_match(reference: Optional[str], answer: str) -> Optional[bool]:
if not reference:
return None
return normalize_text(reference) == normalize_text(answer)
def import_train_module(train_script_path: str):
path = Path(train_script_path)
if not path.exists():
raise FileNotFoundError(f"Script d'entraînement introuvable: {path}")
spec = importlib.util.spec_from_file_location("train_module", str(path))
if spec is None or spec.loader is None:
raise ImportError(f"Impossible de charger le module: {path}")
module = importlib.util.module_from_spec(spec)
sys.modules["train_module"] = module
spec.loader.exec_module(module)
return module
def build_cfg(train_module, output_dir: str):
cfg_path = Path(output_dir) / "train_config.json"
if cfg_path.exists():
with open(cfg_path, "r", encoding="utf-8") as f:
saved = json.load(f)
cfg = train_module.TrainConfig(**saved)
else:
cfg = train_module.TrainConfig(output_dir=output_dir, tokenizer_prefix=f"{output_dir}/tokenizer")
cfg.output_dir = output_dir
cfg.tokenizer_prefix = f"{output_dir}/tokenizer"
return cfg
def run_test(
train_script: str,
output_dir: str,
ckpt_path: str,
questions: List[Dict],
save_report: bool,
):
train_module = import_train_module(train_script)
cfg = build_cfg(train_module, output_dir)
bot = train_module.Chatbot(cfg, ckpt_path)
results = []
categories: Dict[str, List[Dict]] = {}
sep = "─" * 64
print(f"\n{'═'*64}")
print(" TEST QA SIMPLE — MODÈLE ENTRAÎNÉ")
print(f" Checkpoint : {ckpt_path}")
print(f" Questions : {len(questions)}")
print(f"{'═'*64}\n")
for i, item in enumerate(questions, 1):
q = item["question"]
ref = item.get("reference")
cat = item.get("category", "Général")
ctx = item.get("context", "")
t0 = time.time()
ans = bot.chat(q, context=ctx)
latency = time.time() - t0
overlap = lexical_overlap(ref, ans)
em = exact_match(ref, ans)
row = {
"id": i,
"category": cat,
"question": q,
"context": ctx,
"reference": ref,
"answer": ans,
"overlap_score": overlap,
"exact_match": em,
"latency_s": round(latency, 3),
"tokens_generated_approx": len(ans.split()),
}
results.append(row)
categories.setdefault(cat, []).append(row)
score_text = []
if overlap is not None:
score_text.append(f"overlap={overlap:.0%}")
if em is not None:
score_text.append(f"EM={'oui' if em else 'non'}")
score_str = f" [{' | '.join(score_text)}]" if score_text else ""
print(sep)
print(f"[{i:02d}] [{cat}]{score_str}")
if ctx:
print(f" Contexte : {ctx[:120]}{'...' if len(ctx) > 120 else ''}")
print(f" User : {q}")
print(f" Assistant : {ans}")
if ref:
print(f" Référence : {ref}")
print(f" ⏱ {latency:.2f}s | ~{row['tokens_generated_approx']} mots\n")
scored = [r for r in results if r["overlap_score"] is not None]
avg_overlap = sum(r["overlap_score"] for r in scored) / len(scored) if scored else 0.0
em_rows = [r for r in results if r["exact_match"] is not None]
em_rate = sum(1 for r in em_rows if r["exact_match"]) / len(em_rows) if em_rows else 0.0
avg_latency = sum(r["latency_s"] for r in results) / max(1, len(results))
avg_tokens = sum(r["tokens_generated_approx"] for r in results) / max(1, len(results))
scores_by_category = {}
for cat, items in categories.items():
cat_scored = [x for x in items if x["overlap_score"] is not None]
cat_em = [x for x in items if x["exact_match"] is not None]
scores_by_category[cat] = {
"avg_overlap": round(sum(x["overlap_score"] for x in cat_scored) / len(cat_scored), 4) if cat_scored else None,
"exact_match_rate": round(sum(1 for x in cat_em if x["exact_match"]) / len(cat_em), 4) if cat_em else None,
}
summary = {
"checkpoint": ckpt_path,
"total_questions": len(results),
"avg_overlap_score": round(avg_overlap, 4),
"exact_match_rate": round(em_rate, 4),
"avg_latency_s": round(avg_latency, 3),
"avg_tokens_generated_approx": round(avg_tokens, 1),
"scores_by_category": scores_by_category,
"results": results,
}
print(f"{'═'*64}")
print(" RÉSUMÉ")
print(f"{'═'*64}")
print(f" Questions testées : {len(results)}")
print(f" Overlap moyen : {avg_overlap:.1%}")
print(f" Exact match : {em_rate:.1%}")
print(f" Latence moyenne : {avg_latency:.2f}s")
print(f" Mots moyens : {avg_tokens:.0f}")
print(" Scores / catégorie :")
for cat, sc in scores_by_category.items():
ov = sc["avg_overlap"]
emc = sc["exact_match_rate"]
print(f" - {cat:<15} overlap={ov if ov is not None else 'n/a'} | EM={emc if emc is not None else 'n/a'}")
print(f"{'═'*64}\n")
if save_report:
report_json = Path(output_dir) / "qa_test_simple_report.json"
report_txt = Path(output_dir) / "qa_test_simple_report.txt"
with open(report_json, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
with open(report_txt, "w", encoding="utf-8") as f:
f.write("TEST QA SIMPLE — MODÈLE ENTRAÎNÉ\n")
f.write(f"Checkpoint : {ckpt_path}\n\n")
for r in results:
f.write(f"[{r['id']:02d}] {r['category']}\n")
if r["context"]:
f.write(f" Contexte : {r['context']}\n")
f.write(f" User : {r['question']}\n")
f.write(f" Assistant : {r['answer']}\n")
if r["reference"]:
f.write(f" Référence : {r['reference']}\n")
if r["overlap_score"] is not None:
f.write(f" Overlap : {r['overlap_score']:.0%}\n")
if r["exact_match"] is not None:
f.write(f" EM : {'oui' if r['exact_match'] else 'non'}\n")
f.write(f" Latence : {r['latency_s']}s\n\n")
print(f"Rapport JSON -> {report_json}")
print(f"Rapport TXT -> {report_txt}")
return summary
if __name__ == "__main__":
parser = argparse.ArgumentParser("Test QA simple pour modèle déjà entraîné")
parser.add_argument("--train_script", type=str, default="./train_chatbot_100m_large.py")
parser.add_argument("--output_dir", type=str, default="./fr_100m")
parser.add_argument("--ckpt", type=str, default=None)
parser.add_argument("--questions", type=str, default=None, help="JSON optionnel [{question, reference, category, context?}]")
parser.add_argument("--no_save", action="store_true")
args = parser.parse_args()
ckpt_path = args.ckpt or os.path.join(args.output_dir, "sft_best.pt")
if not Path(ckpt_path).exists():
raise FileNotFoundError(
f"Checkpoint introuvable: {ckpt_path}\n"
f"Exemple: --ckpt {args.output_dir}/sft_final.pt"
)
if args.questions:
with open(args.questions, "r", encoding="utf-8") as f:
questions = json.load(f)
else:
questions = DEFAULT_QUESTIONS
run_test(
train_script=args.train_script,
output_dir=args.output_dir,
ckpt_path=ckpt_path,
questions=questions,
save_report=not args.no_save,
)