File size: 14,353 Bytes
685d968
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
"""
Test the SFT model on various inputs.

Tests:
1. Examples from training dataset
2. Examples from test dataset  
3. Novel inputs the model has never seen
"""

import asyncio
import json
import os
from dotenv import load_dotenv

load_dotenv()

import tinker
from tinker import types
from tinker_cookbook import renderers
from tinker_cookbook.tokenizer_utils import get_tokenizer

# Configuration
SFT_CHECKPOINT = "tinker://4f4bae1f-5a95-5f53-a55a-a14f2872825c:train:0/sampler_weights/sft_step_0090"
BASE_MODEL = "meta-llama/Llama-3.1-8B"

VALID_CATEGORIES = {
    "company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts",
    "company.business_priorities", "company.tools_config", "company.performance_context",
    "user.communication_style", "user.strategic_approach", "user.role_context",
    "user.workflow_patterns", "user.session_history", "user.interaction_preferences",
    "none"
}

# Novel test cases the model has never seen
NOVEL_INPUTS = [
    {
        "conversation": """USER: I prefer getting updates in bullet points, not long paragraphs. Keep it concise.
ASSISTANT: Understood! I'll format all future updates as bullet points for easy scanning.
USER: Perfect. Also, I check messages mainly in the morning, so schedule any non-urgent stuff for 8am.""",
        "expected": ["user.communication_style", "user.workflow_patterns"],
        "description": "Communication preferences + workflow timing"
    },
    {
        "conversation": """USER: Our brand voice is playful but professional. Think friendly neighbor who happens to be an expert.
ASSISTANT: Great analogy! So approachable expertise - warm but credible.
USER: Exactly. We never use corporate jargon. Everything should feel like a conversation.""",
        "expected": ["company.brand_core"],
        "description": "Brand voice definition"
    },
    {
        "conversation": """USER: What's the weather like today?
ASSISTANT: I don't have access to real-time weather data. You might want to check a weather app.
USER: Ok thanks.""",
        "expected": ["none"],
        "description": "Transactional/irrelevant conversation"
    },
    {
        "conversation": """USER: We use HubSpot for CRM, Mailchimp for emails, and Slack for team comms. Everything needs to sync.
ASSISTANT: Got it. I'll note those integrations. Any specific workflows between them?
USER: Yeah, new leads from HubSpot should trigger a Slack notification and add to a Mailchimp welcome sequence.""",
        "expected": ["company.tools_config"],
        "description": "Tool stack and integrations"
    },
    {
        "conversation": """USER: Last quarter's email campaign had 45% open rate but only 2% conversion. We need to fix the CTA.
ASSISTANT: That's a significant drop-off. The open rate suggests good subject lines but weak call-to-action.
USER: Right. Let's A/B test three different CTAs next month.""",
        "expected": ["company.performance_context"],
        "description": "Campaign metrics and learnings"
    },
    {
        "conversation": """USER: I'm the VP of Marketing, reporting directly to the CEO. I have final say on brand decisions.
ASSISTANT: Understood. So you have executive authority on brand strategy.
USER: Yes, but I like to get input from the creative team before making calls.""",
        "expected": ["user.role_context"],
        "description": "Role and authority"
    },
    {
        "conversation": """USER: We always prioritize data over intuition. Every decision needs numbers to back it up.
ASSISTANT: A data-driven approach. Do you have minimum thresholds for statistical significance?
USER: Yes, 95% confidence interval minimum. No exceptions.""",
        "expected": ["company.strategic_signatures"],
        "description": "Decision-making framework"
    },
    {
        "conversation": """USER: Q4 is all about the holiday push. Black Friday and Cyber Monday are our biggest days.
ASSISTANT: Seasonal focus noted. What's the budget allocation for Q4?
USER: 60% of annual marketing budget goes to Q4. It's make or break.""",
        "expected": ["company.business_priorities"],
        "description": "Quarterly priorities"
    },
    {
        "conversation": """USER: I want you to challenge my ideas more. Don't just agree with everything.
ASSISTANT: I appreciate that feedback. You want more critical analysis and pushback.
USER: Yes, play devil's advocate when needed. I learn better that way.""",
        "expected": ["user.interaction_preferences"],
        "description": "Coaching/feedback style"
    },
    {
        "conversation": """USER: Yesterday we discussed the influencer campaign. Did you save those notes?
ASSISTANT: Yes, I have the notes from our previous session about the influencer strategy.
USER: Good. Let's continue from where we left off - the micro-influencer targeting.""",
        "expected": ["user.session_history"],
        "description": "Reference to previous context"
    },
]


