""" Data Preprocessing for Memory Routing Training This script converts synthetic JSONL conversations to Tinker-compatible types.Datum objects for supervised fine-tuning. Per Tinker docs (rendering.mdx): - Use renderer.build_supervised_example() to get tokens and weights - Weights indicate which tokens to train on (1.0 for completion, 0.0 for prompt) - Target tokens are shifted by 1 (predicting next token) Per PRD Section 6.6: - Validate datum length <= 4096 - Ensure non-zero weights - Verify token IDs are within vocab range """ import json import os from typing import List, Dict, Any, Tuple from dataclasses import dataclass # Note: These imports require tinker and tinker-cookbook to be installed # pip install git+https://github.com/thinking-machines-lab/tinker.git # pip install git+https://github.com/thinking-machines-lab/tinker-cookbook.git MODEL_NAME = "meta-llama/Llama-3.1-8B" RENDERER_NAME = "llama3" MAX_SEQUENCE_LENGTH = 4096 # Memory taxonomy for validation VALID_CATEGORIES = { "company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts", "company.business_priorities", "company.tools_config", "company.performance_context", "user.communication_style", "user.strategic_approach", "user.role_context", "user.workflow_patterns", "user.session_history", "user.interaction_preferences", "none" } @dataclass class PreprocessingStats: total_examples: int = 0 valid_examples: int = 0 skipped_too_long: int = 0 skipped_zero_weights: int = 0 skipped_invalid_tokens: int = 0 skipped_invalid_categories: int = 0 def build_routing_prompt(conversation: List[Dict[str, str]], categories: List[str]) -> List[Dict[str, str]]: """ Build the full conversation for training, including: 1. System prompt with taxonomy 2. User message with conversation 3. Assistant response with categories Per PRD Section 6 - Student Prompt format. """ # System prompt with taxonomy system_content = """You route marketing conversations into structured memory categories. Available categories: - company.brand_core: Voice, values, positioning, identity anchors (Long >1y) - company.strategic_signatures: Decision frameworks, strategic heuristics (Long >1y) - company.knowledge_artifacts: Docs, style guides, playbooks (Long >1y) - company.business_priorities: Quarterly/seasonal goals, active campaigns (Short <3m) - company.tools_config: Integrations, API keys, workflow settings (Medium ~6m) - company.performance_context: Campaign metrics, retrospectives, learnings (Rolling ~6m) - user.communication_style: Tone, verbosity, format expectations (Long >1y) - user.strategic_approach: Personal priorities, success definitions (Long >1y) - user.role_context: Title, scope, decision authority (Medium ~1y) - user.workflow_patterns: Review cadence, collaboration norms (Medium ~1y) - user.session_history: Immediate context, recent asks (Short <2w) - user.interaction_preferences: Coaching style, feedback expectations (Evolving) - none: Irrelevant, vague, or transactional content Respond with comma-separated categories. Use 'none' only if no other category applies.""" # Format the conversation for the user message conversation_text = "" for turn in conversation: # Handle malformed turns (string instead of dict) if isinstance(turn, str): conversation_text += f"UNKNOWN: {turn}\n" continue if not isinstance(turn, dict): continue role = turn.get("role", "unknown") content = turn.get("content", "") conversation_text += f"{role.upper()}: {content}\n" user_content = f"Conversation:\n{conversation_text.strip()}\n\nWhat memory categories apply?" # Assistant response is the comma-separated categories assistant_content = ", ".join(categories) return [ {"role": "system", "content": system_content}, {"role": "user", "content": user_content}, {"role": "assistant", "content": assistant_content} ] def load_synthetic_data(filepath: str) -> List[Dict[str, Any]]: """Load synthetic data from JSONL file.""" data = [] with open(filepath, "r") as f: for line in f: if line.strip(): item = json.loads(line) data.append(item) return data def validate_categories(categories: List[str]) -> bool: """Validate that all categories are in the taxonomy.""" return all(cat in VALID_CATEGORIES for cat in categories) def preprocess_example_mock(example: Dict[str, Any], stats: PreprocessingStats) -> Dict[str, Any] | None: """ Mock preprocessing that validates structure without Tinker. Returns a dict representation of what would become a Datum. Use this for testing without Tinker installed. """ conversation = example.get("conversation", []) labels = example.get("labels", {}) categories = labels.get("categories", []) # Validate categories if not validate_categories(categories): stats.skipped_invalid_categories += 1 return None # Build the full training conversation training_messages = build_routing_prompt(conversation, categories) # Mock token estimation (rough: 4 chars per token) total_chars = sum(len(m["content"]) for m in training_messages) estimated_tokens = total_chars // 4 if estimated_tokens > MAX_SEQUENCE_LENGTH: stats.skipped_too_long += 1 return None stats.valid_examples += 1 return { "messages": training_messages, "categories": categories, "estimated_tokens": estimated_tokens, "scenario_id": example.get("scenario_id", "unknown") } def preprocess_with_tinker(example: Dict[str, Any], renderer, tokenizer, vocab_size: int, stats: PreprocessingStats): """ Full preprocessing with Tinker renderer. Per Tinker docs (rendering.mdx): - build_supervised_example returns (tokens, weights) - weights=1.0 for completion tokens, weights=0.0 for prompt tokens Per Tinker docs (training-sampling.mdx): - input_tokens = tokens[:-1] - target_tokens = tokens[1:] # Shifted for next-token prediction - weights = weights[1:] """ from tinker import types conversation = example.get("conversation", []) labels = example.get("labels", {}) categories = labels.get("categories", []) # Validate categories if not validate_categories(categories): stats.skipped_invalid_categories += 1 return None # Build the full training conversation training_messages = build_routing_prompt(conversation, categories) # Use renderer to tokenize and get weights # Per Tinker rendering.mdx: build_supervised_example returns tokens and weights tokens, weights = renderer.build_supervised_example(training_messages) # Check sequence length if len(tokens) > MAX_SEQUENCE_LENGTH: stats.skipped_too_long += 1 return None # Prepare for next-token prediction # Per Tinker training-sampling.mdx example input_tokens = tokens[:-1] target_tokens = tokens[1:] loss_weights = weights[1:] # Validate non-zero weights if sum(loss_weights) == 0: stats.skipped_zero_weights += 1 return None # Validate token IDs if not all(0 <= t < vocab_size for t in target_tokens): stats.skipped_invalid_tokens += 1 return None # Create Datum object # Per Tinker types (Datum class) datum = types.Datum( model_input=types.ModelInput.from_ints(input_tokens), loss_fn_inputs=dict( target_tokens=target_tokens, weights=loss_weights ) ) stats.valid_examples += 1 return datum def preprocess_dataset( input_path: str, output_dir: str, use_tinker: bool = False, train_split: float = 0.8 ) -> Tuple[PreprocessingStats, str, str]: """ Preprocess the full dataset. Args: input_path: Path to training_dataset_1000.jsonl output_dir: Directory to save processed data use_tinker: Whether to use actual Tinker (requires installation) train_split: Fraction for training (rest is test) Returns: stats, train_path, test_path """ os.makedirs(output_dir, exist_ok=True) # Load data print(f"Loading data from {input_path}...") raw_data = load_synthetic_data(input_path) print(f"Loaded {len(raw_data)} examples") stats = PreprocessingStats(total_examples=len(raw_data)) if use_tinker: # Import Tinker components from tinker_cookbook import renderers, tokenizer_utils print(f"Initializing tokenizer for {MODEL_NAME}...") tokenizer = tokenizer_utils.get_tokenizer(MODEL_NAME) renderer = renderers.get_renderer(name=RENDERER_NAME, tokenizer=tokenizer) vocab_size = len(tokenizer) print(f"Vocab size: {vocab_size}") processed_data = [] for i, example in enumerate(raw_data): if i % 100 == 0: print(f"Processing {i}/{len(raw_data)}...") datum = preprocess_with_tinker(example, renderer, tokenizer, vocab_size, stats) if datum is not None: processed_data.append(datum) else: # Mock preprocessing for testing print("Running mock preprocessing (no Tinker)...") processed_data = [] for i, example in enumerate(raw_data): if i % 100 == 0: print(f"Processing {i}/{len(raw_data)}...") result = preprocess_example_mock(example, stats) if result is not None: processed_data.append(result) # Split into train/test split_idx = int(len(processed_data) * train_split) train_data = processed_data[:split_idx] test_data = processed_data[split_idx:] # Save processed data train_path = os.path.join(output_dir, "train_data.json") test_path = os.path.join(output_dir, "test_data.json") with open(train_path, "w") as f: json.dump([d if isinstance(d, dict) else d.model_dump() for d in train_data], f) with open(test_path, "w") as f: json.dump([d if isinstance(d, dict) else d.model_dump() for d in test_data], f) print(f"\n=== Preprocessing Complete ===") print(f"Total examples: {stats.total_examples}") print(f"Valid examples: {stats.valid_examples}") print(f"Skipped (too long): {stats.skipped_too_long}") print(f"Skipped (zero weights): {stats.skipped_zero_weights}") print(f"Skipped (invalid tokens): {stats.skipped_invalid_tokens}") print(f"Skipped (invalid categories): {stats.skipped_invalid_categories}") print(f"\nTrain set: {len(train_data)} examples") print(f"Test set: {len(test_data)} examples") print(f"\nSaved to:") print(f" Train: {train_path}") print(f" Test: {test_path}") return stats, train_path, test_path if __name__ == "__main__": import sys input_path = sys.argv[1] if len(sys.argv) > 1 else "synthetic_data/training_dataset_1000.jsonl" output_dir = sys.argv[2] if len(sys.argv) > 2 else "training/processed_data" use_tinker = "--tinker" in sys.argv preprocess_dataset(input_path, output_dir, use_tinker=use_tinker)