FirstChat / simple_qa_test_aramix_v2.py
Medyassino's picture
Add files using upload-large-folder tool
b9049d2 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
simple_qa_test_aramix_v2.py
Version corrigée du test QA simple pour repo Aramix.
Correction principale :
- gère les checkpoints sauvés depuis torch.compile() avec préfixe "_orig_mod."
- contourne load_checkpoint() du script train si celui-ci échoue sur ce cas
Usage
-----
python simple_qa_test_aramix_v2.py --repo_dir ./aramix_h100 --save_report
"""
from __future__ import annotations
import argparse
import importlib.util
import json
import os
import re
import sys
import time
import unicodedata
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
import torch.nn.functional as F
DEFAULT_QUESTIONS = [
{"category": "Géographie", "question": "Quelle est la capitale de la France ?", "reference": "Paris"},
{"category": "Géographie", "question": "Quel est le plus long fleuve d'Afrique ?", "reference": "Le Nil"},
{"category": "Science", "question": "Qu'est-ce que la photosynthèse ?", "reference": "Processus par lequel les plantes convertissent la lumière en énergie"},
{"category": "Science", "question": "Combien d'os compte le corps humain adulte ?", "reference": "206"},
{"category": "Histoire", "question": "En quelle année a eu lieu la Révolution française ?", "reference": "1789"},
{"category": "Histoire", "question": "Qui a écrit Les Misérables ?", "reference": "Victor Hugo"},
{"category": "Mathématiques", "question": "Quelle est la formule de l'aire d'un cercle ?", "reference": "pi r carre"},
{"category": "Langage", "question": "Donne un synonyme du mot heureux.", "reference": "joyeux"},
{"category": "Raisonnement", "question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?", "reference": "3"},
{"category": "Dialogue", "question": "Comment vas-tu aujourd'hui ?", "reference": None},
]
def load_module_from_file(py_path: Path):
spec = importlib.util.spec_from_file_location(py_path.stem, py_path)
if spec is None or spec.loader is None:
raise RuntimeError(f"Impossible de charger le module: {py_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[py_path.stem] = module
spec.loader.exec_module(module)
return module
def normalize_text(text: str) -> str:
text = (text or "").strip().lower()
text = unicodedata.normalize("NFKD", text)
text = "".join(ch for ch in text if not unicodedata.combining(ch))
text = text.replace("π", "pi")
text = re.sub(r"[\W_]+", " ", text, flags=re.UNICODE)
text = re.sub(r"\s+", " ", text).strip()
return text
def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]:
if not reference:
return None
ref = set(normalize_text(reference).split())
ans = set(normalize_text(answer).split())
if not ref:
return None
return len(ref & ans) / len(ref)
def exact_match(reference: Optional[str], answer: str) -> Optional[bool]:
if not reference:
return None
return normalize_text(reference) == normalize_text(answer)
def infer_repo_defaults(repo_dir: Path):
train_script = repo_dir.parent / "train_aramix_h100_full.py"
if not train_script.exists():
train_script = repo_dir / "train_aramix_h100_full.py"
ckpt = repo_dir / "model_best.pt"
if not ckpt.exists():
ckpt = repo_dir / "model.pt"
config = repo_dir / "config.json"
tokenizer_dir = repo_dir / "tokenizer_32k"
return train_script, ckpt, config, tokenizer_dir
def safe_get(cfg: Dict[str, Any], *names: str, default=None):
for name in names:
if name in cfg:
return cfg[name]
return default
def build_model_config_dict(cfg_json: Dict[str, Any], vocab_size: int) -> Dict[str, Any]:
block_size = safe_get(cfg_json, "block_size", "max_seq_len", "seq_len", default=512)
d_model = safe_get(cfg_json, "d_model", "n_embd", "dim", default=768)
n_heads = safe_get(cfg_json, "n_heads", "n_head", "num_heads", default=12)
n_layers = safe_get(cfg_json, "n_layers", "n_layer", "num_layers", default=12)
d_ff = safe_get(cfg_json, "d_ff", "ffn_dim", "intermediate_size", default=d_model * 4)
return {
"vocab_size": vocab_size,
"block_size": int(block_size),
"d_model": int(d_model),
"n_heads": int(n_heads),
"n_layers": int(n_layers),
"d_ff": int(d_ff),
}
def strip_orig_mod_prefix(state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
if any(k.startswith("_orig_mod.") for k in state.keys()):
return {k.replace("_orig_mod.", "", 1): v for k, v in state.items()}
return state
class AramixChatTester:
def __init__(
self,
repo_dir: Path,
train_script: Path,
ckpt_path: Path,
config_path: Path,
device: Optional[str] = None,
):
self.repo_dir = repo_dir
self.train_script = train_script
self.ckpt_path = ckpt_path
self.config_path = config_path
self.device = torch.device(device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu"))
self.M = load_module_from_file(self.train_script)
required = ["GPT", "GPTConfig", "train_or_load_tokenizer", "DOMAINS"]
missing = [x for x in required if not hasattr(self.M, x)]
if missing:
raise RuntimeError(f"Le fichier {self.train_script.name} ne contient pas les symboles attendus: {missing}")
self.cfg_json: Dict[str, Any] = {}
if self.config_path.exists():
with open(self.config_path, "r", encoding="utf-8") as f:
self.cfg_json = json.load(f)
self.tokenizer = self._load_tokenizer()
self.model = self._load_model()
def _load_tokenizer(self):
old_cwd = Path.cwd()
try:
os.chdir(self.repo_dir.parent)
tok = self.M.train_or_load_tokenizer(self.M.DOMAINS)
finally:
os.chdir(old_cwd)
return tok
def _make_gpt_config(self):
kwargs = build_model_config_dict(self.cfg_json, vocab_size=len(self.tokenizer))
try:
return self.M.GPTConfig(**kwargs)
except TypeError:
return self.M.GPTConfig(vocab_size=len(self.tokenizer))
def _manual_load_state(self, model: torch.nn.Module):
ckpt = torch.load(self.ckpt_path, map_location=self.device)
if isinstance(ckpt, dict) and "model" in ckpt:
state = ckpt["model"]
else:
state = ckpt
if not isinstance(state, dict):
raise RuntimeError("Checkpoint non reconnu: pas de state_dict exploitable.")
state = strip_orig_mod_prefix(state)
missing, unexpected = model.load_state_dict(state, strict=False)
# tolérance uniquement sur lm_head/tied weights éventuels, sinon on échoue
hard_missing = [k for k in missing if not k.endswith("lm_head.weight")]
hard_unexpected = [k for k in unexpected if not k.startswith("_orig_mod.")]
if hard_missing or hard_unexpected:
raise RuntimeError(
"Chargement manuel incomplet.\n"
f"Missing: {hard_missing[:20]}\n"
f"Unexpected: {hard_unexpected[:20]}"
)
def _load_model(self):
cfg = self._make_gpt_config()
model = self.M.GPT(cfg).to(self.device)
# 1) tentative via load_checkpoint du script train
if hasattr(self.M, "load_checkpoint"):
try:
try:
self.M.load_checkpoint(model, None, self.ckpt_path, self.device)
model.eval()
return model
except TypeError:
self.M.load_checkpoint(model, self.ckpt_path, self.device)
model.eval()
return model
except RuntimeError as e:
msg = str(e)
if "_orig_mod." not in msg:
raise
# 2) fallback robuste
self._manual_load_state(model)
model.eval()
return model
def encode_prompt(self, question: str) -> List[int]:
bos = getattr(self.tokenizer, "bos_token_id", None)
eos = getattr(self.tokenizer, "eos_token_id", None)
prompt = f"Question: {question}\nRéponse:"
ids = self.tokenizer.encode(prompt, add_special_tokens=False)
if bos is not None:
ids = [bos] + ids
if eos is not None and ids and ids[-1] == eos:
ids = ids[:-1]
return ids
@torch.no_grad()
def generate(
self,
question: str,
max_new_tokens: int = 96,
temperature: float = 0.4,
top_k: int = 40,
repetition_penalty: float = 1.12,
) -> str:
ids = self.encode_prompt(question)
x = torch.tensor([ids], dtype=torch.long, device=self.device)
eos_id = getattr(self.tokenizer, "eos_token_id", None)
block_size = getattr(getattr(self.model, "cfg", None), "block_size", None)
if block_size is None:
block_size = safe_get(self.cfg_json, "block_size", "max_seq_len", default=512)
for step in range(max_new_tokens):
x_ctx = x[:, -int(block_size):]
out = self.model(x_ctx)
logits = out[0] if isinstance(out, tuple) else out
logits = logits[:, -1, :]
recent = x[0, -64:].tolist()
for tok in set(recent):
logits[0, tok] /= repetition_penalty
if temperature <= 0:
next_tok = torch.argmax(logits, dim=-1, keepdim=True)
else:
logits = logits / max(temperature, 1e-5)
if top_k is not None and top_k > 0:
values, _ = torch.topk(logits, k=min(top_k, logits.size(-1)))
kth = values[:, -1].unsqueeze(-1)
logits = torch.where(logits < kth, torch.full_like(logits, float("-inf")), logits)
probs = F.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
x = torch.cat([x, next_tok], dim=1)
if eos_id is not None and next_tok.item() == eos_id and step >= 2:
break
new_ids = x[0, len(ids):].tolist()
text = self.tokenizer.decode(new_ids).strip()
text = re.sub(r"\s+", " ", text).strip()
text = text.replace("Réponse :", "").replace("Réponse:", "").strip()
return text
def load_questions(path: Optional[str]) -> List[Dict[str, Any]]:
if not path:
return DEFAULT_QUESTIONS
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("Le fichier questions doit contenir une liste JSON.")
return data
def format_bar(score: float, width: int = 20) -> str:
n = max(0, min(width, int(round(score * width))))
return "█" * n + "░" * (width - n)
def save_reports(output_dir: Path, summary: Dict[str, Any]) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
json_path = output_dir / "qa_test_report_simple.json"
txt_path = output_dir / "qa_test_report_simple.txt"
with open(json_path, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
with open(txt_path, "w", encoding="utf-8") as f:
f.write("TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ\n")
f.write("=" * 60 + "\n\n")
for r in summary["results"]:
f.write(f"[{r['id']:02d}] {r['category']}\n")
f.write(f" User : {r['question']}\n")
f.write(f" Assistant : {r['answer']}\n")
if r["reference"]:
f.write(f" Référence : {r['reference']}\n")
if r["overlap_score"] is not None:
f.write(f" Overlap : {r['overlap_score']:.0%}\n")
if r["exact_match"] is not None:
f.write(f" ExactMatch: {r['exact_match']}\n")
f.write(f" Latence : {r['latency_s']}s\n\n")
def main():
parser = argparse.ArgumentParser("Test QA simple pour modèle Aramix déjà entraîné")
parser.add_argument("--repo_dir", type=str, default="./aramix_h100")
parser.add_argument("--train_script", type=str, default=None)
parser.add_argument("--ckpt", type=str, default=None)
parser.add_argument("--config", type=str, default=None)
parser.add_argument("--questions", type=str, default=None)
parser.add_argument("--max_new_tokens", type=int, default=96)
parser.add_argument("--temperature", type=float, default=0.4)
parser.add_argument("--top_k", type=int, default=40)
parser.add_argument("--repetition_penalty", type=float, default=1.12)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--save_report", action="store_true")
args = parser.parse_args()
repo_dir = Path(args.repo_dir).resolve()
train_script, ckpt_path, config_path, tokenizer_dir = infer_repo_defaults(repo_dir)
if args.train_script:
train_script = Path(args.train_script).resolve()
if args.ckpt:
ckpt_path = Path(args.ckpt).resolve()
if args.config:
config_path = Path(args.config).resolve()
if not train_script.exists():
raise FileNotFoundError(f"Script train introuvable: {train_script}")
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint introuvable: {ckpt_path}")
if not config_path.exists():
print(f"[WARN] config.json introuvable: {config_path} — fallback sur GPTConfig(vocab_size=...).")
questions = load_questions(args.questions)
tester = AramixChatTester(
repo_dir=repo_dir,
train_script=train_script,
ckpt_path=ckpt_path,
config_path=config_path,
device=args.device,
)
results: List[Dict[str, Any]] = []
categories: Dict[str, List[Dict[str, Any]]] = {}
print("\n" + "═" * 70)
print(" TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ")
print("═" * 70)
print(f"Repo : {repo_dir}")
print(f"Train script: {train_script}")
print(f"Checkpoint : {ckpt_path}")
print(f"Config : {config_path}")
print(f"Tokenizer : {tokenizer_dir}")
print(f"Device : {tester.device}")
print(f"Questions : {len(questions)}")
print("═" * 70 + "\n")
for i, item in enumerate(questions, 1):
q = item["question"]
ref = item.get("reference")
cat = item.get("category", "Général")
t0 = time.time()
ans = tester.generate(
q,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
)
latency = time.time() - t0
overlap = lexical_overlap(ref, ans)
em = exact_match(ref, ans)
entry = {
"id": i,
"category": cat,
"question": q,
"answer": ans,
"reference": ref,
"latency_s": round(latency, 3),
"tokens_generated_approx": len(ans.split()),
"overlap_score": None if overlap is None else round(overlap, 4),
"exact_match": em,
}
results.append(entry)
categories.setdefault(cat, []).append(entry)
overlap_str = f"{overlap:.0%}" if overlap is not None else "n/a"
em_str = "✓" if em else ("✗" if em is not None else "n/a")
print("─" * 70)
print(f"[{i:02d}] [{cat}] overlap={overlap_str} | EM={em_str}")
print(f" User : {q}")
print(f" Assistant : {ans}")
if ref:
print(f" Référence : {ref}")
print(f" ⏱ {latency:.2f}s | ~{entry['tokens_generated_approx']} mots\n")
scored_overlap = [r["overlap_score"] for r in results if r["overlap_score"] is not None]
scored_em = [r["exact_match"] for r in results if r["exact_match"] is not None]
avg_overlap = sum(scored_overlap) / len(scored_overlap) if scored_overlap else 0.0
em_rate = sum(1 for x in scored_em if x) / len(scored_em) if scored_em else 0.0
avg_latency = sum(r["latency_s"] for r in results) / len(results) if results else 0.0
avg_words = sum(r["tokens_generated_approx"] for r in results) / len(results) if results else 0.0
cat_scores: Dict[str, float] = {}
for cat, items in categories.items():
vals = [r["overlap_score"] for r in items if r["overlap_score"] is not None]
cat_scores[cat] = (sum(vals) / len(vals)) if vals else 0.0
summary = {
"repo_dir": str(repo_dir),
"train_script": str(train_script),
"checkpoint": str(ckpt_path),
"config_path": str(config_path),
"tokenizer_dir": str(tokenizer_dir),
"device": str(tester.device),
"total_questions": len(results),
"avg_overlap_score": round(avg_overlap, 4),
"exact_match_rate": round(em_rate, 4),
"avg_latency_s": round(avg_latency, 3),
"avg_words_generated": round(avg_words, 1),
"scores_by_category": {k: round(v, 4) for k, v in cat_scores.items()},
"results": results,
}
print("═" * 70)
print(" RÉSUMÉ")
print("═" * 70)
print(f"Questions testées : {len(results)}")
print(f"Overlap moyen : {avg_overlap:.1%}")
print(f"Exact match : {em_rate:.1%}")
print(f"Latence moyenne : {avg_latency:.2f}s")
print(f"Mots moyens : {avg_words:.1f}")
print("Scores / catégorie:")
for cat, score in sorted(cat_scores.items()):
print(f" {cat:<15} {format_bar(score)} {score:.0%}")
print("═" * 70)
if args.save_report:
save_reports(repo_dir, summary)
print(f"Rapports sauvegardés dans : {repo_dir / 'qa_test_report_simple.json'}")
print(f" {repo_dir / 'qa_test_report_simple.txt'}")
if __name__ == "__main__":
main()