File size: 6,894 Bytes
685d968 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
"""
Quick test of the RL model.
"""
import asyncio
import json
import os
from dotenv import load_dotenv
load_dotenv()
import tinker
from tinker import types
from tinker_cookbook import renderers
from tinker_cookbook.tokenizer_utils import get_tokenizer
BASE_MODEL = "meta-llama/Llama-3.1-8B"
# Run with both SFT and RL (most iterations)
RL_CHECKPOINT = "tinker://398393e1-7182-555d-aa1b-7ddf23892338:train:0/sampler_weights/rl_iter_005"
# SFT from the same run
SFT_CHECKPOINT = "tinker://398393e1-7182-555d-aa1b-7ddf23892338:train:0/sampler_weights/sft_final_sampler"
VALID_CATEGORIES = {
"company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts",
"company.business_priorities", "company.tools_config", "company.performance_context",
"user.communication_style", "user.strategic_approach", "user.role_context",
"user.workflow_patterns", "user.session_history", "user.interaction_preferences",
"none"
}
SYSTEM_PROMPT = """You route marketing conversations into structured memory categories.
Available categories:
- company.brand_core: Voice, values, positioning, identity anchors (Long >1y)
- company.strategic_signatures: Decision frameworks, strategic heuristics (Long >1y)
- company.knowledge_artifacts: Docs, style guides, playbooks (Long >1y)
- company.business_priorities: Quarterly/seasonal goals, active campaigns (Short <3m)
- company.tools_config: Integrations, API keys, workflow settings (Medium ~6m)
- company.performance_context: Campaign metrics, retrospectives, learnings (Rolling ~6m)
- user.communication_style: Tone, verbosity, format expectations (Long >1y)
- user.strategic_approach: Personal priorities, success definitions (Long >1y)
- user.role_context: Title, scope, decision authority (Medium ~1y)
- user.workflow_patterns: Review cadence, collaboration norms (Medium ~1y)
- user.session_history: Immediate context, recent asks (Short <2w)
- user.interaction_preferences: Coaching style, feedback expectations (Evolving)
- none: Irrelevant, vague, or transactional content
Respond with comma-separated categories. Use 'none' only if no other category applies."""
async def test_model(checkpoint: str, name: str, test_examples: list):
"""Test a model on examples."""
print(f"\n{'='*60}")
print(f"TESTING: {name}")
print(f"Checkpoint: {checkpoint}")
print(f"{'='*60}")
service_client = tinker.ServiceClient()
tokenizer = get_tokenizer(BASE_MODEL)
renderer = renderers.get_renderer(name="llama3", tokenizer=tokenizer)
sampling_client = service_client.create_sampling_client(model_path=checkpoint)
stop_sequences = renderer.get_stop_sequences()
results = []
for i, example in enumerate(test_examples):
messages = example.get("messages", [])
gold = example.get("categories", [])
# Build prompt with system message (matching training format)
conversation_text = ""
for m in messages:
role = m["role"].upper()
conversation_text += f"{role}: {m['content']}\n"
prompt_messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Conversation:\n{conversation_text}"}
]
prompt = renderer.build_generation_prompt(prompt_messages)
params = types.SamplingParams(max_tokens=100, temperature=0.1, stop=stop_sequences)
result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result()
response, success = renderer.parse_response(result.sequences[0].tokens)
predicted = response["content"] if success else ""
# Parse prediction
predicted_set = set([c.strip().lower() for c in predicted.split(",")
if c.strip().lower() in VALID_CATEGORIES])
gold_set = set([c.lower() for c in gold])
any_match = len(predicted_set & gold_set) > 0 if gold_set else (len(predicted_set) == 0)
exact_match = predicted_set == gold_set
results.append({
"any_match": any_match,
"exact_match": exact_match,
"predicted": predicted,
"gold": gold
})
# Show first 5 examples
if i < 5:
print(f"\nExample {i+1}:")
print(f" Gold: {gold}")
print(f" Pred: {predicted}")
print(f" Match: {'Yes' if any_match else 'No'}")
# Summary
any_match_rate = sum(r["any_match"] for r in results) / len(results) if results else 0
exact_match_rate = sum(r["exact_match"] for r in results) / len(results) if results else 0
print(f"\n--- Results ({len(results)} examples) ---")
print(f"Any Match: {any_match_rate:.1%}")
print(f"Exact Match: {exact_match_rate:.1%}")
return {"any_match": any_match_rate, "exact_match": exact_match_rate}
async def main():
# First, preprocess data
print("=" * 60)
print("LOADING TEST DATA")
print("=" * 60)
data = []
with open("synthetic_data/training_dataset_1000.jsonl", "r") as f:
for line in f:
item = json.loads(line)
messages = []
for turn in item.get("conversation", []):
if isinstance(turn, dict):
messages.append({"role": turn["role"], "content": turn["content"]})
# Extract categories - handle nested labels structure
labels = item.get("labels", {})
if isinstance(labels, dict):
categories = labels.get("categories", [])
elif isinstance(labels, list):
categories = labels
else:
categories = []
if not categories:
# Parse from scenario_id
scenario_id = item.get("scenario_id", "")
if "." in scenario_id:
cat = scenario_id.split("_")[0]
categories = [cat]
data.append({
"messages": messages,
"categories": categories
})
print(f"Total examples: {len(data)}")
# Use last 50 as test
test_data = data[-50:]
print(f"Test examples: {len(test_data)}")
# Test RL model
rl_results = await test_model(RL_CHECKPOINT, "RL Model (5 iters)", test_data)
# Test SFT model for comparison
sft_results = await test_model(SFT_CHECKPOINT, "SFT Model", test_data)
print("\n" + "=" * 60)
print("COMPARISON")
print("=" * 60)
print(f"SFT Any Match: {sft_results['any_match']:.1%}")
print(f"RL Any Match: {rl_results['any_match']:.1%}")
print(f"Improvement: {(rl_results['any_match'] - sft_results['any_match'])*100:+.1f}pp")
if __name__ == "__main__":
asyncio.run(main())
|