""" Test the SFT model on various inputs. Tests: 1. Examples from training dataset 2. Examples from test dataset 3. Novel inputs the model has never seen """ import asyncio import json import os from dotenv import load_dotenv load_dotenv() import tinker from tinker import types from tinker_cookbook import renderers from tinker_cookbook.tokenizer_utils import get_tokenizer # Configuration SFT_CHECKPOINT = "tinker://4f4bae1f-5a95-5f53-a55a-a14f2872825c:train:0/sampler_weights/sft_step_0090" BASE_MODEL = "meta-llama/Llama-3.1-8B" VALID_CATEGORIES = { "company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts", "company.business_priorities", "company.tools_config", "company.performance_context", "user.communication_style", "user.strategic_approach", "user.role_context", "user.workflow_patterns", "user.session_history", "user.interaction_preferences", "none" } # Novel test cases the model has never seen NOVEL_INPUTS = [ { "conversation": """USER: I prefer getting updates in bullet points, not long paragraphs. Keep it concise. ASSISTANT: Understood! I'll format all future updates as bullet points for easy scanning. USER: Perfect. Also, I check messages mainly in the morning, so schedule any non-urgent stuff for 8am.""", "expected": ["user.communication_style", "user.workflow_patterns"], "description": "Communication preferences + workflow timing" }, { "conversation": """USER: Our brand voice is playful but professional. Think friendly neighbor who happens to be an expert. ASSISTANT: Great analogy! So approachable expertise - warm but credible. USER: Exactly. We never use corporate jargon. Everything should feel like a conversation.""", "expected": ["company.brand_core"], "description": "Brand voice definition" }, { "conversation": """USER: What's the weather like today? ASSISTANT: I don't have access to real-time weather data. You might want to check a weather app. USER: Ok thanks.""", "expected": ["none"], "description": "Transactional/irrelevant conversation" }, { "conversation": """USER: We use HubSpot for CRM, Mailchimp for emails, and Slack for team comms. Everything needs to sync. ASSISTANT: Got it. I'll note those integrations. Any specific workflows between them? USER: Yeah, new leads from HubSpot should trigger a Slack notification and add to a Mailchimp welcome sequence.""", "expected": ["company.tools_config"], "description": "Tool stack and integrations" }, { "conversation": """USER: Last quarter's email campaign had 45% open rate but only 2% conversion. We need to fix the CTA. ASSISTANT: That's a significant drop-off. The open rate suggests good subject lines but weak call-to-action. USER: Right. Let's A/B test three different CTAs next month.""", "expected": ["company.performance_context"], "description": "Campaign metrics and learnings" }, { "conversation": """USER: I'm the VP of Marketing, reporting directly to the CEO. I have final say on brand decisions. ASSISTANT: Understood. So you have executive authority on brand strategy. USER: Yes, but I like to get input from the creative team before making calls.""", "expected": ["user.role_context"], "description": "Role and authority" }, { "conversation": """USER: We always prioritize data over intuition. Every decision needs numbers to back it up. ASSISTANT: A data-driven approach. Do you have minimum thresholds for statistical significance? USER: Yes, 95% confidence interval minimum. No exceptions.""", "expected": ["company.strategic_signatures"], "description": "Decision-making framework" }, { "conversation": """USER: Q4 is all about the holiday push. Black Friday and Cyber Monday are our biggest days. ASSISTANT: Seasonal focus noted. What's the budget allocation for Q4? USER: 60% of annual marketing budget goes to Q4. It's make or break.""", "expected": ["company.business_priorities"], "description": "Quarterly priorities" }, { "conversation": """USER: I want you to challenge my ideas more. Don't just agree with everything. ASSISTANT: I appreciate that feedback. You want more critical analysis and pushback. USER: Yes, play devil's advocate when needed. I learn better that way.""", "expected": ["user.interaction_preferences"], "description": "Coaching/feedback style" }, { "conversation": """USER: Yesterday we discussed the influencer campaign. Did you save those notes? ASSISTANT: Yes, I have the notes from our previous session about the influencer strategy. USER: Good. Let's continue from where we left off - the micro-influencer targeting.""", "expected": ["user.session_history"], "description": "Reference to previous context" }, ] def parse_prediction(text: str) -> set: """Parse model output into category set.""" if not text or not text.strip(): return set() cats = [c.strip().lower() for c in text.split(",")] return {c for c in cats if c in VALID_CATEGORIES} def compute_metrics(predicted: set, gold: set) -> dict: """Compute evaluation metrics.""" if not predicted and not gold: return {"f1": 1.0, "precision": 1.0, "recall": 1.0, "exact_match": True, "any_match": True} if not predicted or not gold: return {"f1": 0.0, "precision": 0.0, "recall": 0.0, "exact_match": False, "any_match": False} tp = len(predicted & gold) precision = tp / len(predicted) if predicted else 0 recall = tp / len(gold) if gold else 0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 return { "f1": f1, "precision": precision, "recall": recall, "exact_match": predicted == gold, "any_match": bool(predicted & gold) } async def test_model(): print("=" * 70) print("SFT MODEL EVALUATION") print("=" * 70) print(f"Checkpoint: {SFT_CHECKPOINT}") print() # Initialize service_client = tinker.ServiceClient() sampling_client = service_client.create_sampling_client(model_path=SFT_CHECKPOINT) tokenizer = get_tokenizer(BASE_MODEL) renderer = renderers.get_renderer(name="llama3", tokenizer=tokenizer) stop_sequences = renderer.get_stop_sequences() params = types.SamplingParams(max_tokens=100, temperature=0.1, stop=stop_sequences) # System prompt system_prompt = """You route marketing conversations into structured memory categories. Available categories: - company.brand_core: Voice, values, positioning, identity anchors - company.strategic_signatures: Decision frameworks, strategic heuristics - company.knowledge_artifacts: Docs, style guides, playbooks - company.business_priorities: Quarterly/seasonal goals, active campaigns - company.tools_config: Integrations, API keys, workflow settings - company.performance_context: Campaign metrics, retrospectives, learnings - user.communication_style: Tone, verbosity, format expectations - user.strategic_approach: Personal priorities, success definitions - user.role_context: Title, scope, decision authority - user.workflow_patterns: Review cadence, collaboration norms - user.session_history: Immediate context, recent asks - user.interaction_preferences: Coaching style, feedback expectations - none: Irrelevant, vague, or transactional content Respond with comma-separated categories. Use 'none' only if no other category applies.""" # ========================================================================= # Test 1: Novel inputs # ========================================================================= print("-" * 70) print("TEST 1: NOVEL INPUTS (Never seen during training)") print("-" * 70) novel_results = [] for i, test_case in enumerate(NOVEL_INPUTS): messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": f"Analyze this conversation and determine which memory categories apply:\n\n{test_case['conversation']}"} ] prompt = renderer.build_generation_prompt(messages) result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result() response, _ = renderer.parse_response(result.sequences[0].tokens) predicted_text = response["content"] predicted = parse_prediction(predicted_text) gold = set(test_case["expected"]) metrics = compute_metrics(predicted, gold) novel_results.append(metrics) status = "✓" if metrics["any_match"] else "✗" exact = "EXACT" if metrics["exact_match"] else "" print(f"\n[{i+1}] {test_case['description']}") print(f" Expected: {', '.join(sorted(gold))}") print(f" Predicted: {predicted_text}") print(f" {status} F1: {metrics['f1']:.2f} {exact}") # Summary for novel inputs avg_f1 = sum(r["f1"] for r in novel_results) / len(novel_results) any_match_rate = sum(1 for r in novel_results if r["any_match"]) / len(novel_results) exact_match_rate = sum(1 for r in novel_results if r["exact_match"]) / len(novel_results) print(f"\n{'='*50}") print(f"NOVEL INPUTS SUMMARY ({len(novel_results)} examples)") print(f" Any Match: {any_match_rate:.1%}") print(f" Exact Match: {exact_match_rate:.1%}") print(f" Avg F1: {avg_f1:.2f}") # ========================================================================= # Test 2: Training dataset examples # ========================================================================= print("\n" + "-" * 70) print("TEST 2: TRAINING DATASET EXAMPLES") print("-" * 70) with open("training/processed_data/train_data.json", "r") as f: train_data = json.load(f) # Sample 20 random examples import random random.seed(42) sample_indices = random.sample(range(len(train_data)), min(20, len(train_data))) train_results = [] for idx in sample_indices: item = train_data[idx] messages = item["messages"][:-1] # Exclude assistant response gold = set(item["categories"]) prompt = renderer.build_generation_prompt(messages) result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result() response, _ = renderer.parse_response(result.sequences[0].tokens) predicted_text = response["content"] predicted = parse_prediction(predicted_text) metrics = compute_metrics(predicted, gold) train_results.append(metrics) avg_f1 = sum(r["f1"] for r in train_results) / len(train_results) any_match_rate = sum(1 for r in train_results if r["any_match"]) / len(train_results) exact_match_rate = sum(1 for r in train_results if r["exact_match"]) / len(train_results) print(f"TRAINING SET SAMPLE ({len(train_results)} examples)") print(f" Any Match: {any_match_rate:.1%}") print(f" Exact Match: {exact_match_rate:.1%}") print(f" Avg F1: {avg_f1:.2f}") # ========================================================================= # Test 3: Test dataset examples # ========================================================================= print("\n" + "-" * 70) print("TEST 3: TEST DATASET EXAMPLES (Held out)") print("-" * 70) with open("training/processed_data/test_data.json", "r") as f: test_data = json.load(f) test_results = [] for item in test_data[:50]: # First 50 test examples messages = item["messages"][:-1] gold = set(item["categories"]) prompt = renderer.build_generation_prompt(messages) result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result() response, _ = renderer.parse_response(result.sequences[0].tokens) predicted_text = response["content"] predicted = parse_prediction(predicted_text) metrics = compute_metrics(predicted, gold) test_results.append(metrics) avg_f1 = sum(r["f1"] for r in test_results) / len(test_results) any_match_rate = sum(1 for r in test_results if r["any_match"]) / len(test_results) exact_match_rate = sum(1 for r in test_results if r["exact_match"]) / len(test_results) print(f"TEST SET ({len(test_results)} examples)") print(f" Any Match: {any_match_rate:.1%}") print(f" Exact Match: {exact_match_rate:.1%}") print(f" Avg F1: {avg_f1:.2f}") # ========================================================================= # Overall Summary # ========================================================================= print("\n" + "=" * 70) print("OVERALL SUMMARY") print("=" * 70) print(f"\n{'Dataset':<20} {'Any Match':<12} {'Exact Match':<12} {'Avg F1':<10}") print("-" * 54) # Novel novel_any = sum(1 for r in novel_results if r["any_match"]) / len(novel_results) novel_exact = sum(1 for r in novel_results if r["exact_match"]) / len(novel_results) novel_f1 = sum(r["f1"] for r in novel_results) / len(novel_results) print(f"{'Novel Inputs':<20} {novel_any:<12.1%} {novel_exact:<12.1%} {novel_f1:<10.2f}") # Train train_any = sum(1 for r in train_results if r["any_match"]) / len(train_results) train_exact = sum(1 for r in train_results if r["exact_match"]) / len(train_results) train_f1 = sum(r["f1"] for r in train_results) / len(train_results) print(f"{'Train Sample':<20} {train_any:<12.1%} {train_exact:<12.1%} {train_f1:<10.2f}") # Test test_any = sum(1 for r in test_results if r["any_match"]) / len(test_results) test_exact = sum(1 for r in test_results if r["exact_match"]) / len(test_results) test_f1 = sum(r["f1"] for r in test_results) / len(test_results) print(f"{'Test Set':<20} {test_any:<12.1%} {test_exact:<12.1%} {test_f1:<10.2f}") print("\n" + "=" * 70) print("SFT EVALUATION COMPLETE") print("=" * 70) if __name__ == "__main__": asyncio.run(test_model())