|
|
"""Quick evaluation script - shows results as they come in.""" |
|
|
|
|
|
import asyncio |
|
|
import json |
|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
load_dotenv() |
|
|
|
|
|
import tinker |
|
|
from tinker import types |
|
|
from tinker_cookbook import renderers |
|
|
from tinker_cookbook.tokenizer_utils import get_tokenizer |
|
|
|
|
|
VALID_CATEGORIES = { |
|
|
"company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts", |
|
|
"company.business_priorities", "company.tools_config", "company.performance_context", |
|
|
"user.communication_style", "user.strategic_approach", "user.role_context", |
|
|
"user.workflow_patterns", "user.session_history", "user.interaction_preferences", |
|
|
"none" |
|
|
} |
|
|
|
|
|
def parse_prediction(text): |
|
|
if not text or not text.strip(): |
|
|
return set() |
|
|
cats = [c.strip().lower() for c in text.split(",")] |
|
|
return {c for c in cats if c in VALID_CATEGORIES} |
|
|
|
|
|
def compute_f1(predicted, gold): |
|
|
if not predicted and not gold: |
|
|
return 1.0 |
|
|
if not predicted or not gold: |
|
|
return 0.0 |
|
|
tp = len(predicted & gold) |
|
|
prec = tp / len(predicted) |
|
|
rec = tp / len(gold) |
|
|
return 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0 |
|
|
|
|
|
async def eval_model(name, checkpoint, model_name, renderer_name, test_data, n=10): |
|
|
print(f"\n{'='*60}", flush=True) |
|
|
print(f"EVALUATING: {name}", flush=True) |
|
|
print(f"Checkpoint: {checkpoint}", flush=True) |
|
|
print(f"{'='*60}", flush=True) |
|
|
|
|
|
service_client = tinker.ServiceClient() |
|
|
sampling_client = service_client.create_sampling_client(model_path=checkpoint) |
|
|
tokenizer = get_tokenizer(model_name) |
|
|
renderer = renderers.get_renderer(name=renderer_name, tokenizer=tokenizer) |
|
|
stop = renderer.get_stop_sequences() |
|
|
params = types.SamplingParams(max_tokens=100, temperature=0.1, stop=stop) |
|
|
|
|
|
correct = exact = total_f1 = 0 |
|
|
|
|
|
for i, item in enumerate(test_data[:n]): |
|
|
messages = item["messages"][:-1] |
|
|
gold = set(item["categories"]) |
|
|
|
|
|
prompt = renderer.build_generation_prompt(messages) |
|
|
result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result() |
|
|
response, _ = renderer.parse_response(result.sequences[0].tokens) |
|
|
predicted = parse_prediction(response["content"]) |
|
|
|
|
|
f1 = compute_f1(predicted, gold) |
|
|
total_f1 += f1 |
|
|
if predicted & gold: |
|
|
correct += 1 |
|
|
if predicted == gold: |
|
|
exact += 1 |
|
|
|
|
|
status = "✓" if predicted & gold else "✗" |
|
|
ex = "EXACT" if predicted == gold else "" |
|
|
print(f"[{i+1:2d}] {status} Gold: {sorted(gold)} | Pred: {sorted(predicted)} | F1={f1:.2f} {ex}", flush=True) |
|
|
|
|
|
print(f"\n--- SUMMARY ({n} examples) ---", flush=True) |
|
|
print(f"Any Match: {correct}/{n} ({correct/n:.0%})", flush=True) |
|
|
print(f"Exact Match: {exact}/{n} ({exact/n:.0%})", flush=True) |
|
|
print(f"Avg F1: {total_f1/n:.2f}", flush=True) |
|
|
|
|
|
return {"any": correct/n, "exact": exact/n, "f1": total_f1/n} |
|
|
|
|
|
async def main(): |
|
|
|
|
|
with open("training/processed_data/test_data.json", "r") as f: |
|
|
test_data = json.load(f) |
|
|
|
|
|
print(f"Loaded {len(test_data)} test examples", flush=True) |
|
|
|
|
|
|
|
|
llama_result = await eval_model( |
|
|
name="Llama-8B RL (iter 12)", |
|
|
checkpoint="tinker://4f4bae1f-5a95-5f53-a55a-a14f2872825c:train:0/sampler_weights/rl_iter_012", |
|
|
model_name="meta-llama/Llama-3.1-8B", |
|
|
renderer_name="llama3", |
|
|
test_data=test_data, |
|
|
n=15 |
|
|
) |
|
|
|
|
|
|
|
|
qwen_result = await eval_model( |
|
|
name="Qwen3-32B SFT (step 30)", |
|
|
checkpoint="tinker://b7be2502-e321-59ee-9477-f3fd8a52ab4e:train:0/sampler_weights/sft_step_0030", |
|
|
model_name="Qwen/Qwen3-32B", |
|
|
renderer_name="qwen3", |
|
|
test_data=test_data, |
|
|
n=15 |
|
|
) |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}", flush=True) |
|
|
print("COMPARISON", flush=True) |
|
|
print(f"{'='*60}", flush=True) |
|
|
print(f"{'Model':<25} {'Any Match':<12} {'Exact':<12} {'F1':<10}", flush=True) |
|
|
print("-" * 60, flush=True) |
|
|
print(f"{'Llama-8B RL':<25} {llama_result['any']:<12.0%} {llama_result['exact']:<12.0%} {llama_result['f1']:<10.2f}", flush=True) |
|
|
print(f"{'Qwen3-32B SFT':<25} {qwen_result['any']:<12.0%} {qwen_result['exact']:<12.0%} {qwen_result['f1']:<10.2f}", flush=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main()) |
|
|
|
|
|
|