|
|
""" |
|
|
Benchmark: Memory Routing Model Evaluation |
|
|
|
|
|
This script evaluates our trained model against: |
|
|
1. Base model (untrained Llama-3.1-8B) |
|
|
2. Our SFT model |
|
|
3. Our RL model |
|
|
|
|
|
We measure: |
|
|
- Classification metrics (F1, precision, recall) |
|
|
- Task-specific metrics (temporal alignment, scope parity) |
|
|
- Efficiency (tokens generated, latency) |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import json |
|
|
import time |
|
|
import os |
|
|
import numpy as np |
|
|
from typing import List, Dict, Any, Tuple |
|
|
from collections import Counter |
|
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BenchmarkConfig: |
|
|
base_model: str = "meta-llama/Llama-3.1-8B" |
|
|
renderer_name: str = "llama3" |
|
|
test_data_path: str = "training/processed_data/test_data.json" |
|
|
output_dir: str = "training/benchmarks" |
|
|
|
|
|
|
|
|
sft_checkpoint: str = "" |
|
|
rl_checkpoint: str = "" |
|
|
|
|
|
|
|
|
VALID_CATEGORIES = { |
|
|
"company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts", |
|
|
"company.business_priorities", "company.tools_config", "company.performance_context", |
|
|
"user.communication_style", "user.strategic_approach", "user.role_context", |
|
|
"user.workflow_patterns", "user.session_history", "user.interaction_preferences", |
|
|
"none" |
|
|
} |
|
|
|
|
|
CATEGORY_PERSISTENCE = { |
|
|
"company.brand_core": "long", "company.strategic_signatures": "long", |
|
|
"company.knowledge_artifacts": "long", "company.business_priorities": "short", |
|
|
"company.tools_config": "medium", "company.performance_context": "rolling", |
|
|
"user.communication_style": "long", "user.strategic_approach": "long", |
|
|
"user.role_context": "medium", "user.workflow_patterns": "medium", |
|
|
"user.session_history": "short", "user.interaction_preferences": "evolving", |
|
|
"none": "short" |
|
|
} |
|
|
|
|
|
SYSTEM_PROMPT = """You route marketing conversations into structured memory categories. |
|
|
|
|
|
Available categories: |
|
|
- company.brand_core: Voice, values, positioning |
|
|
- company.strategic_signatures: Decision frameworks |
|
|
- company.knowledge_artifacts: Docs, style guides |
|
|
- company.business_priorities: Quarterly goals, campaigns |
|
|
- company.tools_config: Integrations, settings |
|
|
- company.performance_context: Campaign metrics |
|
|
- user.communication_style: Tone, format expectations |
|
|
- user.strategic_approach: Personal priorities |
|
|
- user.role_context: Title, scope |
|
|
- user.workflow_patterns: Review cadence |
|
|
- user.session_history: Recent context |
|
|
- user.interaction_preferences: Coaching style |
|
|
- none: Irrelevant or transactional |
|
|
|
|
|
Respond with comma-separated categories only. No explanations.""" |
|
|
|
|
|
|
|
|
def parse_prediction(text: str) -> set: |
|
|
"""Parse model output into category set.""" |
|
|
if not text: |
|
|
return set() |
|
|
|
|
|
categories = set() |
|
|
for part in text.split(","): |
|
|
cat = part.strip().lower() |
|
|
if cat in VALID_CATEGORIES: |
|
|
categories.add(cat) |
|
|
|
|
|
|
|
|
if "none" in categories and len(categories) > 1: |
|
|
categories.discard("none") |
|
|
|
|
|
return categories |
|
|
|
|
|
|
|
|
def compute_metrics(predicted: set, gold: set) -> Dict[str, float]: |
|
|
"""Compute all evaluation metrics for a single example.""" |
|
|
metrics = {} |
|
|
|
|
|
|
|
|
tp = len(predicted & gold) |
|
|
metrics["precision"] = tp / len(predicted) if predicted else 0 |
|
|
metrics["recall"] = tp / len(gold) if gold else 0 |
|
|
metrics["f1"] = 2 * metrics["precision"] * metrics["recall"] / (metrics["precision"] + metrics["recall"]) if (metrics["precision"] + metrics["recall"]) > 0 else 0 |
|
|
metrics["exact_match"] = float(predicted == gold) |
|
|
metrics["any_match"] = float(tp > 0) |
|
|
|
|
|
|
|
|
def majority_persistence(cats): |
|
|
if not cats: |
|
|
return "medium" |
|
|
persis = [CATEGORY_PERSISTENCE.get(c, "medium") for c in cats] |
|
|
return Counter(persis).most_common(1)[0][0] |
|
|
|
|
|
pred_pers = majority_persistence(predicted) |
|
|
gold_pers = majority_persistence(gold) |
|
|
metrics["temporal_match"] = float(pred_pers == gold_pers) |
|
|
|
|
|
|
|
|
def get_scope(cats): |
|
|
scopes = set() |
|
|
for c in cats: |
|
|
if c.startswith("company."): |
|
|
scopes.add("company") |
|
|
elif c.startswith("user."): |
|
|
scopes.add("user") |
|
|
if len(scopes) == 2: |
|
|
return "mixed" |
|
|
return scopes.pop() if scopes else "none" |
|
|
|
|
|
metrics["scope_match"] = float(get_scope(predicted) == get_scope(gold)) |
|
|
|
|
|
|
|
|
n = len(predicted) |
|
|
metrics["n_categories"] = n |
|
|
metrics["efficiency"] = 1.0 if n <= 3 else (0.7 if n == 4 else 0.4) |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
async def evaluate_model( |
|
|
service_client, |
|
|
tokenizer, |
|
|
renderer, |
|
|
checkpoint: str, |
|
|
test_data: List[Dict], |
|
|
model_name: str |
|
|
) -> Tuple[Dict, List[Dict]]: |
|
|
"""Evaluate a single model checkpoint.""" |
|
|
from tinker import types |
|
|
|
|
|
print(f"\nEvaluating: {model_name}") |
|
|
print(f"Checkpoint: {checkpoint}") |
|
|
|
|
|
sampling_client = service_client.create_sampling_client(model_path=checkpoint) |
|
|
stop_sequences = renderer.get_stop_sequences() |
|
|
|
|
|
results = [] |
|
|
latencies = [] |
|
|
|
|
|
for i, example in enumerate(test_data): |
|
|
gold = set([c.lower() for c in example.get("categories", [])]) |
|
|
messages = example.get("messages", []) |
|
|
prompt_messages = [m for m in messages if m.get("role") != "assistant"] |
|
|
|
|
|
if not prompt_messages: |
|
|
continue |
|
|
|
|
|
prompt = renderer.build_generation_prompt(prompt_messages) |
|
|
params = types.SamplingParams(max_tokens=50, temperature=0.1, stop=stop_sequences) |
|
|
|
|
|
start_time = time.time() |
|
|
result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result() |
|
|
latency = time.time() - start_time |
|
|
latencies.append(latency) |
|
|
|
|
|
response, success = renderer.parse_response(result.sequences[0].tokens) |
|
|
predicted_text = response["content"] if success else "" |
|
|
predicted = parse_prediction(predicted_text) |
|
|
|
|
|
metrics = compute_metrics(predicted, gold) |
|
|
metrics["gold"] = list(gold) |
|
|
metrics["predicted"] = list(predicted) |
|
|
metrics["predicted_text"] = predicted_text |
|
|
metrics["latency"] = latency |
|
|
metrics["format_valid"] = bool(predicted) or predicted_text.strip().lower() == "none" |
|
|
|
|
|
results.append(metrics) |
|
|
|
|
|
if (i + 1) % 50 == 0: |
|
|
print(f" Progress: {i + 1}/{len(test_data)}") |
|
|
|
|
|
|
|
|
aggregate = { |
|
|
"model_name": model_name, |
|
|
"checkpoint": checkpoint, |
|
|
"n_examples": len(results), |
|
|
"f1": np.mean([r["f1"] for r in results]), |
|
|
"precision": np.mean([r["precision"] for r in results]), |
|
|
"recall": np.mean([r["recall"] for r in results]), |
|
|
"exact_match": np.mean([r["exact_match"] for r in results]), |
|
|
"any_match": np.mean([r["any_match"] for r in results]), |
|
|
"temporal_match": np.mean([r["temporal_match"] for r in results]), |
|
|
"scope_match": np.mean([r["scope_match"] for r in results]), |
|
|
"efficiency": np.mean([r["efficiency"] for r in results]), |
|
|
"format_valid": np.mean([r["format_valid"] for r in results]), |
|
|
"mean_latency": np.mean(latencies), |
|
|
"p95_latency": np.percentile(latencies, 95), |
|
|
} |
|
|
|
|
|
return aggregate, results |
|
|
|
|
|
|
|
|
async def run_benchmark(config: BenchmarkConfig): |
|
|
"""Run full benchmark suite.""" |
|
|
import tinker |
|
|
from tinker_cookbook import renderers |
|
|
from tinker_cookbook.tokenizer_utils import get_tokenizer |
|
|
from dotenv import load_dotenv |
|
|
from datetime import datetime |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
print("=" * 70) |
|
|
print("MEMORY ROUTING BENCHMARK") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
os.makedirs(config.output_dir, exist_ok=True) |
|
|
service_client = tinker.ServiceClient() |
|
|
tokenizer = get_tokenizer(config.base_model) |
|
|
renderer = renderers.get_renderer(name=config.renderer_name, tokenizer=tokenizer) |
|
|
|
|
|
|
|
|
with open(config.test_data_path, "r") as f: |
|
|
test_data = json.load(f) |
|
|
|
|
|
print(f"Test examples: {len(test_data)}") |
|
|
|
|
|
|
|
|
models = [] |
|
|
|
|
|
if config.sft_checkpoint: |
|
|
models.append(("SFT Model (Llama-3.1-8B + LoRA)", config.sft_checkpoint)) |
|
|
|
|
|
if config.rl_checkpoint: |
|
|
models.append(("RL Model (Llama-3.1-8B + LoRA)", config.rl_checkpoint)) |
|
|
|
|
|
|
|
|
all_results = {} |
|
|
|
|
|
for model_name, checkpoint in models: |
|
|
aggregate, details = await evaluate_model( |
|
|
service_client, tokenizer, renderer, checkpoint, test_data, model_name |
|
|
) |
|
|
all_results[model_name] = { |
|
|
"aggregate": aggregate, |
|
|
"details": details |
|
|
} |
|
|
|
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("BENCHMARK RESULTS") |
|
|
print("=" * 70) |
|
|
|
|
|
print(f"\n{'Metric':<20} ", end="") |
|
|
for model_name in all_results: |
|
|
short_name = model_name.split(" (")[0] |
|
|
print(f"{short_name:<15} ", end="") |
|
|
print() |
|
|
print("-" * 70) |
|
|
|
|
|
metrics_to_show = [ |
|
|
("F1 Score", "f1"), |
|
|
("Precision", "precision"), |
|
|
("Recall", "recall"), |
|
|
("Exact Match", "exact_match"), |
|
|
("Any Match", "any_match"), |
|
|
("Temporal Match", "temporal_match"), |
|
|
("Scope Match", "scope_match"), |
|
|
("Format Valid", "format_valid"), |
|
|
("Mean Latency", "mean_latency"), |
|
|
] |
|
|
|
|
|
for display_name, key in metrics_to_show: |
|
|
print(f"{display_name:<20} ", end="") |
|
|
for model_name in all_results: |
|
|
value = all_results[model_name]["aggregate"][key] |
|
|
if key == "mean_latency": |
|
|
print(f"{value:.3f}s ", end="") |
|
|
else: |
|
|
print(f"{value:.1%} ", end="") |
|
|
print() |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
output_path = os.path.join(config.output_dir, f"benchmark_{timestamp}.json") |
|
|
|
|
|
with open(output_path, "w") as f: |
|
|
json.dump({ |
|
|
"config": { |
|
|
"base_model": config.base_model, |
|
|
"test_examples": len(test_data), |
|
|
}, |
|
|
"results": {k: v["aggregate"] for k, v in all_results.items()}, |
|
|
"details": {k: v["details"] for k, v in all_results.items()} |
|
|
}, f, indent=2, default=str) |
|
|
|
|
|
print(f"\nResults saved to: {output_path}") |
|
|
|
|
|
return all_results |
|
|
|
|
|
|
|
|
async def main(): |
|
|
import sys |
|
|
|
|
|
config = BenchmarkConfig() |
|
|
|
|
|
|
|
|
for arg in sys.argv[1:]: |
|
|
if "=" in arg: |
|
|
key, value = arg.split("=", 1) |
|
|
if hasattr(config, key): |
|
|
setattr(config, key, value) |
|
|
|
|
|
await run_benchmark(config) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main()) |
|
|
|
|
|
|