import argparse import json import os import re from datetime import datetime import torch from GPT_model import GPT, block_size from chat import collect_banned_token_ids, generate, load_tokenizer, DEFAULT_SYSTEM_PROMPT def parse_args(): p = argparse.ArgumentParser(description="Batch prompt evaluator for Jarvis chat checkpoints") p.add_argument("--ckpt", default="cpu_gpt_jarvis_rebuild_l6_v2048_best.pth") p.add_argument("--prompts-file", default=os.path.join("data", "jarvis_eval_prompts.txt")) p.add_argument("--out-prefix", default="jarvis_eval") p.add_argument("--num-prompts", type=int, default=120) p.add_argument("--temperature", type=float, default=0.62) p.add_argument("--top-k", type=int, default=32) p.add_argument("--top-p", type=float, default=0.9) p.add_argument("--repetition-penalty", type=float, default=1.12) p.add_argument("--no-repeat-ngram", type=int, default=3) p.add_argument("--max-new-tokens", type=int, default=96) p.add_argument("--min-new-tokens", type=int, default=10) p.add_argument("--max-context-tokens", type=int, default=block_size) p.add_argument("--system-prompt", default=DEFAULT_SYSTEM_PROMPT) p.add_argument("--ban-empty-tokens", action=argparse.BooleanOptionalAction, default=True) p.add_argument("--threads", type=int, default=max(1, min(6, (os.cpu_count() or 4) - 2))) p.add_argument("--interop-threads", type=int, default=1) p.add_argument("--seed", type=int, default=1337) p.add_argument("--int8", action=argparse.BooleanOptionalAction, default=False) return p.parse_args() def parse_prompts(path): if not os.path.exists(path): raise FileNotFoundError(f"Prompts file not found: {path}") text = open(path, "r", encoding="utf-8", errors="ignore").read() chunks = [c.strip() for c in re.split(r"\n\s*\n", text) if c.strip()] prompts = [] for chunk in chunks: match = re.search(r"User:\s*(.*?)(?:\nAssistant:|$)", chunk, flags=re.S) if not match: continue user = re.sub(r"\s+", " ", match.group(1)).strip() if user: prompts.append(user) return prompts def likely_gibberish(text): if not text: return True words = re.findall(r"[A-Za-z]{18,}", text) weird_words = [w for w in words if len(set(w.lower())) > 12] if len(weird_words) >= 3: return True alpha = sum(ch.isalpha() for ch in text) printable = sum((31 < ord(ch) < 127) or ch in "\n\t\r" for ch in text) if printable < max(1, int(0.9 * len(text))): return True if alpha < 10: return True return False def repetition_score(text): tokens = re.findall(r"\w+", text.lower()) if len(tokens) < 6: return 0.0 trigrams = [" ".join(tokens[i : i + 3]) for i in range(len(tokens) - 2)] if not trigrams: return 0.0 unique = len(set(trigrams)) return 1.0 - (unique / len(trigrams)) def load_model(ckpt_path, vocab_size, use_int8): if not os.path.exists(ckpt_path): raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") model = GPT(vocab_size).to("cpu") ckpt = torch.load(ckpt_path, map_location="cpu") ckpt_vocab = ckpt.get("vocab_size") if ckpt_vocab is not None and int(ckpt_vocab) != vocab_size: raise RuntimeError( f"Checkpoint/tokenizer mismatch: ckpt vocab_size={ckpt_vocab}, tokenizer vocab_size={vocab_size}" ) model.load_state_dict(ckpt["model"], strict=True) model.eval() if use_int8: model = torch.ao.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) model.eval() return model, ckpt def main(): args = parse_args() torch.manual_seed(args.seed) torch.set_num_threads(args.threads) torch.set_num_interop_threads(args.interop_threads) tokenizer = load_tokenizer() prompts = parse_prompts(args.prompts_file) if not prompts: raise RuntimeError("No valid prompts found.") prompts = prompts[: max(1, args.num_prompts)] model, ckpt = load_model(args.ckpt, len(tokenizer.vocab), args.int8) max_ctx = max(32, min(args.max_context_tokens, block_size)) banned_token_ids = collect_banned_token_ids(tokenizer, args.ban_empty_tokens) bootstrap = "" if args.system_prompt.strip(): bootstrap = f"User: {args.system_prompt.strip()}\nAssistant: Understood.\n" bootstrap_tokens = tokenizer.encode(bootstrap) rows = [] empty_count = 0 gibberish_count = 0 repetition_scores = [] lengths = [] for i, user in enumerate(prompts, start=1): turn_prefix = f"\nUser: {user}\nAssistant:" prompt_tokens = (bootstrap_tokens + tokenizer.encode(turn_prefix))[-max_ctx:] reply, _ = generate( model=model, tokenizer=tokenizer, prompt_tokens=prompt_tokens, max_new_tokens=args.max_new_tokens, min_new_tokens=args.min_new_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, no_repeat_ngram=args.no_repeat_ngram, max_context_tokens=max_ctx, banned_token_ids=banned_token_ids, ) reply = reply.strip() if not reply: reply = "(empty)" empty_count += 1 if likely_gibberish(reply): gibberish_count += 1 rep = repetition_score(reply) repetition_scores.append(rep) lengths.append(len(reply)) rows.append( { "idx": i, "user": user, "assistant": reply, "repetition_score": round(rep, 4), } ) avg_len = sum(lengths) / max(1, len(lengths)) avg_rep = sum(repetition_scores) / max(1, len(repetition_scores)) summary = { "timestamp": datetime.now().isoformat(timespec="seconds"), "checkpoint": args.ckpt, "ckpt_step": ckpt.get("step"), "ckpt_best_val": ckpt.get("best_val"), "prompts_file": args.prompts_file, "num_prompts": len(rows), "empty_count": empty_count, "empty_rate": round(empty_count / max(1, len(rows)), 4), "likely_gibberish_count": gibberish_count, "likely_gibberish_rate": round(gibberish_count / max(1, len(rows)), 4), "avg_response_chars": round(avg_len, 2), "avg_repetition_score": round(avg_rep, 4), "decode": { "temperature": args.temperature, "top_k": args.top_k, "top_p": args.top_p, "repetition_penalty": args.repetition_penalty, "no_repeat_ngram": args.no_repeat_ngram, }, } ts = datetime.now().strftime("%Y%m%d_%H%M%S") out_json = f"{args.out_prefix}_{ts}.json" out_txt = f"{args.out_prefix}_{ts}.txt" with open(out_json, "w", encoding="utf-8") as f: json.dump({"summary": summary, "samples": rows}, f, indent=2) with open(out_txt, "w", encoding="utf-8") as f: f.write(json.dumps(summary, indent=2)) f.write("\n\n") for row in rows: f.write(f"[{row['idx']}] User: {row['user']}\n") f.write(f"[{row['idx']}] Assistant: {row['assistant']}\n") f.write(f"[{row['idx']}] repetition_score={row['repetition_score']}\n\n") print(json.dumps(summary, indent=2)) print(f"Saved: {out_json}") print(f"Saved: {out_txt}") if __name__ == "__main__": main()