|
|
""" |
|
|
Data Preprocessing for Memory Routing Training |
|
|
|
|
|
This script converts synthetic JSONL conversations to Tinker-compatible |
|
|
types.Datum objects for supervised fine-tuning. |
|
|
|
|
|
Per Tinker docs (rendering.mdx): |
|
|
- Use renderer.build_supervised_example() to get tokens and weights |
|
|
- Weights indicate which tokens to train on (1.0 for completion, 0.0 for prompt) |
|
|
- Target tokens are shifted by 1 (predicting next token) |
|
|
|
|
|
Per PRD Section 6.6: |
|
|
- Validate datum length <= 4096 |
|
|
- Ensure non-zero weights |
|
|
- Verify token IDs are within vocab range |
|
|
""" |
|
|
|
|
|
import json |
|
|
import os |
|
|
from typing import List, Dict, Any, Tuple |
|
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = "meta-llama/Llama-3.1-8B" |
|
|
RENDERER_NAME = "llama3" |
|
|
MAX_SEQUENCE_LENGTH = 4096 |
|
|
|
|
|
|
|
|
VALID_CATEGORIES = { |
|
|
"company.brand_core", |
|
|
"company.strategic_signatures", |
|
|
"company.knowledge_artifacts", |
|
|
"company.business_priorities", |
|
|
"company.tools_config", |
|
|
"company.performance_context", |
|
|
"user.communication_style", |
|
|
"user.strategic_approach", |
|
|
"user.role_context", |
|
|
"user.workflow_patterns", |
|
|
"user.session_history", |
|
|
"user.interaction_preferences", |
|
|
"none" |
|
|
} |
|
|
|
|
|
@dataclass |
|
|
class PreprocessingStats: |
|
|
total_examples: int = 0 |
|
|
valid_examples: int = 0 |
|
|
skipped_too_long: int = 0 |
|
|
skipped_zero_weights: int = 0 |
|
|
skipped_invalid_tokens: int = 0 |
|
|
skipped_invalid_categories: int = 0 |
|
|
|
|
|
|
|
|
def build_routing_prompt(conversation: List[Dict[str, str]], categories: List[str]) -> List[Dict[str, str]]: |
|
|
""" |
|
|
Build the full conversation for training, including: |
|
|
1. System prompt with taxonomy |
|
|
2. User message with conversation |
|
|
3. Assistant response with categories |
|
|
|
|
|
Per PRD Section 6 - Student Prompt format. |
|
|
""" |
|
|
|
|
|
system_content = """You route marketing conversations into structured memory categories. |
|
|
|
|
|
Available categories: |
|
|
- company.brand_core: Voice, values, positioning, identity anchors (Long >1y) |
|
|
- company.strategic_signatures: Decision frameworks, strategic heuristics (Long >1y) |
|
|
- company.knowledge_artifacts: Docs, style guides, playbooks (Long >1y) |
|
|
- company.business_priorities: Quarterly/seasonal goals, active campaigns (Short <3m) |
|
|
- company.tools_config: Integrations, API keys, workflow settings (Medium ~6m) |
|
|
- company.performance_context: Campaign metrics, retrospectives, learnings (Rolling ~6m) |
|
|
- user.communication_style: Tone, verbosity, format expectations (Long >1y) |
|
|
- user.strategic_approach: Personal priorities, success definitions (Long >1y) |
|
|
- user.role_context: Title, scope, decision authority (Medium ~1y) |
|
|
- user.workflow_patterns: Review cadence, collaboration norms (Medium ~1y) |
|
|
- user.session_history: Immediate context, recent asks (Short <2w) |
|
|
- user.interaction_preferences: Coaching style, feedback expectations (Evolving) |
|
|
- none: Irrelevant, vague, or transactional content |
|
|
|
|
|
Respond with comma-separated categories. Use 'none' only if no other category applies.""" |
|
|
|
|
|
|
|
|
conversation_text = "" |
|
|
for turn in conversation: |
|
|
|
|
|
if isinstance(turn, str): |
|
|
conversation_text += f"UNKNOWN: {turn}\n" |
|
|
continue |
|
|
if not isinstance(turn, dict): |
|
|
continue |
|
|
role = turn.get("role", "unknown") |
|
|
content = turn.get("content", "") |
|
|
conversation_text += f"{role.upper()}: {content}\n" |
|
|
|
|
|
user_content = f"Conversation:\n{conversation_text.strip()}\n\nWhat memory categories apply?" |
|
|
|
|
|
|
|
|
assistant_content = ", ".join(categories) |
|
|
|
|
|
return [ |
|
|
{"role": "system", "content": system_content}, |
|
|
{"role": "user", "content": user_content}, |
|
|
{"role": "assistant", "content": assistant_content} |
|
|
] |
|
|
|
|
|
|
|
|
def load_synthetic_data(filepath: str) -> List[Dict[str, Any]]: |
|
|
"""Load synthetic data from JSONL file.""" |
|
|
data = [] |
|
|
with open(filepath, "r") as f: |
|
|
for line in f: |
|
|
if line.strip(): |
|
|
item = json.loads(line) |
|
|
data.append(item) |
|
|
return data |
|
|
|
|
|
|
|
|
def validate_categories(categories: List[str]) -> bool: |
|
|
"""Validate that all categories are in the taxonomy.""" |
|
|
return all(cat in VALID_CATEGORIES for cat in categories) |
|
|
|
|
|
|
|
|
def preprocess_example_mock(example: Dict[str, Any], stats: PreprocessingStats) -> Dict[str, Any] | None: |
|
|
""" |
|
|
Mock preprocessing that validates structure without Tinker. |
|
|
Returns a dict representation of what would become a Datum. |
|
|
|
|
|
Use this for testing without Tinker installed. |
|
|
""" |
|
|
conversation = example.get("conversation", []) |
|
|
labels = example.get("labels", {}) |
|
|
categories = labels.get("categories", []) |
|
|
|
|
|
|
|
|
if not validate_categories(categories): |
|
|
stats.skipped_invalid_categories += 1 |
|
|
return None |
|
|
|
|
|
|
|
|
training_messages = build_routing_prompt(conversation, categories) |
|
|
|
|
|
|
|
|
total_chars = sum(len(m["content"]) for m in training_messages) |
|
|
estimated_tokens = total_chars // 4 |
|
|
|
|
|
if estimated_tokens > MAX_SEQUENCE_LENGTH: |
|
|
stats.skipped_too_long += 1 |
|
|
return None |
|
|
|
|
|
stats.valid_examples += 1 |
|
|
|
|
|
return { |
|
|
"messages": training_messages, |
|
|
"categories": categories, |
|
|
"estimated_tokens": estimated_tokens, |
|
|
"scenario_id": example.get("scenario_id", "unknown") |
|
|
} |
|
|
|
|
|
|
|
|
def preprocess_with_tinker(example: Dict[str, Any], renderer, tokenizer, vocab_size: int, stats: PreprocessingStats): |
|
|
""" |
|
|
Full preprocessing with Tinker renderer. |
|
|
|
|
|
Per Tinker docs (rendering.mdx): |
|
|
- build_supervised_example returns (tokens, weights) |
|
|
- weights=1.0 for completion tokens, weights=0.0 for prompt tokens |
|
|
|
|
|
Per Tinker docs (training-sampling.mdx): |
|
|
- input_tokens = tokens[:-1] |
|
|
- target_tokens = tokens[1:] # Shifted for next-token prediction |
|
|
- weights = weights[1:] |
|
|
""" |
|
|
from tinker import types |
|
|
|
|
|
conversation = example.get("conversation", []) |
|
|
labels = example.get("labels", {}) |
|
|
categories = labels.get("categories", []) |
|
|
|
|
|
|
|
|
if not validate_categories(categories): |
|
|
stats.skipped_invalid_categories += 1 |
|
|
return None |
|
|
|
|
|
|
|
|
training_messages = build_routing_prompt(conversation, categories) |
|
|
|
|
|
|
|
|
|
|
|
tokens, weights = renderer.build_supervised_example(training_messages) |
|
|
|
|
|
|
|
|
if len(tokens) > MAX_SEQUENCE_LENGTH: |
|
|
stats.skipped_too_long += 1 |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
input_tokens = tokens[:-1] |
|
|
target_tokens = tokens[1:] |
|
|
loss_weights = weights[1:] |
|
|
|
|
|
|
|
|
if sum(loss_weights) == 0: |
|
|
stats.skipped_zero_weights += 1 |
|
|
return None |
|
|
|
|
|
|
|
|
if not all(0 <= t < vocab_size for t in target_tokens): |
|
|
stats.skipped_invalid_tokens += 1 |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
datum = types.Datum( |
|
|
model_input=types.ModelInput.from_ints(input_tokens), |
|
|
loss_fn_inputs=dict( |
|
|
target_tokens=target_tokens, |
|
|
weights=loss_weights |
|
|
) |
|
|
) |
|
|
|
|
|
stats.valid_examples += 1 |
|
|
return datum |
|
|
|
|
|
|
|
|
def preprocess_dataset( |
|
|
input_path: str, |
|
|
output_dir: str, |
|
|
use_tinker: bool = False, |
|
|
train_split: float = 0.8 |
|
|
) -> Tuple[PreprocessingStats, str, str]: |
|
|
""" |
|
|
Preprocess the full dataset. |
|
|
|
|
|
Args: |
|
|
input_path: Path to training_dataset_1000.jsonl |
|
|
output_dir: Directory to save processed data |
|
|
use_tinker: Whether to use actual Tinker (requires installation) |
|
|
train_split: Fraction for training (rest is test) |
|
|
|
|
|
Returns: |
|
|
stats, train_path, test_path |
|
|
""" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
print(f"Loading data from {input_path}...") |
|
|
raw_data = load_synthetic_data(input_path) |
|
|
print(f"Loaded {len(raw_data)} examples") |
|
|
|
|
|
stats = PreprocessingStats(total_examples=len(raw_data)) |
|
|
|
|
|
if use_tinker: |
|
|
|
|
|
from tinker_cookbook import renderers, tokenizer_utils |
|
|
|
|
|
print(f"Initializing tokenizer for {MODEL_NAME}...") |
|
|
tokenizer = tokenizer_utils.get_tokenizer(MODEL_NAME) |
|
|
renderer = renderers.get_renderer(name=RENDERER_NAME, tokenizer=tokenizer) |
|
|
vocab_size = len(tokenizer) |
|
|
print(f"Vocab size: {vocab_size}") |
|
|
|
|
|
processed_data = [] |
|
|
for i, example in enumerate(raw_data): |
|
|
if i % 100 == 0: |
|
|
print(f"Processing {i}/{len(raw_data)}...") |
|
|
datum = preprocess_with_tinker(example, renderer, tokenizer, vocab_size, stats) |
|
|
if datum is not None: |
|
|
processed_data.append(datum) |
|
|
else: |
|
|
|
|
|
print("Running mock preprocessing (no Tinker)...") |
|
|
processed_data = [] |
|
|
for i, example in enumerate(raw_data): |
|
|
if i % 100 == 0: |
|
|
print(f"Processing {i}/{len(raw_data)}...") |
|
|
result = preprocess_example_mock(example, stats) |
|
|
if result is not None: |
|
|
processed_data.append(result) |
|
|
|
|
|
|
|
|
split_idx = int(len(processed_data) * train_split) |
|
|
train_data = processed_data[:split_idx] |
|
|
test_data = processed_data[split_idx:] |
|
|
|
|
|
|
|
|
train_path = os.path.join(output_dir, "train_data.json") |
|
|
test_path = os.path.join(output_dir, "test_data.json") |
|
|
|
|
|
with open(train_path, "w") as f: |
|
|
json.dump([d if isinstance(d, dict) else d.model_dump() for d in train_data], f) |
|
|
|
|
|
with open(test_path, "w") as f: |
|
|
json.dump([d if isinstance(d, dict) else d.model_dump() for d in test_data], f) |
|
|
|
|
|
print(f"\n=== Preprocessing Complete ===") |
|
|
print(f"Total examples: {stats.total_examples}") |
|
|
print(f"Valid examples: {stats.valid_examples}") |
|
|
print(f"Skipped (too long): {stats.skipped_too_long}") |
|
|
print(f"Skipped (zero weights): {stats.skipped_zero_weights}") |
|
|
print(f"Skipped (invalid tokens): {stats.skipped_invalid_tokens}") |
|
|
print(f"Skipped (invalid categories): {stats.skipped_invalid_categories}") |
|
|
print(f"\nTrain set: {len(train_data)} examples") |
|
|
print(f"Test set: {len(test_data)} examples") |
|
|
print(f"\nSaved to:") |
|
|
print(f" Train: {train_path}") |
|
|
print(f" Test: {test_path}") |
|
|
|
|
|
return stats, train_path, test_path |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import sys |
|
|
|
|
|
input_path = sys.argv[1] if len(sys.argv) > 1 else "synthetic_data/training_dataset_1000.jsonl" |
|
|
output_dir = sys.argv[2] if len(sys.argv) > 2 else "training/processed_data" |
|
|
use_tinker = "--tinker" in sys.argv |
|
|
|
|
|
preprocess_dataset(input_path, output_dir, use_tinker=use_tinker) |
|
|
|
|
|
|