| |
| |
|
|
| """ |
| simple_qa_test_aramix.py |
| |
| Test QA simple pour un modèle déjà entraîné dans une repo de type : |
| - train_aramix_h100_full.py |
| - aramix_h100/ |
| - config.json |
| - model_best.pt |
| - model.pt |
| - tokenizer_32k/ |
| |
| Hypothèses alignées avec ton repo : |
| - le module d'entraînement expose : GPT, GPTConfig, train_or_load_tokenizer, |
| load_checkpoint, DOMAINS |
| - le tokenizer est géré par train_or_load_tokenizer(DOMAINS) |
| - le checkpoint se recharge avec load_checkpoint(model, opt, ckpt_path, device) |
| |
| Usage |
| ----- |
| python simple_qa_test_aramix.py |
| python simple_qa_test_aramix.py --repo_dir ./aramix_h100 |
| python simple_qa_test_aramix.py --ckpt ./aramix_h100/model.pt |
| python simple_qa_test_aramix.py --questions qa_questions.json |
| python simple_qa_test_aramix.py --max_new_tokens 96 --temperature 0.4 --top_k 40 |
| python simple_qa_test_aramix.py --save_report |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import importlib.util |
| import json |
| import os |
| import re |
| import sys |
| import time |
| import unicodedata |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| DEFAULT_QUESTIONS = [ |
| { |
| "category": "Géographie", |
| "question": "Quelle est la capitale de la France ?", |
| "reference": "Paris", |
| }, |
| { |
| "category": "Géographie", |
| "question": "Quel est le plus long fleuve d'Afrique ?", |
| "reference": "Le Nil", |
| }, |
| { |
| "category": "Science", |
| "question": "Qu'est-ce que la photosynthèse ?", |
| "reference": "Processus par lequel les plantes convertissent la lumière en énergie", |
| }, |
| { |
| "category": "Science", |
| "question": "Combien d'os compte le corps humain adulte ?", |
| "reference": "206", |
| }, |
| { |
| "category": "Histoire", |
| "question": "En quelle année a eu lieu la Révolution française ?", |
| "reference": "1789", |
| }, |
| { |
| "category": "Histoire", |
| "question": "Qui a écrit Les Misérables ?", |
| "reference": "Victor Hugo", |
| }, |
| { |
| "category": "Mathématiques", |
| "question": "Quelle est la formule de l'aire d'un cercle ?", |
| "reference": "pi r carre", |
| }, |
| { |
| "category": "Langage", |
| "question": "Donne un synonyme du mot heureux.", |
| "reference": "joyeux", |
| }, |
| { |
| "category": "Raisonnement", |
| "question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?", |
| "reference": "3", |
| }, |
| { |
| "category": "Dialogue", |
| "question": "Comment vas-tu aujourd'hui ?", |
| "reference": None, |
| }, |
| ] |
|
|
|
|
| def load_module_from_file(py_path: Path): |
| spec = importlib.util.spec_from_file_location(py_path.stem, py_path) |
| if spec is None or spec.loader is None: |
| raise RuntimeError(f"Impossible de charger le module: {py_path}") |
| module = importlib.util.module_from_spec(spec) |
| sys.modules[py_path.stem] = module |
| spec.loader.exec_module(module) |
| return module |
|
|
|
|
| def normalize_text(text: str) -> str: |
| text = (text or "").strip().lower() |
| text = unicodedata.normalize("NFKD", text) |
| text = "".join(ch for ch in text if not unicodedata.combining(ch)) |
| text = text.replace("π", "pi") |
| text = re.sub(r"[\W_]+", " ", text, flags=re.UNICODE) |
| text = re.sub(r"\s+", " ", text).strip() |
| return text |
|
|
|
|
| def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]: |
| if not reference: |
| return None |
| ref = set(normalize_text(reference).split()) |
| ans = set(normalize_text(answer).split()) |
| if not ref: |
| return None |
| return len(ref & ans) / len(ref) |
|
|
|
|
| def exact_match(reference: Optional[str], answer: str) -> Optional[bool]: |
| if not reference: |
| return None |
| return normalize_text(reference) == normalize_text(answer) |
|
|
|
|
| def infer_repo_defaults(repo_dir: Path): |
| train_script = repo_dir.parent / "train_aramix_h100_full.py" |
| if not train_script.exists(): |
| train_script = repo_dir / "train_aramix_h100_full.py" |
|
|
| ckpt = repo_dir / "model_best.pt" |
| if not ckpt.exists(): |
| ckpt = repo_dir / "model.pt" |
|
|
| config = repo_dir / "config.json" |
| tokenizer_dir = repo_dir / "tokenizer_32k" |
| return train_script, ckpt, config, tokenizer_dir |
|
|
|
|
| def safe_get(cfg: Dict[str, Any], *names: str, default=None): |
| for name in names: |
| if name in cfg: |
| return cfg[name] |
| return default |
|
|
|
|
| def build_model_config_dict(cfg_json: Dict[str, Any], vocab_size: int) -> Dict[str, Any]: |
| block_size = safe_get(cfg_json, "block_size", "max_seq_len", "seq_len", default=512) |
| d_model = safe_get(cfg_json, "d_model", "n_embd", "dim", default=768) |
| n_heads = safe_get(cfg_json, "n_heads", "n_head", "num_heads", default=12) |
| n_layers = safe_get(cfg_json, "n_layers", "n_layer", "num_layers", default=12) |
| d_ff = safe_get(cfg_json, "d_ff", "ffn_dim", "intermediate_size", default=d_model * 4) |
|
|
| return { |
| "vocab_size": vocab_size, |
| "block_size": int(block_size), |
| "d_model": int(d_model), |
| "n_heads": int(n_heads), |
| "n_layers": int(n_layers), |
| "d_ff": int(d_ff), |
| } |
|
|
|
|
| class AramixChatTester: |
| def __init__( |
| self, |
| repo_dir: Path, |
| train_script: Path, |
| ckpt_path: Path, |
| config_path: Path, |
| device: Optional[str] = None, |
| ): |
| self.repo_dir = repo_dir |
| self.train_script = train_script |
| self.ckpt_path = ckpt_path |
| self.config_path = config_path |
|
|
| self.device = torch.device( |
| device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu") |
| ) |
|
|
| self.M = load_module_from_file(self.train_script) |
|
|
| required = ["GPT", "GPTConfig", "train_or_load_tokenizer", "load_checkpoint", "DOMAINS"] |
| missing = [x for x in required if not hasattr(self.M, x)] |
| if missing: |
| raise RuntimeError( |
| f"Le fichier {self.train_script.name} ne contient pas les symboles attendus: {missing}" |
| ) |
|
|
| self.cfg_json: Dict[str, Any] = {} |
| if self.config_path.exists(): |
| with open(self.config_path, "r", encoding="utf-8") as f: |
| self.cfg_json = json.load(f) |
|
|
| self.tokenizer = self._load_tokenizer() |
| self.model = self._load_model() |
|
|
| def _load_tokenizer(self): |
| old_cwd = Path.cwd() |
| try: |
| os.chdir(self.repo_dir.parent) |
| tok = self.M.train_or_load_tokenizer(self.M.DOMAINS) |
| finally: |
| os.chdir(old_cwd) |
| return tok |
|
|
| def _make_gpt_config(self): |
| kwargs = build_model_config_dict(self.cfg_json, vocab_size=len(self.tokenizer)) |
| try: |
| return self.M.GPTConfig(**kwargs) |
| except TypeError: |
| return self.M.GPTConfig(vocab_size=len(self.tokenizer)) |
|
|
| def _load_model(self): |
| cfg = self._make_gpt_config() |
| model = self.M.GPT(cfg).to(self.device) |
|
|
| try: |
| self.M.load_checkpoint(model, None, self.ckpt_path, self.device) |
| except TypeError: |
| try: |
| self.M.load_checkpoint(model, self.ckpt_path, self.device) |
| except TypeError: |
| ckpt = torch.load(self.ckpt_path, map_location=self.device) |
| state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt |
| if any(k.startswith("_orig_mod.") for k in state): |
| state = {k.replace("_orig_mod.", ""): v for k, v in state.items()} |
| model.load_state_dict(state, strict=False) |
|
|
| model.eval() |
| return model |
|
|
| def encode_prompt(self, question: str) -> List[int]: |
| bos = getattr(self.tokenizer, "bos_token_id", None) |
| eos = getattr(self.tokenizer, "eos_token_id", None) |
|
|
| prompt = f"Question: {question}\nRéponse:" |
| ids = self.tokenizer.encode(prompt, add_special_tokens=False) |
|
|
| if bos is not None: |
| ids = [bos] + ids |
| if eos is not None and len(ids) > 0 and ids[-1] == eos: |
| ids = ids[:-1] |
| return ids |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| question: str, |
| max_new_tokens: int = 96, |
| temperature: float = 0.4, |
| top_k: int = 40, |
| repetition_penalty: float = 1.12, |
| ) -> str: |
| ids = self.encode_prompt(question) |
| x = torch.tensor([ids], dtype=torch.long, device=self.device) |
|
|
| eos_id = getattr(self.tokenizer, "eos_token_id", None) |
| block_size = getattr(getattr(self.model, "cfg", None), "block_size", None) |
| if block_size is None: |
| block_size = safe_get(self.cfg_json, "block_size", "max_seq_len", default=512) |
|
|
| for step in range(max_new_tokens): |
| x_ctx = x[:, -int(block_size):] |
|
|
| try: |
| logits, _ = self.model(x_ctx) |
| except TypeError: |
| out = self.model(x_ctx) |
| logits = out[0] if isinstance(out, tuple) else out |
|
|
| logits = logits[:, -1, :] |
|
|
| recent = x[0, -64:].tolist() |
| for tok in set(recent): |
| logits[0, tok] /= repetition_penalty |
|
|
| if temperature <= 0: |
| next_tok = torch.argmax(logits, dim=-1, keepdim=True) |
| else: |
| logits = logits / max(temperature, 1e-5) |
|
|
| if top_k is not None and top_k > 0: |
| values, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) |
| kth = values[:, -1].unsqueeze(-1) |
| logits = torch.where(logits < kth, torch.full_like(logits, float("-inf")), logits) |
|
|
| probs = F.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
|
|
| x = torch.cat([x, next_tok], dim=1) |
|
|
| if eos_id is not None and next_tok.item() == eos_id and step >= 2: |
| break |
|
|
| new_ids = x[0, len(ids):].tolist() |
| text = self.tokenizer.decode(new_ids).strip() |
| text = re.sub(r"\s+", " ", text).strip() |
| text = text.replace("Réponse :", "").replace("Réponse:", "").strip() |
| return text |
|
|
|
|
| def load_questions(path: Optional[str]) -> List[Dict[str, Any]]: |
| if not path: |
| return DEFAULT_QUESTIONS |
| with open(path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| if not isinstance(data, list): |
| raise ValueError("Le fichier questions doit contenir une liste JSON.") |
| return data |
|
|
|
|
| def format_bar(score: float, width: int = 20) -> str: |
| n = max(0, min(width, int(round(score * width)))) |
| return "█" * n + "░" * (width - n) |
|
|
|
|
| def save_reports(output_dir: Path, summary: Dict[str, Any]) -> None: |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| json_path = output_dir / "qa_test_report_simple.json" |
| txt_path = output_dir / "qa_test_report_simple.txt" |
|
|
| with open(json_path, "w", encoding="utf-8") as f: |
| json.dump(summary, f, ensure_ascii=False, indent=2) |
|
|
| with open(txt_path, "w", encoding="utf-8") as f: |
| f.write("TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ\n") |
| f.write("=" * 60 + "\n\n") |
| for r in summary["results"]: |
| f.write(f"[{r['id']:02d}] {r['category']}\n") |
| f.write(f" User : {r['question']}\n") |
| f.write(f" Assistant : {r['answer']}\n") |
| if r["reference"]: |
| f.write(f" Référence : {r['reference']}\n") |
| if r["overlap_score"] is not None: |
| f.write(f" Overlap : {r['overlap_score']:.0%}\n") |
| if r["exact_match"] is not None: |
| f.write(f" ExactMatch: {r['exact_match']}\n") |
| f.write(f" Latence : {r['latency_s']}s\n\n") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser("Test QA simple pour modèle Aramix déjà entraîné") |
| parser.add_argument("--repo_dir", type=str, default="./aramix_h100") |
| parser.add_argument("--train_script", type=str, default=None) |
| parser.add_argument("--ckpt", type=str, default=None) |
| parser.add_argument("--config", type=str, default=None) |
| parser.add_argument("--questions", type=str, default=None) |
| parser.add_argument("--max_new_tokens", type=int, default=96) |
| parser.add_argument("--temperature", type=float, default=0.4) |
| parser.add_argument("--top_k", type=int, default=40) |
| parser.add_argument("--repetition_penalty", type=float, default=1.12) |
| parser.add_argument("--device", type=str, default=None) |
| parser.add_argument("--save_report", action="store_true") |
| args = parser.parse_args() |
|
|
| repo_dir = Path(args.repo_dir).resolve() |
| train_script, ckpt_path, config_path, tokenizer_dir = infer_repo_defaults(repo_dir) |
|
|
| if args.train_script: |
| train_script = Path(args.train_script).resolve() |
| if args.ckpt: |
| ckpt_path = Path(args.ckpt).resolve() |
| if args.config: |
| config_path = Path(args.config).resolve() |
|
|
| if not train_script.exists(): |
| raise FileNotFoundError(f"Script train introuvable: {train_script}") |
| if not ckpt_path.exists(): |
| raise FileNotFoundError(f"Checkpoint introuvable: {ckpt_path}") |
| if not config_path.exists(): |
| print(f"[WARN] config.json introuvable: {config_path} — fallback sur GPTConfig(vocab_size=...).") |
|
|
| questions = load_questions(args.questions) |
| tester = AramixChatTester( |
| repo_dir=repo_dir, |
| train_script=train_script, |
| ckpt_path=ckpt_path, |
| config_path=config_path, |
| device=args.device, |
| ) |
|
|
| results: List[Dict[str, Any]] = [] |
| categories: Dict[str, List[Dict[str, Any]]] = {} |
|
|
| print("\n" + "═" * 70) |
| print(" TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ") |
| print("═" * 70) |
| print(f"Repo : {repo_dir}") |
| print(f"Train script: {train_script}") |
| print(f"Checkpoint : {ckpt_path}") |
| print(f"Config : {config_path}") |
| print(f"Tokenizer : {tokenizer_dir}") |
| print(f"Device : {tester.device}") |
| print(f"Questions : {len(questions)}") |
| print("═" * 70 + "\n") |
|
|
| for i, item in enumerate(questions, 1): |
| q = item["question"] |
| ref = item.get("reference") |
| cat = item.get("category", "Général") |
|
|
| t0 = time.time() |
| ans = tester.generate( |
| q, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| repetition_penalty=args.repetition_penalty, |
| ) |
| latency = time.time() - t0 |
|
|
| overlap = lexical_overlap(ref, ans) |
| em = exact_match(ref, ans) |
|
|
| entry = { |
| "id": i, |
| "category": cat, |
| "question": q, |
| "answer": ans, |
| "reference": ref, |
| "latency_s": round(latency, 3), |
| "tokens_generated_approx": len(ans.split()), |
| "overlap_score": None if overlap is None else round(overlap, 4), |
| "exact_match": em, |
| } |
| results.append(entry) |
| categories.setdefault(cat, []).append(entry) |
|
|
| overlap_str = f"{overlap:.0%}" if overlap is not None else "n/a" |
| em_str = "✓" if em else ("✗" if em is not None else "n/a") |
|
|
| print("─" * 70) |
| print(f"[{i:02d}] [{cat}] overlap={overlap_str} | EM={em_str}") |
| print(f" User : {q}") |
| print(f" Assistant : {ans}") |
| if ref: |
| print(f" Référence : {ref}") |
| print(f" ⏱ {latency:.2f}s | ~{entry['tokens_generated_approx']} mots\n") |
|
|
| scored_overlap = [r["overlap_score"] for r in results if r["overlap_score"] is not None] |
| scored_em = [r["exact_match"] for r in results if r["exact_match"] is not None] |
|
|
| avg_overlap = sum(scored_overlap) / len(scored_overlap) if scored_overlap else 0.0 |
| em_rate = sum(1 for x in scored_em if x) / len(scored_em) if scored_em else 0.0 |
| avg_latency = sum(r["latency_s"] for r in results) / len(results) if results else 0.0 |
| avg_words = sum(r["tokens_generated_approx"] for r in results) / len(results) if results else 0.0 |
|
|
| cat_scores: Dict[str, float] = {} |
| for cat, items in categories.items(): |
| vals = [r["overlap_score"] for r in items if r["overlap_score"] is not None] |
| cat_scores[cat] = (sum(vals) / len(vals)) if vals else 0.0 |
|
|
| summary = { |
| "repo_dir": str(repo_dir), |
| "train_script": str(train_script), |
| "checkpoint": str(ckpt_path), |
| "config_path": str(config_path), |
| "tokenizer_dir": str(tokenizer_dir), |
| "device": str(tester.device), |
| "total_questions": len(results), |
| "avg_overlap_score": round(avg_overlap, 4), |
| "exact_match_rate": round(em_rate, 4), |
| "avg_latency_s": round(avg_latency, 3), |
| "avg_words_generated": round(avg_words, 1), |
| "scores_by_category": {k: round(v, 4) for k, v in cat_scores.items()}, |
| "results": results, |
| } |
|
|
| print("═" * 70) |
| print(" RÉSUMÉ") |
| print("═" * 70) |
| print(f"Questions testées : {len(results)}") |
| print(f"Overlap moyen : {avg_overlap:.1%}") |
| print(f"Exact match : {em_rate:.1%}") |
| print(f"Latence moyenne : {avg_latency:.2f}s") |
| print(f"Mots moyens : {avg_words:.1f}") |
| print("Scores / catégorie:") |
| for cat, score in sorted(cat_scores.items()): |
| print(f" {cat:<15} {format_bar(score)} {score:.0%}") |
| print("═" * 70) |
|
|
| if args.save_report: |
| save_reports(repo_dir, summary) |
| print(f"Rapports sauvegardés dans : {repo_dir / 'qa_test_report_simple.json'}") |
| print(f" {repo_dir / 'qa_test_report_simple.txt'}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|