FirstChat / simple_qa_test_aramix_v3.py
Medyassino's picture
Add files using upload-large-folder tool
b9049d2 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
simple_qa_test_aramix_v3.py
Test QA simple et strict pour un modèle déjà entraîné dans une repo de type :
- train_aramix_h100_full.py
- aramix_h100/
- config.json
- model_best.pt
- model.pt
- tokenizer_32k/
Cette version améliore le test QA en :
- gérant les checkpoints torch.compile() avec préfixe "_orig_mod."
- utilisant un prompt plus directif pour des réponses courtes
- privilégiant une génération greedy / quasi-greedy
- tronquant proprement les réponses trop longues
- ajoutant une métrique "contains_reference"
Usage
-----
python simple_qa_test_aramix_v3.py --repo_dir ./aramix_h100 --save_report
python simple_qa_test_aramix_v3.py --repo_dir ./aramix_h100 --temperature 0 --max_new_tokens 16
python simple_qa_test_aramix_v3.py --repo_dir ./aramix_h100 --ckpt ./aramix_h100/model_best.pt
python simple_qa_test_aramix_v3.py --repo_dir ./aramix_h100 --questions qa_questions.json
"""
from __future__ import annotations
import argparse
import importlib.util
import json
import os
import re
import sys
import time
import unicodedata
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
DEFAULT_QUESTIONS = [
{"category": "Géographie", "question": "Quelle est la capitale de la France ?", "reference": "Paris"},
{"category": "Géographie", "question": "Quel est le plus long fleuve d'Afrique ?", "reference": "Le Nil"},
{"category": "Science", "question": "Qu'est-ce que la photosynthèse ?", "reference": "Processus par lequel les plantes convertissent la lumière en énergie"},
{"category": "Science", "question": "Combien d'os compte le corps humain adulte ?", "reference": "206"},
{"category": "Histoire", "question": "En quelle année a eu lieu la Révolution française ?", "reference": "1789"},
{"category": "Histoire", "question": "Qui a écrit Les Misérables ?", "reference": "Victor Hugo"},
{"category": "Mathématiques", "question": "Quelle est la formule de l'aire d'un cercle ?", "reference": "pi r carre"},
{"category": "Langage", "question": "Donne un synonyme du mot heureux.", "reference": "joyeux"},
{"category": "Raisonnement", "question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?", "reference": "3"},
{"category": "Dialogue", "question": "Comment vas-tu aujourd'hui ?", "reference": None},
]
def load_module_from_file(py_path: Path):
spec = importlib.util.spec_from_file_location(py_path.stem, py_path)
if spec is None or spec.loader is None:
raise RuntimeError(f"Impossible de charger le module: {py_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[py_path.stem] = module
spec.loader.exec_module(module)
return module
def normalize_text(text: str) -> str:
text = (text or "").strip().lower()
text = unicodedata.normalize("NFKD", text)
text = "".join(ch for ch in text if not unicodedata.combining(ch))
text = text.replace("π", "pi")
text = re.sub(r"[\W_]+", " ", text, flags=re.UNICODE)
text = re.sub(r"\s+", " ", text).strip()
return text
def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]:
if not reference:
return None
ref = set(normalize_text(reference).split())
ans = set(normalize_text(answer).split())
if not ref:
return None
return len(ref & ans) / len(ref)
def exact_match(reference: Optional[str], answer: str) -> Optional[bool]:
if not reference:
return None
return normalize_text(reference) == normalize_text(answer)
def contains_reference(reference: Optional[str], answer: str) -> Optional[bool]:
if not reference:
return None
ref = normalize_text(reference)
ans = normalize_text(answer)
if not ref:
return None
return ref in ans
def infer_repo_defaults(repo_dir: Path):
train_script = repo_dir.parent / "train_aramix_h100_full.py"
if not train_script.exists():
train_script = repo_dir / "train_aramix_h100_full.py"
ckpt = repo_dir / "model_best.pt"
if not ckpt.exists():
ckpt = repo_dir / "model.pt"
config = repo_dir / "config.json"
tokenizer_dir = repo_dir / "tokenizer_32k"
return train_script, ckpt, config, tokenizer_dir
def safe_get(cfg: Dict[str, Any], *names: str, default=None):
for name in names:
if name in cfg:
return cfg[name]
return default
def build_model_config_dict(cfg_json: Dict[str, Any], vocab_size: int) -> Dict[str, Any]:
block_size = safe_get(cfg_json, "block_size", "max_seq_len", "seq_len", default=512)
d_model = safe_get(cfg_json, "d_model", "n_embd", "dim", default=768)
n_heads = safe_get(cfg_json, "n_heads", "n_head", "num_heads", default=12)
n_layers = safe_get(cfg_json, "n_layers", "n_layer", "num_layers", default=12)
d_ff = safe_get(cfg_json, "d_ff", "ffn_dim", "intermediate_size", default=d_model * 4)
return {
"vocab_size": vocab_size,
"block_size": int(block_size),
"d_model": int(d_model),
"n_heads": int(n_heads),
"n_layers": int(n_layers),
"d_ff": int(d_ff),
}
def strip_orig_mod_prefix(state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
if any(k.startswith("_orig_mod.") for k in state.keys()):
return {k.replace("_orig_mod.", "", 1): v for k, v in state.items()}
return state
def clean_answer_text(text: str, max_words: int = 16) -> str:
text = (text or "").strip()
# Retire quelques marqueurs fréquents
text = text.replace("<eos>", " ")
text = text.replace("</s>", " ")
text = text.replace("<pad>", " ")
text = text.replace("Réponse :", " ")
text = text.replace("Réponse:", " ")
text = text.replace("Answer:", " ")
# Garde la première ligne
text = text.split("\n")[0].strip()
# Coupe à la première vraie fin de phrase courte
m = re.search(r"([.!?])", text)
if m and m.start() < 120:
text = text[: m.start() + 1]
# Compacte les espaces
text = re.sub(r"\s+", " ", text).strip()
# Tronque au nombre de mots voulu
words = text.split()
if len(words) > max_words:
text = " ".join(words[:max_words]).strip()
# Supprime ponctuation finale excessive
text = re.sub(r"[,\s;:]+$", "", text).strip()
return text
def safe_top_k_logits(logits: torch.Tensor, top_k: int) -> torch.Tensor:
if top_k is None or top_k <= 0:
return logits
k = min(int(top_k), logits.size(-1))
values, _ = torch.topk(logits, k=k)
kth = values[:, -1].unsqueeze(-1)
return torch.where(logits < kth, torch.full_like(logits, float("-inf")), logits)
class AramixChatTester:
def __init__(
self,
repo_dir: Path,
train_script: Path,
ckpt_path: Path,
config_path: Path,
device: Optional[str] = None,
prompt_style: str = "strict_qa",
):
self.repo_dir = repo_dir
self.train_script = train_script
self.ckpt_path = ckpt_path
self.config_path = config_path
self.prompt_style = prompt_style
self.device = torch.device(device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu"))
self.M = load_module_from_file(self.train_script)
required = ["GPT", "GPTConfig", "train_or_load_tokenizer", "DOMAINS"]
missing = [x for x in required if not hasattr(self.M, x)]
if missing:
raise RuntimeError(f"Le fichier {self.train_script.name} ne contient pas les symboles attendus: {missing}")
self.cfg_json: Dict[str, Any] = {}
if self.config_path.exists():
with open(self.config_path, "r", encoding="utf-8") as f:
self.cfg_json = json.load(f)
self.tokenizer = self._load_tokenizer()
self.model = self._load_model()
def _load_tokenizer(self):
old_cwd = Path.cwd()
try:
os.chdir(self.repo_dir.parent)
tok = self.M.train_or_load_tokenizer(self.M.DOMAINS)
finally:
os.chdir(old_cwd)
return tok
def _make_gpt_config(self):
kwargs = build_model_config_dict(self.cfg_json, vocab_size=len(self.tokenizer))
try:
return self.M.GPTConfig(**kwargs)
except TypeError:
return self.M.GPTConfig(vocab_size=len(self.tokenizer))
def _manual_load_state(self, model: torch.nn.Module):
ckpt = torch.load(self.ckpt_path, map_location=self.device)
if isinstance(ckpt, dict) and "model" in ckpt:
state = ckpt["model"]
else:
state = ckpt
if not isinstance(state, dict):
raise RuntimeError("Checkpoint non reconnu: pas de state_dict exploitable.")
state = strip_orig_mod_prefix(state)
missing, unexpected = model.load_state_dict(state, strict=False)
# tolérance minimale uniquement sur le tying éventuel
hard_missing = [k for k in missing if not k.endswith("lm_head.weight")]
hard_unexpected = [k for k in unexpected if not k.startswith("_orig_mod.")]
if hard_missing or hard_unexpected:
raise RuntimeError(
"Chargement manuel incomplet.\n"
f"Missing: {hard_missing[:20]}\n"
f"Unexpected: {hard_unexpected[:20]}"
)
def _load_model(self):
cfg = self._make_gpt_config()
model = self.M.GPT(cfg).to(self.device)
if hasattr(self.M, "load_checkpoint"):
try:
try:
self.M.load_checkpoint(model, None, self.ckpt_path, self.device)
model.eval()
return model
except TypeError:
self.M.load_checkpoint(model, self.ckpt_path, self.device)
model.eval()
return model
except RuntimeError as e:
if "_orig_mod." not in str(e):
raise
self._manual_load_state(model)
model.eval()
return model
def build_prompt(self, question: str) -> str:
if self.prompt_style == "strict_qa":
return (
"Réponds très brièvement et uniquement en français.\n"
"Donne seulement la réponse finale, sans explication.\n\n"
f"Question : {question}\n"
"Réponse :"
)
if self.prompt_style == "qa":
return f"Question: {question}\nRéponse:"
return question
def encode_prompt(self, question: str) -> List[int]:
bos = getattr(self.tokenizer, "bos_token_id", None)
eos = getattr(self.tokenizer, "eos_token_id", None)
prompt = self.build_prompt(question)
ids = self.tokenizer.encode(prompt, add_special_tokens=False)
if bos is not None:
ids = [bos] + ids
if eos is not None and ids and ids[-1] == eos:
ids = ids[:-1]
return ids
@torch.no_grad()
def generate(
self,
question: str,
max_new_tokens: int = 16,
temperature: float = 0.0,
top_k: int = 1,
repetition_penalty: float = 1.10,
max_answer_words: int = 16,
) -> str:
ids = self.encode_prompt(question)
x = torch.tensor([ids], dtype=torch.long, device=self.device)
eos_id = getattr(self.tokenizer, "eos_token_id", None)
block_size = getattr(getattr(self.model, "cfg", None), "block_size", None)
if block_size is None:
block_size = safe_get(self.cfg_json, "block_size", "max_seq_len", default=512)
generated_word_budget_hit = False
for step in range(max_new_tokens):
x_ctx = x[:, -int(block_size):]
out = self.model(x_ctx)
logits = out[0] if isinstance(out, tuple) else out
logits = logits[:, -1, :]
# Pénalité légère de répétition
recent = x[0, -48:].tolist()
for tok in set(recent):
logits[0, tok] /= repetition_penalty
# Greedy par défaut pour QA courte
if temperature <= 0:
if top_k is not None and top_k > 1:
masked = safe_top_k_logits(logits, top_k)
probs = F.softmax(masked, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
else:
next_tok = torch.argmax(logits, dim=-1, keepdim=True)
else:
logits = logits / max(temperature, 1e-5)
logits = safe_top_k_logits(logits, top_k)
probs = F.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
x = torch.cat([x, next_tok], dim=1)
if eos_id is not None and next_tok.item() == eos_id and step >= 1:
break
# arrêt anticipé si la réponse devient déjà trop longue
partial = self.tokenizer.decode(x[0, len(ids):].tolist()).strip()
partial = clean_answer_text(partial, max_words=max_answer_words)
if len(partial.split()) >= max_answer_words:
generated_word_budget_hit = True
break
new_ids = x[0, len(ids):].tolist()
text = self.tokenizer.decode(new_ids).strip()
text = clean_answer_text(text, max_words=max_answer_words)
if generated_word_budget_hit:
text = clean_answer_text(text, max_words=max_answer_words)
return text
def load_questions(path: Optional[str]) -> List[Dict[str, Any]]:
if not path:
return DEFAULT_QUESTIONS
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("Le fichier questions doit contenir une liste JSON.")
return data
def format_bar(score: float, width: int = 20) -> str:
n = max(0, min(width, int(round(score * width))))
return "█" * n + "░" * (width - n)
def save_reports(output_dir: Path, summary: Dict[str, Any]) -> Tuple[Path, Path]:
output_dir.mkdir(parents=True, exist_ok=True)
json_path = output_dir / "qa_test_report_simple_v3.json"
txt_path = output_dir / "qa_test_report_simple_v3.txt"
with open(json_path, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
with open(txt_path, "w", encoding="utf-8") as f:
f.write("TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ (V3)\n")
f.write("=" * 60 + "\n\n")
for r in summary["results"]:
f.write(f"[{r['id']:02d}] {r['category']}\n")
f.write(f" User : {r['question']}\n")
f.write(f" Assistant : {r['answer']}\n")
if r["reference"]:
f.write(f" Référence : {r['reference']}\n")
if r["overlap_score"] is not None:
f.write(f" Overlap : {r['overlap_score']:.0%}\n")
if r["exact_match"] is not None:
f.write(f" ExactMatch: {r['exact_match']}\n")
if r["contains_reference"] is not None:
f.write(f" Contains : {r['contains_reference']}\n")
f.write(f" Latence : {r['latency_s']}s\n\n")
return json_path, txt_path
def main():
parser = argparse.ArgumentParser("Test QA simple et strict pour modèle Aramix déjà entraîné")
parser.add_argument("--repo_dir", type=str, default="./aramix_h100")
parser.add_argument("--train_script", type=str, default=None)
parser.add_argument("--ckpt", type=str, default=None)
parser.add_argument("--config", type=str, default=None)
parser.add_argument("--questions", type=str, default=None)
parser.add_argument("--max_new_tokens", type=int, default=16)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--top_k", type=int, default=1)
parser.add_argument("--repetition_penalty", type=float, default=1.10)
parser.add_argument("--max_answer_words", type=int, default=16)
parser.add_argument("--prompt_style", type=str, default="strict_qa", choices=["strict_qa", "qa", "raw"])
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--save_report", action="store_true")
args = parser.parse_args()
repo_dir = Path(args.repo_dir).resolve()
train_script, ckpt_path, config_path, tokenizer_dir = infer_repo_defaults(repo_dir)
if args.train_script:
train_script = Path(args.train_script).resolve()
if args.ckpt:
ckpt_path = Path(args.ckpt).resolve()
if args.config:
config_path = Path(args.config).resolve()
if not train_script.exists():
raise FileNotFoundError(f"Script train introuvable: {train_script}")
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint introuvable: {ckpt_path}")
if not config_path.exists():
print(f"[WARN] config.json introuvable: {config_path} — fallback sur GPTConfig(vocab_size=...).")
questions = load_questions(args.questions)
tester = AramixChatTester(
repo_dir=repo_dir,
train_script=train_script,
ckpt_path=ckpt_path,
config_path=config_path,
device=args.device,
prompt_style=args.prompt_style,
)
results: List[Dict[str, Any]] = []
categories: Dict[str, List[Dict[str, Any]]] = {}
print("\n" + "=" * 60)
print("TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ")
print("=" * 60)
print(f"Repo : {repo_dir}")
print(f"Train script: {train_script}")
print(f"Checkpoint : {ckpt_path}")
print(f"Config : {config_path}")
print(f"Tokenizer : {tokenizer_dir}")
print(f"Device : {tester.device}")
print(f"Prompt : {args.prompt_style}")
print(f"Questions : {len(questions)}")
print("=" * 60 + "\n")
for i, item in enumerate(questions, 1):
q = item["question"]
ref = item.get("reference")
cat = item.get("category", "Général")
t0 = time.time()
ans = tester.generate(
q,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
max_answer_words=args.max_answer_words,
)
latency = time.time() - t0
overlap = lexical_overlap(ref, ans)
em = exact_match(ref, ans)
contains_ref = contains_reference(ref, ans)
entry = {
"id": i,
"category": cat,
"question": q,
"answer": ans,
"reference": ref,
"latency_s": round(latency, 3),
"tokens_generated_approx": len(ans.split()),
"overlap_score": None if overlap is None else round(overlap, 4),
"exact_match": em,
"contains_reference": contains_ref,
}
results.append(entry)
categories.setdefault(cat, []).append(entry)
overlap_str = f"{overlap:.0%}" if overlap is not None else "n/a"
em_str = "✓" if em else ("✗" if em is not None else "n/a")
contains_str = "✓" if contains_ref else ("✗" if contains_ref is not None else "n/a")
print(f"[{i:02d}] {cat}")
print(f" User : {q}")
print(f" Assistant : {ans}")
if ref:
print(f" Référence : {ref}")
print(f" Overlap : {overlap_str}")
print(f" ExactMatch: {em_str}")
print(f" Contains : {contains_str}")
print(f" Latence : {latency:.3f}s\n")
scored_overlap = [r["overlap_score"] for r in results if r["overlap_score"] is not None]
scored_em = [r["exact_match"] for r in results if r["exact_match"] is not None]
scored_contains = [r["contains_reference"] for r in results if r["contains_reference"] is not None]
avg_overlap = sum(scored_overlap) / len(scored_overlap) if scored_overlap else 0.0
em_rate = sum(1 for x in scored_em if x) / len(scored_em) if scored_em else 0.0
contains_rate = sum(1 for x in scored_contains if x) / len(scored_contains) if scored_contains else 0.0
avg_latency = sum(r["latency_s"] for r in results) / len(results) if results else 0.0
avg_words = sum(r["tokens_generated_approx"] for r in results) / len(results) if results else 0.0
cat_scores: Dict[str, float] = {}
for cat, items in categories.items():
vals = [r["overlap_score"] for r in items if r["overlap_score"] is not None]
cat_scores[cat] = (sum(vals) / len(vals)) if vals else 0.0
summary = {
"repo_dir": str(repo_dir),
"train_script": str(train_script),
"checkpoint": str(ckpt_path),
"config_path": str(config_path),
"tokenizer_dir": str(tokenizer_dir),
"device": str(tester.device),
"prompt_style": args.prompt_style,
"total_questions": len(results),
"avg_overlap_score": round(avg_overlap, 4),
"exact_match_rate": round(em_rate, 4),
"contains_reference_rate": round(contains_rate, 4),
"avg_latency_s": round(avg_latency, 3),
"avg_words_generated": round(avg_words, 1),
"scores_by_category": {k: round(v, 4) for k, v in cat_scores.items()},
"results": results,
}
print("=" * 60)
print("RÉSUMÉ")
print("=" * 60)
print(f"Questions testées : {len(results)}")
print(f"Overlap moyen : {avg_overlap:.1%}")
print(f"Exact match : {em_rate:.1%}")
print(f"Contains ref : {contains_rate:.1%}")
print(f"Latence moyenne : {avg_latency:.2f}s")
print(f"Mots moyens : {avg_words:.1f}")
print("Scores / catégorie:")
for cat, score in sorted(cat_scores.items()):
print(f" {cat:<15} {format_bar(score)} {score:.0%}")
print("=" * 60)
if args.save_report:
json_path, txt_path = save_reports(repo_dir, summary)
print(f"Rapports sauvegardés dans : {json_path}")
print(f" {txt_path}")
if __name__ == "__main__":
main()