File size: 10,596 Bytes
685d968 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
"""
Quick evaluation of SFT model on test data and novel inputs.
"""
import asyncio
import json
import os
from dotenv import load_dotenv
# Load env before imports
load_dotenv()
import tinker
from tinker import types
from tinker_cookbook import renderers
from tinker_cookbook.tokenizer_utils import get_tokenizer
BASE_MODEL = "meta-llama/Llama-3.1-8B"
SFT_CHECKPOINT = "tinker://44184486-b599-5739-8819-c38f1d5d8ad0:train:0/sampler_weights/sft_step_0049"
SYSTEM_PROMPT = """You route marketing conversations into structured memory categories.
Available categories:
- company.brand_core: Voice, values, positioning, identity anchors (Long >1y)
- company.strategic_signatures: Decision frameworks, strategic heuristics (Long >1y)
- company.knowledge_artifacts: Docs, style guides, playbooks (Long >1y)
- company.business_priorities: Quarterly/seasonal goals, active campaigns (Short <3m)
- company.tools_config: Integrations, API keys, workflow settings (Medium ~6m)
- company.performance_context: Campaign metrics, retrospectives, learnings (Rolling ~6m)
- user.communication_style: Tone, verbosity, format expectations (Long >1y)
- user.strategic_approach: Personal priorities, success definitions (Long >1y)
- user.role_context: Title, scope, decision authority (Medium ~1y)
- user.workflow_patterns: Review cadence, collaboration norms (Medium ~1y)
- user.session_history: Immediate context, recent asks (Short <2w)
- user.interaction_preferences: Coaching style, feedback expectations (Evolving)
- none: Irrelevant, vague, or transactional content
Respond with comma-separated categories. Use 'none' only if no other category applies."""
VALID_CATEGORIES = {
"company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts",
"company.business_priorities", "company.tools_config", "company.performance_context",
"user.communication_style", "user.strategic_approach", "user.role_context",
"user.workflow_patterns", "user.session_history", "user.interaction_preferences",
"none"
}
# Novel test cases the model has NEVER seen
NOVEL_TESTS = [
{
"name": "CEO Strategy Discussion",
"conversation": """USER: As CEO, I've decided we're pivoting from B2B to B2C next quarter. This is confidential until the board meeting.
ASSISTANT: Understood. This strategic pivot will require significant changes across marketing, sales, and product. Should I help outline the key considerations?
USER: Yes, and remember - I prefer bullet points over long paragraphs. Keep updates brief.""",
"expected_categories": ["user.strategic_approach", "company.strategic_signatures", "user.communication_style"]
},
{
"name": "Tool Integration Request",
"conversation": """USER: Can you help me set up the Slack integration with our CRM? The API key is sk-xxx-123.
ASSISTANT: I'll help you configure the Slack-CRM integration. First, let me verify the API key format...
USER: Also, make sure notifications go to #marketing-alerts channel.""",
"expected_categories": ["company.tools_config"]
},
{
"name": "Casual Chat (None)",
"conversation": """USER: Hey, what's the weather like today?
ASSISTANT: I don't have access to real-time weather data. You might want to check a weather app or website.
USER: No worries, just making small talk while I wait for my coffee.""",
"expected_categories": ["none"]
},
{
"name": "Brand Voice Guidelines",
"conversation": """USER: Our brand voice should always be warm, approachable, and slightly playful. We never use corporate jargon.
ASSISTANT: Got it - warm, approachable, playful, no jargon. Should I apply this to all future content?
USER: Yes, and we always sign off emails with 'Cheers' not 'Best regards'.""",
"expected_categories": ["company.brand_core", "company.knowledge_artifacts"]
},
{
"name": "Campaign Performance Review",
"conversation": """USER: The Q3 email campaign had a 24% open rate and 3.2% CTR. That's up from 18% and 2.1% last quarter.
ASSISTANT: Great improvement! The new subject line strategy seems to be working. What do you attribute the gains to?
USER: Definitely the personalization we added. Let's document this for future campaigns.""",
"expected_categories": ["company.performance_context", "company.knowledge_artifacts"]
},
{
"name": "Session Context",
"conversation": """USER: Let's continue working on the proposal we started yesterday.
ASSISTANT: Of course! We were drafting the executive summary for the Johnson account. You wanted to emphasize ROI projections.
USER: Right, and I mentioned I need this done by Friday 3pm.""",
"expected_categories": ["user.session_history"]
},
{
"name": "Role and Authority",
"conversation": """USER: I'm the VP of Marketing and I report directly to the CMO. I have budget authority up to $50k.
ASSISTANT: Noted. With that authority level, you can approve most campaign expenditures directly.
USER: Exactly. Anything above $50k needs Sarah's sign-off.""",
"expected_categories": ["user.role_context"]
},
{
"name": "Workflow Preferences",
"conversation": """USER: I review all campaign drafts on Monday mornings. Don't send me anything on Fridays - I'm in back-to-back meetings.
ASSISTANT: Understood. Monday morning reviews, no Friday submissions.
USER: And I always want to see the A/B test variants side by side, not in separate documents.""",
"expected_categories": ["user.workflow_patterns", "user.interaction_preferences"]
},
{
"name": "Q4 Goals",
"conversation": """USER: For Q4, we're focusing on lead gen for the enterprise segment. Target is 500 MQLs.
ASSISTANT: Got it - enterprise focus, 500 MQL target for Q4. Should I prioritize content for that segment?
USER: Yes, and the holiday campaign should tie into this goal.""",
"expected_categories": ["company.business_priorities"]
},
{
"name": "Technical Documentation",
"conversation": """USER: Here's our updated style guide. Headlines should be sentence case, max 60 characters. Body copy in AP style.
ASSISTANT: Noted - sentence case headlines under 60 chars, AP style for body.
USER: I'm also attaching our brand color codes: primary #2563EB, secondary #10B981.""",
"expected_categories": ["company.knowledge_artifacts", "company.brand_core"]
}
]
async def test_model():
print("=" * 70)
print("SFT MODEL EVALUATION")
print("=" * 70)
print(f"Checkpoint: {SFT_CHECKPOINT}")
print()
# Initialize
service_client = tinker.ServiceClient()
tokenizer = get_tokenizer(BASE_MODEL)
renderer = renderers.get_renderer(name="llama3", tokenizer=tokenizer)
sampling_client = service_client.create_sampling_client(model_path=SFT_CHECKPOINT)
stop_sequences = renderer.get_stop_sequences()
params = types.SamplingParams(max_tokens=100, temperature=0.1, stop=stop_sequences)
# Part 1: Test on held-out test data
print("-" * 70)
print("PART 1: TEST SET EVALUATION (50 examples)")
print("-" * 70)
with open("synthetic_data/training_dataset_1000.jsonl") as f:
all_data = [json.loads(l) for l in f]
# Use last 200 as test, sample 50
test_data = all_data[-200:][:50]
correct_any = 0
correct_exact = 0
for i, item in enumerate(test_data):
conv = item.get("conversation", [])
gold = item.get("labels", {}).get("categories", [])
# Build conversation text
conv_text = ""
for turn in conv:
if isinstance(turn, dict):
conv_text += f"{turn['role'].upper()}: {turn['content']}\n"
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Conversation:\n{conv_text}"}
]
prompt = renderer.build_generation_prompt(messages)
result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result()
response, _ = renderer.parse_response(result.sequences[0].tokens)
pred = response["content"]
pred_set = set([c.strip().lower() for c in pred.split(",") if c.strip().lower() in VALID_CATEGORIES])
gold_set = set([c.lower() for c in gold])
if pred_set & gold_set:
correct_any += 1
if pred_set == gold_set:
correct_exact += 1
if (i + 1) % 10 == 0:
print(f" Processed {i+1}/50...")
print()
print(f"Any Match Accuracy: {correct_any}/{len(test_data)} = {correct_any/len(test_data):.1%}")
print(f"Exact Match Accuracy: {correct_exact}/{len(test_data)} = {correct_exact/len(test_data):.1%}")
# Part 2: Novel inputs
print()
print("-" * 70)
print("PART 2: NOVEL INPUTS (Never seen during training)")
print("-" * 70)
novel_correct = 0
novel_exact = 0
for test in NOVEL_TESTS:
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Conversation:\n{test['conversation']}"}
]
prompt = renderer.build_generation_prompt(messages)
result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result()
response, _ = renderer.parse_response(result.sequences[0].tokens)
pred = response["content"]
pred_set = set([c.strip().lower() for c in pred.split(",") if c.strip().lower() in VALID_CATEGORIES])
expected_set = set([c.lower() for c in test["expected_categories"]])
any_match = bool(pred_set & expected_set)
exact_match = pred_set == expected_set
if any_match:
novel_correct += 1
if exact_match:
novel_exact += 1
match_icon = "✓" if any_match else "✗"
exact_icon = " [EXACT]" if exact_match else ""
print(f"\n{match_icon} {test['name']}{exact_icon}")
print(f" Expected: {', '.join(sorted(test['expected_categories']))}")
print(f" Predicted: {pred.strip()}")
print()
print("-" * 70)
print("NOVEL INPUT RESULTS")
print("-" * 70)
print(f"Any Match: {novel_correct}/{len(NOVEL_TESTS)} = {novel_correct/len(NOVEL_TESTS):.1%}")
print(f"Exact Match: {novel_exact}/{len(NOVEL_TESTS)} = {novel_exact/len(NOVEL_TESTS):.1%}")
print()
if __name__ == "__main__":
asyncio.run(test_model())
|