def parse_prediction(text: str) -> set:
    """Parse model output into category set."""
    if not text or not text.strip():
        return set()
    cats = [c.strip().lower() for c in text.split(",")]
    return {c for c in cats if c in VALID_CATEGORIES}


def compute_metrics(predicted: set, gold: set) -> dict:
    """Compute evaluation metrics."""
    if not predicted and not gold:
        return {"f1": 1.0, "precision": 1.0, "recall": 1.0, "exact_match": True, "any_match": True}
    if not predicted or not gold:
        return {"f1": 0.0, "precision": 0.0, "recall": 0.0, "exact_match": False, "any_match": False}
    
    tp = len(predicted & gold)
    precision = tp / len(predicted) if predicted else 0
    recall = tp / len(gold) if gold else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    return {
        "f1": f1,
        "precision": precision,
        "recall": recall,
        "exact_match": predicted == gold,
        "any_match": bool(predicted & gold)
    }


async def test_model():
    print("=" * 70)
    print("SFT MODEL EVALUATION")
    print("=" * 70)
    print(f"Checkpoint: {SFT_CHECKPOINT}")
    print()
    
    # Initialize
    service_client = tinker.ServiceClient()
    sampling_client = service_client.create_sampling_client(model_path=SFT_CHECKPOINT)
    tokenizer = get_tokenizer(BASE_MODEL)
    renderer = renderers.get_renderer(name="llama3", tokenizer=tokenizer)
    
    stop_sequences = renderer.get_stop_sequences()
    params = types.SamplingParams(max_tokens=100, temperature=0.1, stop=stop_sequences)
    
    # System prompt
    system_prompt = """You route marketing conversations into structured memory categories.

Available categories:
- company.brand_core: Voice, values, positioning, identity anchors
- company.strategic_signatures: Decision frameworks, strategic heuristics
- company.knowledge_artifacts: Docs, style guides, playbooks
- company.business_priorities: Quarterly/seasonal goals, active campaigns
- company.tools_config: Integrations, API keys, workflow settings
- company.performance_context: Campaign metrics, retrospectives, learnings
- user.communication_style: Tone, verbosity, format expectations
- user.strategic_approach: Personal priorities, success definitions
- user.role_context: Title, scope, decision authority
- user.workflow_patterns: Review cadence, collaboration norms
- user.session_history: Immediate context, recent asks
- user.interaction_preferences: Coaching style, feedback expectations
- none: Irrelevant, vague, or transactional content

Respond with comma-separated categories. Use 'none' only if no other category applies."""

    # =========================================================================
    # Test 1: Novel inputs
    # =========================================================================
    print("-" * 70)
    print("TEST 1: NOVEL INPUTS (Never seen during training)")
    print("-" * 70)
    
    novel_results = []
    
    for i, test_case in enumerate(NOVEL_INPUTS):
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": f"Analyze this conversation and determine which memory categories apply:\n\n{test_case['conversation']}"}
        ]
        
        prompt = renderer.build_generation_prompt(messages)
        result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result()
        response, _ = renderer.parse_response(result.sequences[0].tokens)
        predicted_text = response["content"]
        
        predicted = parse_prediction(predicted_text)
        gold = set(test_case["expected"])
        metrics = compute_metrics(predicted, gold)
        
        novel_results.append(metrics)
        
        status = "✓" if metrics["any_match"] else "✗"
        exact = "EXACT" if metrics["exact_match"] else ""
        
        print(f"\n[{i+1}] {test_case['description']}")
        print(f"    Expected:  {', '.join(sorted(gold))}")
        print(f"    Predicted: {predicted_text}")
        print(f"    {status} F1: {metrics['f1']:.2f} {exact}")
    
    # Summary for novel inputs
    avg_f1 = sum(r["f1"] for r in novel_results) / len(novel_results)
    any_match_rate = sum(1 for r in novel_results if r["any_match"]) / len(novel_results)
    exact_match_rate = sum(1 for r in novel_results if r["exact_match"]) / len(novel_results)
    
    print(f"\n{'='*50}")
    print(f"NOVEL INPUTS SUMMARY ({len(novel_results)} examples)")
    print(f"  Any Match:   {any_match_rate:.1%}")
    print(f"  Exact Match: {exact_match_rate:.1%}")
    print(f"  Avg F1:      {avg_f1:.2f}")
    
    # =========================================================================
    # Test 2: Training dataset examples
    # =========================================================================
    print("\n" + "-" * 70)
    print("TEST 2: TRAINING DATASET EXAMPLES")
    print("-" * 70)
    
    with open("training/processed_data/train_data.json", "r") as f:
        train_data = json.load(f)
    
    # Sample 20 random examples
    import random
    random.seed(42)
    sample_indices = random.sample(range(len(train_data)), min(20, len(train_data)))
    
    train_results = []
    
    for idx in sample_indices:
        item = train_data[idx]
        messages = item["messages"][:-1]  # Exclude assistant response
        gold = set(item["categories"])
        
        prompt = renderer.build_generation_prompt(messages)
        result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result()
        response, _ = renderer.parse_response(result.sequences[0].tokens)
        predicted_text = response["content"]
        
        predicted = parse_prediction(predicted_text)
        metrics = compute_metrics(predicted, gold)
        train_results.append(metrics)
    
    avg_f1 = sum(r["f1"] for r in train_results) / len(train_results)
    any_match_rate = sum(1 for r in train_results if r["any_match"]) / len(train_results)
    exact_match_rate = sum(1 for r in train_results if r["exact_match"]) / len(train_results)
    
    print(f"TRAINING SET SAMPLE ({len(train_results)} examples)")
    print(f"  Any Match:   {any_match_rate:.1%}")
    print(f"  Exact Match: {exact_match_rate:.1%}")
    print(f"  Avg F1:      {avg_f1:.2f}")
    
    # =========================================================================
    # Test 3: Test dataset examples
    # =========================================================================
    print("\n" + "-" * 70)
    print("TEST 3: TEST DATASET EXAMPLES (Held out)")
    print("-" * 70)
    
    with open("training/processed_data/test_data.json", "r") as f:
        test_data = json.load(f)
    
    test_results = []
    
    for item in test_data[:50]:  # First 50 test examples
        messages = item["messages"][:-1]
        gold = set(item["categories"])
        
        prompt = renderer.build_generation_prompt(messages)
        result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result()
        response, _ = renderer.parse_response(result.sequences[0].tokens)
        predicted_text = response["content"]
        
        predicted = parse_prediction(predicted_text)
        metrics = compute_metrics(predicted, gold)
        test_results.append(metrics)
    
    avg_f1 = sum(r["f1"] for r in test_results) / len(test_results)
    any_match_rate = sum(1 for r in test_results if r["any_match"]) / len(test_results)
    exact_match_rate = sum(1 for r in test_results if r["exact_match"]) / len(test_results)
    
    print(f"TEST SET ({len(test_results)} examples)")
    print(f"  Any Match:   {any_match_rate:.1%}")
    print(f"  Exact Match: {exact_match_rate:.1%}")
    print(f"  Avg F1:      {avg_f1:.2f}")
    
    # =========================================================================
    # Overall Summary
    # =========================================================================
    print("\n" + "=" * 70)
    print("OVERALL SUMMARY")
    print("=" * 70)
    
    print(f"\n{'Dataset':<20} {'Any Match':<12} {'Exact Match':<12} {'Avg F1':<10}")
    print("-" * 54)
    
    # Novel
    novel_any = sum(1 for r in novel_results if r["any_match"]) / len(novel_results)
    novel_exact = sum(1 for r in novel_results if r["exact_match"]) / len(novel_results)
    novel_f1 = sum(r["f1"] for r in novel_results) / len(novel_results)
    print(f"{'Novel Inputs':<20} {novel_any:<12.1%} {novel_exact:<12.1%} {novel_f1:<10.2f}")
    
    # Train
    train_any = sum(1 for r in train_results if r["any_match"]) / len(train_results)
    train_exact = sum(1 for r in train_results if r["exact_match"]) / len(train_results)
    train_f1 = sum(r["f1"] for r in train_results) / len(train_results)
    print(f"{'Train Sample':<20} {train_any:<12.1%} {train_exact:<12.1%} {train_f1:<10.2f}")
    
    # Test
    test_any = sum(1 for r in test_results if r["any_match"]) / len(test_results)
    test_exact = sum(1 for r in test_results if r["exact_match"]) / len(test_results)
    test_f1 = sum(r["f1"] for r in test_results) / len(test_results)
    print(f"{'Test Set':<20} {test_any:<12.1%} {test_exact:<12.1%} {test_f1:<10.2f}")
    
    print("\n" + "=" * 70)
    print("SFT EVALUATION COMPLETE")
    print("=" * 70)


if __name__ == "__main__":
    asyncio.run(test_model())