File size: 4,446 Bytes
685d968 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
"""Quick evaluation script - shows results as they come in."""
import asyncio
import json
import os
from dotenv import load_dotenv
load_dotenv()
import tinker
from tinker import types
from tinker_cookbook import renderers
from tinker_cookbook.tokenizer_utils import get_tokenizer
VALID_CATEGORIES = {
"company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts",
"company.business_priorities", "company.tools_config", "company.performance_context",
"user.communication_style", "user.strategic_approach", "user.role_context",
"user.workflow_patterns", "user.session_history", "user.interaction_preferences",
"none"
}
def parse_prediction(text):
if not text or not text.strip():
return set()
cats = [c.strip().lower() for c in text.split(",")]
return {c for c in cats if c in VALID_CATEGORIES}
def compute_f1(predicted, gold):
if not predicted and not gold:
return 1.0
if not predicted or not gold:
return 0.0
tp = len(predicted & gold)
prec = tp / len(predicted)
rec = tp / len(gold)
return 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
async def eval_model(name, checkpoint, model_name, renderer_name, test_data, n=10):
print(f"\n{'='*60}", flush=True)
print(f"EVALUATING: {name}", flush=True)
print(f"Checkpoint: {checkpoint}", flush=True)
print(f"{'='*60}", flush=True)
service_client = tinker.ServiceClient()
sampling_client = service_client.create_sampling_client(model_path=checkpoint)
tokenizer = get_tokenizer(model_name)
renderer = renderers.get_renderer(name=renderer_name, tokenizer=tokenizer)
stop = renderer.get_stop_sequences()
params = types.SamplingParams(max_tokens=100, temperature=0.1, stop=stop)
correct = exact = total_f1 = 0
for i, item in enumerate(test_data[:n]):
messages = item["messages"][:-1]
gold = set(item["categories"])
prompt = renderer.build_generation_prompt(messages)
result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result()
response, _ = renderer.parse_response(result.sequences[0].tokens)
predicted = parse_prediction(response["content"])
f1 = compute_f1(predicted, gold)
total_f1 += f1
if predicted & gold:
correct += 1
if predicted == gold:
exact += 1
status = "✓" if predicted & gold else "✗"
ex = "EXACT" if predicted == gold else ""
print(f"[{i+1:2d}] {status} Gold: {sorted(gold)} | Pred: {sorted(predicted)} | F1={f1:.2f} {ex}", flush=True)
print(f"\n--- SUMMARY ({n} examples) ---", flush=True)
print(f"Any Match: {correct}/{n} ({correct/n:.0%})", flush=True)
print(f"Exact Match: {exact}/{n} ({exact/n:.0%})", flush=True)
print(f"Avg F1: {total_f1/n:.2f}", flush=True)
return {"any": correct/n, "exact": exact/n, "f1": total_f1/n}
async def main():
# Load test data
with open("training/processed_data/test_data.json", "r") as f:
test_data = json.load(f)
print(f"Loaded {len(test_data)} test examples", flush=True)
# Evaluate Llama-8B RL (latest)
llama_result = await eval_model(
name="Llama-8B RL (iter 12)",
checkpoint="tinker://4f4bae1f-5a95-5f53-a55a-a14f2872825c:train:0/sampler_weights/rl_iter_012",
model_name="meta-llama/Llama-3.1-8B",
renderer_name="llama3",
test_data=test_data,
n=15
)
# Evaluate Qwen3-32B SFT
qwen_result = await eval_model(
name="Qwen3-32B SFT (step 30)",
checkpoint="tinker://b7be2502-e321-59ee-9477-f3fd8a52ab4e:train:0/sampler_weights/sft_step_0030",
model_name="Qwen/Qwen3-32B",
renderer_name="qwen3",
test_data=test_data,
n=15
)
# Comparison
print(f"\n{'='*60}", flush=True)
print("COMPARISON", flush=True)
print(f"{'='*60}", flush=True)
print(f"{'Model':<25} {'Any Match':<12} {'Exact':<12} {'F1':<10}", flush=True)
print("-" * 60, flush=True)
print(f"{'Llama-8B RL':<25} {llama_result['any']:<12.0%} {llama_result['exact']:<12.0%} {llama_result['f1']:<10.2f}", flush=True)
print(f"{'Qwen3-32B SFT':<25} {qwen_result['any']:<12.0%} {qwen_result['exact']:<12.0%} {qwen_result['f1']:<10.2f}", flush=True)
if __name__ == "__main__":
asyncio.run(main())
|