WitGym / scripts /compare_witgym_zero_shot.py
akshay4's picture
Upload folder using huggingface_hub
949430d verified
Raw
History Blame Contribute Delete
3.6 kB
"""WitGym vs zero-shot baseline — same model, same API backend."""
import json
import sys
import time
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
# 7 diverse scenarios (5 canonical + 2 extra)
TEST_INPUTS = [
"I just got promoted to manager and I have no idea what I'm doing.",
"My coworker keeps stealing my lunch from the fridge.",
"I've been cc'd on an email chain I definitely should not be reading.",
"I'm pretending to understand cryptocurrency at dinner parties.",
"My therapist fell asleep during our session.",
"My boss called a quick sync that's been going for two hours.",
"I waved back at someone who wasn't waving at me.",
]
ZERO_SHOT_PROMPT = """You are a witty friend. The user shares an awkward or funny situation.
Reply with ONE sharp, funny line. Maximum 2 sentences. No preamble.
User: {user_input}
Funny reply:"""
def zero_shot_reply(user_input: str, model, tokenizer) -> str:
from witgym.model import generate_text
prompt = ZERO_SHOT_PROMPT.format(user_input=user_input)
raw = generate_text(prompt, model, tokenizer, config_type="generate")
return raw.strip().strip('"')
def run_experiment(index_path: str = "data/index.npz") -> dict:
import os
os.environ.setdefault("LLM_BACKEND", "hf_api")
from witgym import config
from witgym.engine import WitGymEngine
from witgym.model import load_model
if not Path(index_path).exists():
print(f"[ERROR] Index missing: {index_path}")
sys.exit(1)
print(f"Model: {config.LLM_MODEL_ID} | Backend: {config.LLM_BACKEND}")
print(f"Providers: {config.HF_INFERENCE_PROVIDERS}\n")
# Shared embedder/index only — fresh engine per input (no conversation bleed)
shared = None
model, tokenizer = load_model()
rows = []
for i, user_input in enumerate(TEST_INPUTS, 1):
print(f"[{i}/{len(TEST_INPUTS)}] {user_input[:65]}...")
from witgym.engine import WitGymEngine, SharedResources
nonlocal_shared = shared
if nonlocal_shared is None:
nonlocal_shared = SharedResources(index_path=index_path)
shared = nonlocal_shared
engine = WitGymEngine(index_path=index_path, resources=nonlocal_shared)
t0 = time.time()
wg = engine.respond(user_input)
wg_elapsed = time.time() - t0
t0 = time.time()
zs = zero_shot_reply(user_input, model, tokenizer)
zs_elapsed = time.time() - t0
row = {
"input": user_input,
"witgym": {
"text": wg.selected,
"words": len(wg.selected.split()),
"latency_s": round(wg_elapsed, 1),
"archetype": wg.metadata.archetype.value,
"candidates": [c.text for c in wg.candidates],
},
"zero_shot": {
"text": zs,
"words": len(zs.split()),
"latency_s": round(zs_elapsed, 1),
},
}
rows.append(row)
print(f" WitGym ({row['witgym']['words']}w, {wg_elapsed:.1f}s): {wg.selected}")
print(f" Zero ({row['zero_shot']['words']}w, {zs_elapsed:.1f}s): {zs}\n")
out = {
"model": config.LLM_MODEL_ID,
"backend": config.LLM_BACKEND,
"zero_shot_prompt": ZERO_SHOT_PROMPT,
"results": rows,
}
out_path = Path("data/eval_witgym_vs_zero_shot.json")
out_path.write_text(json.dumps(out, indent=2), encoding="utf-8")
print(f"Saved → {out_path}")
return out
if __name__ == "__main__":
run_experiment()