""" Memory Routing RL Environment This implements the MemoryRoutingEnv for Stage 2 (RL Optimization) per PRD Section 8. Per Tinker docs (rl/rl-envs.mdx): - Env operates on tokens, not strings - Implement initial_observation() and step() - EnvGroupBuilder creates groups of environments - RLDataset provides batches of EnvGroupBuilders Per PRD Section 4 (Reward Computation): - R_F1: Token-level F1 between predicted and gold categories - R_temp: Persistence alignment (+1.0 exact, +0.5 adjacent, 0.0 otherwise) - R_parity: Company/user scope alignment - R_eff: Storage efficiency (penalize >3 categories) - R_total = 0.6*R_F1 + 0.2*R_temp + 0.1*R_parity + 0.1*R_eff Per PRD Section 4 (Environment Design): - Single-step bandit: initial_observation returns conversation, step terminates - EnvGroupBuilder clones each conversation across group_size rollouts """ import json from typing import List, Dict, Any, Tuple, Set, Optional, Sequence from dataclasses import dataclass # Memory taxonomy VALID_CATEGORIES = { "company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts", "company.business_priorities", "company.tools_config", "company.performance_context", "user.communication_style", "user.strategic_approach", "user.role_context", "user.workflow_patterns", "user.session_history", "user.interaction_preferences", "none" } # Persistence mapping CATEGORY_PERSISTENCE = { "company.brand_core": "long", "company.strategic_signatures": "long", "company.knowledge_artifacts": "long", "company.business_priorities": "short", "company.tools_config": "medium", "company.performance_context": "rolling", "user.communication_style": "long", "user.strategic_approach": "long", "user.role_context": "medium", "user.workflow_patterns": "medium", "user.session_history": "short", "user.interaction_preferences": "evolving", "none": "short" } # Scope mapping CATEGORY_SCOPE = { "company.brand_core": "company", "company.strategic_signatures": "company", "company.knowledge_artifacts": "company", "company.business_priorities": "company", "company.tools_config": "company", "company.performance_context": "company", "user.communication_style": "user", "user.strategic_approach": "user", "user.role_context": "user", "user.workflow_patterns": "user", "user.session_history": "user", "user.interaction_preferences": "user", "none": "none" } @dataclass class RewardComponents: """Breakdown of reward computation.""" r_f1: float = 0.0 r_temp: float = 0.0 r_parity: float = 0.0 r_eff: float = 0.0 r_total: float = 0.0 format_valid: bool = True def parse_categories(text: str) -> Tuple[Set[str], bool]: """ Parse comma-separated categories from model output. Returns: (set of valid categories, parse_success) """ if not text or not text.strip(): return set(), False # Split on comma, strip whitespace, lowercase raw_cats = [c.strip().lower() for c in text.split(",")] # Filter to valid categories valid_cats = {c for c in raw_cats if c in VALID_CATEGORIES} if not valid_cats: return set(), False # Check for invalid "none" mixing # Per PRD: "none" must be exclusive if "none" in valid_cats and len(valid_cats) > 1: valid_cats.discard("none") return valid_cats, True def compute_f1(predicted: Set[str], gold: Set[str]) -> float: """ Compute F1 score between predicted and gold category sets. Per PRD: Use macro-averaging if multi-label. """ if not predicted and not gold: return 1.0 if not predicted or not gold: return 0.0 true_positives = len(predicted & gold) precision = true_positives / len(predicted) if predicted else 0.0 recall = true_positives / len(gold) if gold else 0.0 if precision + recall == 0: return 0.0 return 2 * (precision * recall) / (precision + recall) def compute_temporal_reward(predicted: Set[str], gold: Set[str]) -> float: """ Compute temporal alignment reward. Per PRD: - +1.0 if predicted persistence matches gold - +0.5 if adjacent (long<->medium or medium<->short) - 0.0 otherwise - Use majority vote if multi-label """ if not predicted or not gold: return 0.0 # Get persistence for each category pred_persistence = [CATEGORY_PERSISTENCE.get(c, "medium") for c in predicted] gold_persistence = [CATEGORY_PERSISTENCE.get(c, "medium") for c in gold] # Majority vote def majority(items): from collections import Counter if not items: return "medium" counts = Counter(items) return counts.most_common(1)[0][0] pred_pers = majority(pred_persistence) gold_pers = majority(gold_persistence) # Exact match if pred_pers == gold_pers: return 1.0 # Adjacent match adjacency = { ("long", "medium"): True, ("medium", "long"): True, ("medium", "short"): True, ("short", "medium"): True, ("medium", "rolling"): True, ("rolling", "medium"): True, ("short", "rolling"): True, ("rolling", "short"): True, } if (pred_pers, gold_pers) in adjacency: return 0.5 return 0.0 def compute_parity_reward(predicted: Set[str], gold: Set[str]) -> float: """ Compute company/user scope alignment reward. Per PRD: - +1.0 if predicted scope matches gold scope exactly - 0.0 otherwise """ def get_scope(categories: Set[str]) -> str: scopes = {CATEGORY_SCOPE.get(c, "none") for c in categories} if "company" in scopes and "user" in scopes: return "mixed" elif "company" in scopes: return "company" elif "user" in scopes: return "user" else: return "none" pred_scope = get_scope(predicted) gold_scope = get_scope(gold) return 1.0 if pred_scope == gold_scope else 0.0 def compute_efficiency_reward(predicted: Set[str]) -> float: """ Compute storage efficiency reward. Per PRD: - 1.0 if ≤3 categories - 0.7 if 4 categories - 0.4 if 5 categories - 0.0 if ≥6 categories """ n = len(predicted) if n <= 3: return 1.0 elif n == 4: return 0.7 elif n == 5: return 0.4 else: return 0.0 def compute_reward(predicted_text: str, gold_categories: List[str]) -> RewardComponents: """ Compute full reward for a prediction. Per PRD Section 4: R_total = 0.6 * R_F1 + 0.2 * R_temp + 0.1 * R_parity + 0.1 * R_eff Returns RewardComponents with breakdown. """ result = RewardComponents() # Parse prediction predicted, parse_success = parse_categories(predicted_text) gold = set(gold_categories) # Format validation failure if not parse_success: result.format_valid = False result.r_total = -1.0 return result # Compute components result.r_f1 = compute_f1(predicted, gold) result.r_temp = compute_temporal_reward(predicted, gold) result.r_parity = compute_parity_reward(predicted, gold) result.r_eff = compute_efficiency_reward(predicted) # Weighted sum result.r_total = ( 0.6 * result.r_f1 + 0.2 * result.r_temp + 0.1 * result.r_parity + 0.1 * result.r_eff ) return result # Tinker Environment Classes # Per Tinker docs (rl/rl-envs.mdx) class MemoryRoutingEnv: """ Single-step bandit environment for memory routing. Per Tinker Env interface: - initial_observation() -> (Observation, StopCondition) - step(action) -> StepResult Per PRD: Single-step episodes - step() terminates immediately with reward. """ def __init__( self, conversation: List[Dict[str, str]], gold_categories: List[str], prompt_tokens: List[int], stop_tokens: List[int], scenario_id: str = "" ): self.conversation = conversation self.gold_categories = gold_categories self.prompt_tokens = prompt_tokens self.stop_tokens = stop_tokens self.scenario_id = scenario_id self._done = False async def initial_observation(self): """ Return the initial observation (prompt tokens) and stop condition. Per Tinker: Returns (Observation, StopCondition) - Observation is the model input (tokens) - StopCondition tells the sampler when to stop """ from tinker import types from tinker_cookbook.rl.types import StopCondition observation = types.ModelInput.from_ints(self.prompt_tokens) stop_condition = StopCondition(stop_tokens=self.stop_tokens) return observation, stop_condition async def step(self, action): """ Process the model's action (generated tokens) and return reward. Per Tinker: Returns StepResult with reward and done=True Per PRD: Single-step bandit, so always terminates. """ from tinker_cookbook.rl.types import StepResult # Decode action tokens to text # Note: In actual implementation, we'd use tokenizer.decode() # For now, assume action is already decoded text or we have tokenizer if isinstance(action, list): # Would decode here: action_text = tokenizer.decode(action) action_text = str(action) # Placeholder else: action_text = str(action) # Compute reward reward_components = compute_reward(action_text, self.gold_categories) self._done = True return StepResult( reward=reward_components.r_total, done=True, info={ "r_f1": reward_components.r_f1, "r_temp": reward_components.r_temp, "r_parity": reward_components.r_parity, "r_eff": reward_components.r_eff, "format_valid": reward_components.format_valid, "scenario_id": self.scenario_id } ) class MemoryRoutingEnvGroupBuilder: """ Builds a group of identical environments for variance reduction. Per Tinker docs (rl/rl-envs.mdx): - EnvGroupBuilder creates group_size copies of the same environment - This enables comparing multiple samples for the same input """ def __init__( self, conversation: List[Dict[str, str]], gold_categories: List[str], prompt_tokens: List[int], stop_tokens: List[int], group_size: int = 8, scenario_id: str = "" ): self.conversation = conversation self.gold_categories = gold_categories self.prompt_tokens = prompt_tokens self.stop_tokens = stop_tokens self.group_size = group_size self.scenario_id = scenario_id async def make_envs(self) -> Sequence["MemoryRoutingEnv"]: """Create group_size copies of the environment.""" return [ MemoryRoutingEnv( conversation=self.conversation, gold_categories=self.gold_categories, prompt_tokens=self.prompt_tokens, stop_tokens=self.stop_tokens, scenario_id=self.scenario_id ) for _ in range(self.group_size) ] def logging_tags(self) -> Dict[str, Any]: """Return tags for logging.""" return { "scenario_id": self.scenario_id, "num_gold_categories": len(self.gold_categories), "has_none": "none" in self.gold_categories } class MemoryRoutingDataset: """ Dataset of EnvGroupBuilders for RL training. Per Tinker docs (rl/rl-envs.mdx): - RLDataset.get_batch(index) returns list of EnvGroupBuilders """ def __init__( self, examples: List[Dict[str, Any]], batch_size: int, group_size: int, renderer, tokenizer ): self.examples = examples self.batch_size = batch_size self.group_size = group_size self.renderer = renderer self.tokenizer = tokenizer self.stop_tokens = renderer.get_stop_sequences() def __len__(self) -> int: return len(self.examples) // self.batch_size def get_batch(self, index: int) -> List[MemoryRoutingEnvGroupBuilder]: """Get a batch of EnvGroupBuilders.""" start_idx = (index * self.batch_size) % len(self.examples) end_idx = start_idx + self.batch_size if end_idx <= len(self.examples): batch_examples = self.examples[start_idx:end_idx] else: batch_examples = self.examples[start_idx:] batch_examples.extend(self.examples[:end_idx - len(self.examples)]) builders = [] for example in batch_examples: # Build prompt for this example messages = example.get("messages", []) if not messages: # Need to construct from conversation conversation = example.get("conversation", []) categories = example.get("labels", {}).get("categories", []) # Build without the assistant response (for generation) from training.preprocess import build_routing_prompt full_messages = build_routing_prompt(conversation, categories) # Remove assistant response for generation prompt messages = full_messages[:-1] # Tokenize prompt prompt = self.renderer.build_generation_prompt(messages) prompt_tokens = prompt.to_ints() # Get gold categories gold_categories = example.get("categories", []) if not gold_categories: gold_categories = example.get("labels", {}).get("categories", []) builders.append(MemoryRoutingEnvGroupBuilder( conversation=example.get("conversation", []), gold_categories=gold_categories, prompt_tokens=prompt_tokens, stop_tokens=self.stop_tokens, group_size=self.group_size, scenario_id=example.get("scenario_id", "") )) return builders class MemoryRoutingDatasetBuilder: """ Factory for creating train/test RL datasets. Per Tinker pattern from math_env.py example. """ def __init__( self, train_data_path: str, test_data_path: str, batch_size: int = 64, group_size: int = 8, model_name: str = "meta-llama/Llama-3.1-8B", renderer_name: str = "llama3" ): self.train_data_path = train_data_path self.test_data_path = test_data_path self.batch_size = batch_size self.group_size = group_size self.model_name = model_name self.renderer_name = renderer_name def __call__(self) -> Tuple[MemoryRoutingDataset, MemoryRoutingDataset]: """Create train and test datasets.""" from tinker_cookbook import renderers, tokenizer_utils tokenizer = tokenizer_utils.get_tokenizer(self.model_name) renderer = renderers.get_renderer(name=self.renderer_name, tokenizer=tokenizer) # Load data with open(self.train_data_path, "r") as f: train_examples = json.load(f) with open(self.test_data_path, "r") as f: test_examples = json.load(f) train_dataset = MemoryRoutingDataset( examples=train_examples, batch_size=self.batch_size, group_size=self.group_size, renderer=renderer, tokenizer=tokenizer ) test_dataset = MemoryRoutingDataset( examples=test_examples, batch_size=min(self.batch_size, len(test_examples)), group_size=self.group_size, renderer=renderer, tokenizer=tokenizer ) return train_dataset, test_dataset # Test the reward computation if __name__ == "__main__": # Test cases test_cases = [ # (predicted_text, gold_categories, expected_valid) ("company.brand_core, user.strategic_approach", ["company.brand_core", "user.strategic_approach"], True), ("none", ["none"], True), ("company.brand_core, none", ["company.brand_core"], True), # none should be removed ("invalid_category", ["company.brand_core"], False), ("", ["company.brand_core"], False), ("company.brand_core", ["company.brand_core", "user.role_context"], True), # Partial match ] print("Testing reward computation:") print("=" * 60) for pred, gold, expected_valid in test_cases: result = compute_reward(pred, gold) print(f"\nPredicted: '{pred}'") print(f"Gold: {gold}") print(f"Format valid: {result.format_valid} (expected: {expected_valid})") print(f"R_F1: {result.r_f1:.3f}") print(f"R_temp: {result.r_temp:.3f}") print(f"R_parity: {result.r_parity:.3f}") print(f"R_eff: {result.r_eff:.3f}") print(f"R_total: {result.r_total:.3f}")