MuratcanKoylan's picture
Upload folder using huggingface_hub
685d968 verified
"""
Reinforcement Learning Training for Memory Routing
This implements Stage 2 of the PRD: RL Optimization using Tinker's
importance_sampling loss function.
Per Tinker docs (rl.mdx):
- RL learns from trial and error with reward functions
- Use importance_sampling loss for policy gradient
Per Tinker docs (rl/rl-loops.mdx):
1. Create policy with current weights
2. Generate rollouts
3. Process trajectory data into training examples
4. Update model parameters
Per PRD Section 8:
- 25 iterations minimum
- Group size 8 for variance reduction
- KL divergence monitoring (<0.01 threshold)
"""
import asyncio
import json
import time
import os
from typing import List, Dict, Any, Tuple
from dataclasses import dataclass
from collections import Counter
# Configuration
@dataclass
class RLConfig:
# Model - start from SFT checkpoint
sft_checkpoint: str = "" # Will be set from command line
base_model: str = "meta-llama/Llama-3.1-8B"
lora_rank: int = 32
renderer_name: str = "llama3"
# Training
num_iterations: int = 25
batch_size: int = 32 # Number of unique conversations per iteration
group_size: int = 8 # Rollouts per conversation for variance reduction
learning_rate: float = 1e-5 # Lower than SFT per Tinker RL docs
# Adam optimizer
beta1: float = 0.9
beta2: float = 0.95
eps: float = 1e-8
# Sampling
max_tokens: int = 50
temperature: float = 0.7
# Monitoring
kl_threshold: float = 0.01 # Per PRD: warn if KL > 0.01
checkpoint_every: int = 5
# Paths
train_data_path: str = "training/processed_data/train_data.json"
log_path: str = "training/logs/rl"
# Import reward computation from rl_env
import sys
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from rl_env import compute_reward, VALID_CATEGORIES
def build_routing_prompt(conversation: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""Build the routing prompt for a conversation."""
system_content = """You route marketing conversations into structured memory categories.
Available categories:
- company.brand_core: Voice, values, positioning
- company.strategic_signatures: Decision frameworks
- company.knowledge_artifacts: Docs, style guides
- company.business_priorities: Quarterly goals, campaigns
- company.tools_config: Integrations, settings
- company.performance_context: Campaign metrics
- user.communication_style: Tone, format expectations
- user.strategic_approach: Personal priorities
- user.role_context: Title, scope
- user.workflow_patterns: Review cadence
- user.session_history: Recent context
- user.interaction_preferences: Coaching style
- none: Irrelevant or transactional
Respond with comma-separated categories."""
# Format the conversation
conversation_text = ""
for turn in conversation:
if isinstance(turn, dict):
role = turn.get("role", "unknown")
content = turn.get("content", "")
conversation_text += f"{role.upper()}: {content}\n"
return [
{"role": "system", "content": system_content},
{"role": "user", "content": f"Conversation:\n{conversation_text.strip()}\n\nCategories?"}
]
async def run_rl_training(config: RLConfig):
"""
Main RL training loop.
Per Tinker docs (rl/rl-loops.mdx):
1. Create policy with current weights
2. Generate rollouts (sample from model)
3. Compute rewards and advantages
4. Update with importance_sampling loss
"""
import tinker
from tinker import types
from tinker_cookbook import renderers
from tinker_cookbook.tokenizer_utils import get_tokenizer
from tinker_cookbook.hyperparam_utils import get_lr
from dotenv import load_dotenv
import numpy as np
load_dotenv()
os.makedirs(config.log_path, exist_ok=True)
print("=" * 70)
print("REINFORCEMENT LEARNING TRAINING")
print("=" * 70)
# Load training data
print(f"\nLoading training data from {config.train_data_path}...")
with open(config.train_data_path, "r") as f:
train_data = json.load(f)
print(f"Loaded {len(train_data)} examples")
# Initialize Tinker
print("\nInitializing Tinker...")
service_client = tinker.ServiceClient()
# For RL, we need to:
# 1. Create a training client (fresh LoRA weights)
# 2. Use the SFT checkpoint directly for sampling
# Note: We cannot load sampler_weights into training client
# The SFT checkpoint is a sampler checkpoint, not a full state checkpoint
print(f"Creating training client (base: {config.base_model})...")
training_client = await service_client.create_lora_training_client_async(
base_model=config.base_model,
rank=config.lora_rank,
)
# Note: For proper RL continuation from SFT, we would need:
# 1. SFT to save with save_state() not save_weights_for_sampler()
# 2. Then load_state() here
# For now, we'll use the SFT checkpoint directly for initial sampling
print(f"SFT checkpoint for reference: {config.sft_checkpoint}")
print("Note: Using fresh LoRA weights for training, SFT checkpoint for initial sampling")
# Get tokenizer and renderer
tokenizer = get_tokenizer(config.base_model)
renderer = renderers.get_renderer(name=config.renderer_name, tokenizer=tokenizer)
stop_sequences = renderer.get_stop_sequences()
print(f"""
RL Training Configuration:
--------------------------
SFT Checkpoint: {config.sft_checkpoint}
Base Model: {config.base_model}
LoRA Rank: {config.lora_rank}
Iterations: {config.num_iterations}
Batch Size: {config.batch_size}
Group Size: {config.group_size}
Learning Rate: {config.learning_rate:.2e}
Temperature: {config.temperature}
KL Threshold: {config.kl_threshold}
""")
metrics_log = []
final_checkpoint = None
for iteration in range(config.num_iterations):
iter_start = time.time()
print(f"\n{'='*70}")
print(f"Iteration {iteration + 1}/{config.num_iterations}")
print(f"{'='*70}")
# === STEP 1: Get sampling client ===
print("\n[1/4] Getting sampling client...")
if iteration == 0:
# First iteration: use SFT checkpoint
sampling_path = config.sft_checkpoint
print(f" Using SFT checkpoint: {sampling_path}")
else:
# Subsequent iterations: save current weights
save_future = await training_client.save_weights_for_sampler_async(
name=f"rl_iter_{iteration:04d}"
)
save_result = await save_future.result_async()
sampling_path = save_result.path
print(f" Saved new checkpoint: {sampling_path}")
# Create sampling client
sampling_client = service_client.create_sampling_client(model_path=sampling_path)
# === STEP 2: Generate rollouts ===
print("[2/4] Generating rollouts...")
# Sample batch of conversations
batch_indices = np.random.choice(len(train_data), size=config.batch_size, replace=False)
batch_examples = [train_data[i] for i in batch_indices]
all_rollouts = []
all_rewards = []
all_gold_categories = []
for example in batch_examples:
# Get gold categories
gold_categories = example.get("categories", [])
if not gold_categories:
gold_categories = example.get("labels", {}).get("categories", [])
# Get messages - the data is already in messages format from preprocessing
messages = example.get("messages", [])
if not messages:
continue
# Remove the assistant response if present (we want to generate it)
prompt_messages = [m for m in messages if m.get("role") != "assistant"]
# Build prompt from the messages (already formatted)
prompt = renderer.build_generation_prompt(prompt_messages)
# Sample group_size rollouts
params = types.SamplingParams(
max_tokens=config.max_tokens,
temperature=config.temperature,
stop=stop_sequences
)
result = sampling_client.sample(
prompt=prompt,
sampling_params=params,
num_samples=config.group_size
).result()
# Process each rollout
for seq in result.sequences:
response_message, success = renderer.parse_response(seq.tokens)
predicted_text = response_message["content"] if success else ""
# Compute reward
reward_result = compute_reward(predicted_text, gold_categories)
# Debug: print first few
if len(all_rollouts) < 3:
print(f" DEBUG: predicted='{predicted_text}', gold={gold_categories}, reward={reward_result.r_total:.3f}")
all_rollouts.append({
"prompt": prompt,
"tokens": seq.tokens,
"logprobs": seq.logprobs if seq.logprobs else [],
"predicted": predicted_text,
"gold": gold_categories
})
all_rewards.append(reward_result.r_total)
all_gold_categories.append(gold_categories)
# === STEP 3: Compute advantages ===
print("[3/4] Computing advantages...")
rewards_array = np.array(all_rewards)
mean_reward = rewards_array.mean()
std_reward = rewards_array.std() + 1e-8
# Normalize rewards to get advantages
advantages = (rewards_array - mean_reward) / std_reward
# === STEP 4: Update model ===
print("[4/4] Updating model...")
# Build training data
# Per Tinker losses.mdx: importance_sampling needs target_tokens, logprobs, advantages
# All arrays must have length N where model_input has length N
training_data = []
for i, rollout in enumerate(all_rollouts):
if not rollout["logprobs"] or len(rollout["logprobs"]) == 0:
continue
tokens = rollout["tokens"]
logprobs = rollout["logprobs"]
advantage = advantages[i]
# Get prompt tokens
prompt_tokens = rollout["prompt"].to_ints()
# For importance_sampling, per Tinker rl/train.py example:
# - model_input: the input sequence (prompt + completion[:-1])
# - target_tokens: what we predict (completion tokens)
# - logprobs: sampling logprobs for target_tokens
# - advantages: advantage values for each token
n_gen = len(tokens)
if n_gen < 1 or len(logprobs) != n_gen:
continue
# Full input sequence
full_input = prompt_tokens + tokens[:-1] if n_gen > 1 else prompt_tokens
n_input = len(full_input)
# Target tokens (shifted by 1 for next-token prediction)
# We need to include prompt targets too for proper alignment
full_target = prompt_tokens[1:] + tokens if len(prompt_tokens) > 0 else tokens
# Logprobs: 0 for prompt positions, actual logprobs for completion
n_prompt = len(prompt_tokens) - 1 if len(prompt_tokens) > 0 else 0
full_logprobs = [0.0] * n_prompt + logprobs
# Advantages: 0 for prompt, actual advantage for completion
full_advantages = [0.0] * n_prompt + [advantage] * n_gen
# Verify all lengths match
if len(full_target) != n_input or len(full_logprobs) != n_input or len(full_advantages) != n_input:
# Length mismatch, skip
continue
datum = types.Datum(
model_input=types.ModelInput.from_ints(full_input),
loss_fn_inputs=dict(
target_tokens=full_target,
logprobs=full_logprobs,
advantages=full_advantages
)
)
training_data.append(datum)
if training_data:
# Forward-backward with importance_sampling loss
fwd_bwd_future = await training_client.forward_backward_async(
training_data,
loss_fn="importance_sampling"
)
# Optimizer step
adam_params = types.AdamParams(
learning_rate=config.learning_rate,
beta1=config.beta1,
beta2=config.beta2,
eps=config.eps,
)
optim_future = await training_client.optim_step_async(adam_params)
# Wait for results
fwd_bwd_result = await fwd_bwd_future.result_async()
optim_result = await optim_future.result_async()
# Extract KL divergence from outputs
kl_values = []
for output in fwd_bwd_result.loss_fn_outputs:
if "logprobs" in output:
new_logprobs = output["logprobs"].tolist()
# KL approximation
kl_values.extend(new_logprobs)
# Compute metrics
iter_time = time.time() - iter_start
# Category prediction accuracy
correct = 0
total = len(all_rollouts)
for rollout in all_rollouts:
predicted_set = set([x.strip() for x in rollout["predicted"].split(",") if x.strip() in VALID_CATEGORIES])
gold_set = set(rollout["gold"])
if predicted_set.intersection(gold_set):
correct += 1
accuracy = correct / total if total > 0 else 0
metrics = {
"iteration": iteration,
"mean_reward": float(mean_reward),
"std_reward": float(std_reward),
"accuracy": accuracy,
"num_rollouts": len(all_rollouts),
"num_training_examples": len(training_data),
"iter_time": iter_time,
}
metrics_log.append(metrics)
print(f"""
Iteration {iteration + 1} Results:
----------------------------------
Mean Reward: {mean_reward:.4f}
Std Reward: {std_reward:.4f}
Accuracy: {accuracy:.2%}
Rollouts: {len(all_rollouts)}
Training Data: {len(training_data)}
Time: {iter_time:.1f}s
""")
# Checkpoint
if (iteration + 1) % config.checkpoint_every == 0 or iteration == config.num_iterations - 1:
ckpt_future = await training_client.save_weights_for_sampler_async(
name=f"rl_final_{iteration:04d}"
)
ckpt_result = await ckpt_future.result_async()
final_checkpoint = ckpt_result.path
print(f"Checkpoint saved: {final_checkpoint}")
# Save metrics
metrics_path = os.path.join(config.log_path, "metrics.jsonl")
with open(metrics_path, "w") as f:
for m in metrics_log:
f.write(json.dumps(m) + "\n")
print(f"\n{'='*70}")
print("RL TRAINING COMPLETE")
print(f"{'='*70}")
print(f"Final checkpoint: {final_checkpoint}")
print(f"Metrics saved to: {metrics_path}")
return final_checkpoint, metrics_log
async def main():
import sys
config = RLConfig()
# Parse command line args
for arg in sys.argv[1:]:
if "=" in arg:
key, value = arg.split("=", 1)
if hasattr(config, key):
current_value = getattr(config, key)
if isinstance(current_value, int):
setattr(config, key, int(value))
elif isinstance(current_value, float):
setattr(config, key, float(value))
else:
setattr(config, key, value)
if not config.sft_checkpoint:
print("ERROR: sft_checkpoint is required")
print("Usage: python rl_train.py sft_checkpoint=tinker://...")
sys.exit(1)
await run_rl_training(config)
if __name__ == "__main__":
asyncio.run(main())