File size: 6,894 Bytes
685d968
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""
Quick test of the RL model.
"""

import asyncio
import json
import os
from dotenv import load_dotenv

load_dotenv()

import tinker
from tinker import types
from tinker_cookbook import renderers
from tinker_cookbook.tokenizer_utils import get_tokenizer

BASE_MODEL = "meta-llama/Llama-3.1-8B"

# Run with both SFT and RL (most iterations)
RL_CHECKPOINT = "tinker://398393e1-7182-555d-aa1b-7ddf23892338:train:0/sampler_weights/rl_iter_005"

# SFT from the same run
SFT_CHECKPOINT = "tinker://398393e1-7182-555d-aa1b-7ddf23892338:train:0/sampler_weights/sft_final_sampler"

VALID_CATEGORIES = {
    "company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts",
    "company.business_priorities", "company.tools_config", "company.performance_context",
    "user.communication_style", "user.strategic_approach", "user.role_context",
    "user.workflow_patterns", "user.session_history", "user.interaction_preferences",
    "none"
}


SYSTEM_PROMPT = """You route marketing conversations into structured memory categories.

Available categories:
- company.brand_core: Voice, values, positioning, identity anchors (Long >1y)
- company.strategic_signatures: Decision frameworks, strategic heuristics (Long >1y)
- company.knowledge_artifacts: Docs, style guides, playbooks (Long >1y)
- company.business_priorities: Quarterly/seasonal goals, active campaigns (Short <3m)
- company.tools_config: Integrations, API keys, workflow settings (Medium ~6m)
- company.performance_context: Campaign metrics, retrospectives, learnings (Rolling ~6m)
- user.communication_style: Tone, verbosity, format expectations (Long >1y)
- user.strategic_approach: Personal priorities, success definitions (Long >1y)
- user.role_context: Title, scope, decision authority (Medium ~1y)
- user.workflow_patterns: Review cadence, collaboration norms (Medium ~1y)
- user.session_history: Immediate context, recent asks (Short <2w)
- user.interaction_preferences: Coaching style, feedback expectations (Evolving)
- none: Irrelevant, vague, or transactional content

Respond with comma-separated categories. Use 'none' only if no other category applies."""


async def test_model(checkpoint: str, name: str, test_examples: list):
    """Test a model on examples."""
    print(f"\n{'='*60}")
    print(f"TESTING: {name}")
    print(f"Checkpoint: {checkpoint}")
    print(f"{'='*60}")
    
    service_client = tinker.ServiceClient()
    tokenizer = get_tokenizer(BASE_MODEL)
    renderer = renderers.get_renderer(name="llama3", tokenizer=tokenizer)
    
    sampling_client = service_client.create_sampling_client(model_path=checkpoint)
    stop_sequences = renderer.get_stop_sequences()
    
    results = []
    
    for i, example in enumerate(test_examples):
        messages = example.get("messages", [])
        gold = example.get("categories", [])
        
        # Build prompt with system message (matching training format)
        conversation_text = ""
        for m in messages:
            role = m["role"].upper()
            conversation_text += f"{role}: {m['content']}\n"
        
        prompt_messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": f"Conversation:\n{conversation_text}"}
        ]
        
        prompt = renderer.build_generation_prompt(prompt_messages)
        params = types.SamplingParams(max_tokens=100, temperature=0.1, stop=stop_sequences)
        
        result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result()
        response, success = renderer.parse_response(result.sequences[0].tokens)
        predicted = response["content"] if success else ""
        
        # Parse prediction
        predicted_set = set([c.strip().lower() for c in predicted.split(",") 
                           if c.strip().lower() in VALID_CATEGORIES])
        gold_set = set([c.lower() for c in gold])
        
        any_match = len(predicted_set & gold_set) > 0 if gold_set else (len(predicted_set) == 0)
        exact_match = predicted_set == gold_set
        
        results.append({
            "any_match": any_match,
            "exact_match": exact_match,
            "predicted": predicted,
            "gold": gold
        })
        
        # Show first 5 examples
        if i < 5:
            print(f"\nExample {i+1}:")
            print(f"  Gold: {gold}")
            print(f"  Pred: {predicted}")
            print(f"  Match: {'Yes' if any_match else 'No'}")
    
    # Summary
    any_match_rate = sum(r["any_match"] for r in results) / len(results) if results else 0
    exact_match_rate = sum(r["exact_match"] for r in results) / len(results) if results else 0
    
    print(f"\n--- Results ({len(results)} examples) ---")
    print(f"Any Match:   {any_match_rate:.1%}")
    print(f"Exact Match: {exact_match_rate:.1%}")
    
    return {"any_match": any_match_rate, "exact_match": exact_match_rate}


async def main():
    # First, preprocess data
    print("=" * 60)
    print("LOADING TEST DATA")
    print("=" * 60)
    
    data = []
    with open("synthetic_data/training_dataset_1000.jsonl", "r") as f:
        for line in f:
            item = json.loads(line)
            messages = []
            for turn in item.get("conversation", []):
                if isinstance(turn, dict):
                    messages.append({"role": turn["role"], "content": turn["content"]})
            
            # Extract categories - handle nested labels structure
            labels = item.get("labels", {})
            if isinstance(labels, dict):
                categories = labels.get("categories", [])
            elif isinstance(labels, list):
                categories = labels
            else:
                categories = []
            
            if not categories:
                # Parse from scenario_id
                scenario_id = item.get("scenario_id", "")
                if "." in scenario_id:
                    cat = scenario_id.split("_")[0]
                    categories = [cat]
            
            data.append({
                "messages": messages,
                "categories": categories
            })
    
    print(f"Total examples: {len(data)}")
    
    # Use last 50 as test
    test_data = data[-50:]
    print(f"Test examples: {len(test_data)}")
    
    # Test RL model
    rl_results = await test_model(RL_CHECKPOINT, "RL Model (5 iters)", test_data)
    
    # Test SFT model for comparison
    sft_results = await test_model(SFT_CHECKPOINT, "SFT Model", test_data)
    
    print("\n" + "=" * 60)
    print("COMPARISON")
    print("=" * 60)
    print(f"SFT Any Match:   {sft_results['any_match']:.1%}")
    print(f"RL Any Match:    {rl_results['any_match']:.1%}")
    print(f"Improvement:     {(rl_results['any_match'] - sft_results['any_match'])*100:+.1f}pp")


if __name__ == "__main__":
    asyncio.run(main())