AI / scripts /eval_chat_prompts.py
shsplas's picture
Upload 15 files
c5f49b9 verified
Raw
History Blame Contribute Delete
7.57 kB
import argparse
import json
import os
import re
from datetime import datetime
import torch
from GPT_model import GPT, block_size
from chat import collect_banned_token_ids, generate, load_tokenizer, DEFAULT_SYSTEM_PROMPT
def parse_args():
p = argparse.ArgumentParser(description="Batch prompt evaluator for Jarvis chat checkpoints")
p.add_argument("--ckpt", default="cpu_gpt_jarvis_rebuild_l6_v2048_best.pth")
p.add_argument("--prompts-file", default=os.path.join("data", "jarvis_eval_prompts.txt"))
p.add_argument("--out-prefix", default="jarvis_eval")
p.add_argument("--num-prompts", type=int, default=120)
p.add_argument("--temperature", type=float, default=0.62)
p.add_argument("--top-k", type=int, default=32)
p.add_argument("--top-p", type=float, default=0.9)
p.add_argument("--repetition-penalty", type=float, default=1.12)
p.add_argument("--no-repeat-ngram", type=int, default=3)
p.add_argument("--max-new-tokens", type=int, default=96)
p.add_argument("--min-new-tokens", type=int, default=10)
p.add_argument("--max-context-tokens", type=int, default=block_size)
p.add_argument("--system-prompt", default=DEFAULT_SYSTEM_PROMPT)
p.add_argument("--ban-empty-tokens", action=argparse.BooleanOptionalAction, default=True)
p.add_argument("--threads", type=int, default=max(1, min(6, (os.cpu_count() or 4) - 2)))
p.add_argument("--interop-threads", type=int, default=1)
p.add_argument("--seed", type=int, default=1337)
p.add_argument("--int8", action=argparse.BooleanOptionalAction, default=False)
return p.parse_args()
def parse_prompts(path):
if not os.path.exists(path):
raise FileNotFoundError(f"Prompts file not found: {path}")
text = open(path, "r", encoding="utf-8", errors="ignore").read()
chunks = [c.strip() for c in re.split(r"\n\s*\n", text) if c.strip()]
prompts = []
for chunk in chunks:
match = re.search(r"User:\s*(.*?)(?:\nAssistant:|$)", chunk, flags=re.S)
if not match:
continue
user = re.sub(r"\s+", " ", match.group(1)).strip()
if user:
prompts.append(user)
return prompts
def likely_gibberish(text):
if not text:
return True
words = re.findall(r"[A-Za-z]{18,}", text)
weird_words = [w for w in words if len(set(w.lower())) > 12]
if len(weird_words) >= 3:
return True
alpha = sum(ch.isalpha() for ch in text)
printable = sum((31 < ord(ch) < 127) or ch in "\n\t\r" for ch in text)
if printable < max(1, int(0.9 * len(text))):
return True
if alpha < 10:
return True
return False
def repetition_score(text):
tokens = re.findall(r"\w+", text.lower())
if len(tokens) < 6:
return 0.0
trigrams = [" ".join(tokens[i : i + 3]) for i in range(len(tokens) - 2)]
if not trigrams:
return 0.0
unique = len(set(trigrams))
return 1.0 - (unique / len(trigrams))
def load_model(ckpt_path, vocab_size, use_int8):
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
model = GPT(vocab_size).to("cpu")
ckpt = torch.load(ckpt_path, map_location="cpu")
ckpt_vocab = ckpt.get("vocab_size")
if ckpt_vocab is not None and int(ckpt_vocab) != vocab_size:
raise RuntimeError(
f"Checkpoint/tokenizer mismatch: ckpt vocab_size={ckpt_vocab}, tokenizer vocab_size={vocab_size}"
)
model.load_state_dict(ckpt["model"], strict=True)
model.eval()
if use_int8:
model = torch.ao.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
model.eval()
return model, ckpt
def main():
args = parse_args()
torch.manual_seed(args.seed)
torch.set_num_threads(args.threads)
torch.set_num_interop_threads(args.interop_threads)
tokenizer = load_tokenizer()
prompts = parse_prompts(args.prompts_file)
if not prompts:
raise RuntimeError("No valid prompts found.")
prompts = prompts[: max(1, args.num_prompts)]
model, ckpt = load_model(args.ckpt, len(tokenizer.vocab), args.int8)
max_ctx = max(32, min(args.max_context_tokens, block_size))
banned_token_ids = collect_banned_token_ids(tokenizer, args.ban_empty_tokens)
bootstrap = ""
if args.system_prompt.strip():
bootstrap = f"User: {args.system_prompt.strip()}\nAssistant: Understood.\n"
bootstrap_tokens = tokenizer.encode(bootstrap)
rows = []
empty_count = 0
gibberish_count = 0
repetition_scores = []
lengths = []
for i, user in enumerate(prompts, start=1):
turn_prefix = f"\nUser: {user}\nAssistant:"
prompt_tokens = (bootstrap_tokens + tokenizer.encode(turn_prefix))[-max_ctx:]
reply, _ = generate(
model=model,
tokenizer=tokenizer,
prompt_tokens=prompt_tokens,
max_new_tokens=args.max_new_tokens,
min_new_tokens=args.min_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
no_repeat_ngram=args.no_repeat_ngram,
max_context_tokens=max_ctx,
banned_token_ids=banned_token_ids,
)
reply = reply.strip()
if not reply:
reply = "(empty)"
empty_count += 1
if likely_gibberish(reply):
gibberish_count += 1
rep = repetition_score(reply)
repetition_scores.append(rep)
lengths.append(len(reply))
rows.append(
{
"idx": i,
"user": user,
"assistant": reply,
"repetition_score": round(rep, 4),
}
)
avg_len = sum(lengths) / max(1, len(lengths))
avg_rep = sum(repetition_scores) / max(1, len(repetition_scores))
summary = {
"timestamp": datetime.now().isoformat(timespec="seconds"),
"checkpoint": args.ckpt,
"ckpt_step": ckpt.get("step"),
"ckpt_best_val": ckpt.get("best_val"),
"prompts_file": args.prompts_file,
"num_prompts": len(rows),
"empty_count": empty_count,
"empty_rate": round(empty_count / max(1, len(rows)), 4),
"likely_gibberish_count": gibberish_count,
"likely_gibberish_rate": round(gibberish_count / max(1, len(rows)), 4),
"avg_response_chars": round(avg_len, 2),
"avg_repetition_score": round(avg_rep, 4),
"decode": {
"temperature": args.temperature,
"top_k": args.top_k,
"top_p": args.top_p,
"repetition_penalty": args.repetition_penalty,
"no_repeat_ngram": args.no_repeat_ngram,
},
}
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
out_json = f"{args.out_prefix}_{ts}.json"
out_txt = f"{args.out_prefix}_{ts}.txt"
with open(out_json, "w", encoding="utf-8") as f:
json.dump({"summary": summary, "samples": rows}, f, indent=2)
with open(out_txt, "w", encoding="utf-8") as f:
f.write(json.dumps(summary, indent=2))
f.write("\n\n")
for row in rows:
f.write(f"[{row['idx']}] User: {row['user']}\n")
f.write(f"[{row['idx']}] Assistant: {row['assistant']}\n")
f.write(f"[{row['idx']}] repetition_score={row['repetition_score']}\n\n")
print(json.dumps(summary, indent=2))
print(f"Saved: {out_json}")
print(f"Saved: {out_txt}")
if __name__ == "__main__":
main()