|
|
""" |
|
|
Reinforcement Learning Training for Memory Routing |
|
|
|
|
|
This implements Stage 2 of the PRD: RL Optimization using Tinker's |
|
|
importance_sampling loss function. |
|
|
|
|
|
Per Tinker docs (rl.mdx): |
|
|
- RL learns from trial and error with reward functions |
|
|
- Use importance_sampling loss for policy gradient |
|
|
|
|
|
Per Tinker docs (rl/rl-loops.mdx): |
|
|
1. Create policy with current weights |
|
|
2. Generate rollouts |
|
|
3. Process trajectory data into training examples |
|
|
4. Update model parameters |
|
|
|
|
|
Per PRD Section 8: |
|
|
- 25 iterations minimum |
|
|
- Group size 8 for variance reduction |
|
|
- KL divergence monitoring (<0.01 threshold) |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import json |
|
|
import time |
|
|
import os |
|
|
from typing import List, Dict, Any, Tuple |
|
|
from dataclasses import dataclass |
|
|
from collections import Counter |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RLConfig: |
|
|
|
|
|
sft_checkpoint: str = "" |
|
|
base_model: str = "meta-llama/Llama-3.1-8B" |
|
|
lora_rank: int = 32 |
|
|
renderer_name: str = "llama3" |
|
|
|
|
|
|
|
|
num_iterations: int = 25 |
|
|
batch_size: int = 32 |
|
|
group_size: int = 8 |
|
|
learning_rate: float = 1e-5 |
|
|
|
|
|
|
|
|
beta1: float = 0.9 |
|
|
beta2: float = 0.95 |
|
|
eps: float = 1e-8 |
|
|
|
|
|
|
|
|
max_tokens: int = 50 |
|
|
temperature: float = 0.7 |
|
|
|
|
|
|
|
|
kl_threshold: float = 0.01 |
|
|
checkpoint_every: int = 5 |
|
|
|
|
|
|
|
|
train_data_path: str = "training/processed_data/train_data.json" |
|
|
log_path: str = "training/logs/rl" |
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
from rl_env import compute_reward, VALID_CATEGORIES |
|
|
|
|
|
|
|
|
def build_routing_prompt(conversation: List[Dict[str, str]]) -> List[Dict[str, str]]: |
|
|
"""Build the routing prompt for a conversation.""" |
|
|
system_content = """You route marketing conversations into structured memory categories. |
|
|
|
|
|
Available categories: |
|
|
- company.brand_core: Voice, values, positioning |
|
|
- company.strategic_signatures: Decision frameworks |
|
|
- company.knowledge_artifacts: Docs, style guides |
|
|
- company.business_priorities: Quarterly goals, campaigns |
|
|
- company.tools_config: Integrations, settings |
|
|
- company.performance_context: Campaign metrics |
|
|
- user.communication_style: Tone, format expectations |
|
|
- user.strategic_approach: Personal priorities |
|
|
- user.role_context: Title, scope |
|
|
- user.workflow_patterns: Review cadence |
|
|
- user.session_history: Recent context |
|
|
- user.interaction_preferences: Coaching style |
|
|
- none: Irrelevant or transactional |
|
|
|
|
|
Respond with comma-separated categories.""" |
|
|
|
|
|
|
|
|
conversation_text = "" |
|
|
for turn in conversation: |
|
|
if isinstance(turn, dict): |
|
|
role = turn.get("role", "unknown") |
|
|
content = turn.get("content", "") |
|
|
conversation_text += f"{role.upper()}: {content}\n" |
|
|
|
|
|
return [ |
|
|
{"role": "system", "content": system_content}, |
|
|
{"role": "user", "content": f"Conversation:\n{conversation_text.strip()}\n\nCategories?"} |
|
|
] |
|
|
|
|
|
|
|
|
async def run_rl_training(config: RLConfig): |
|
|
""" |
|
|
Main RL training loop. |
|
|
|
|
|
Per Tinker docs (rl/rl-loops.mdx): |
|
|
1. Create policy with current weights |
|
|
2. Generate rollouts (sample from model) |
|
|
3. Compute rewards and advantages |
|
|
4. Update with importance_sampling loss |
|
|
""" |
|
|
import tinker |
|
|
from tinker import types |
|
|
from tinker_cookbook import renderers |
|
|
from tinker_cookbook.tokenizer_utils import get_tokenizer |
|
|
from tinker_cookbook.hyperparam_utils import get_lr |
|
|
from dotenv import load_dotenv |
|
|
import numpy as np |
|
|
|
|
|
load_dotenv() |
|
|
os.makedirs(config.log_path, exist_ok=True) |
|
|
|
|
|
print("=" * 70) |
|
|
print("REINFORCEMENT LEARNING TRAINING") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
print(f"\nLoading training data from {config.train_data_path}...") |
|
|
with open(config.train_data_path, "r") as f: |
|
|
train_data = json.load(f) |
|
|
print(f"Loaded {len(train_data)} examples") |
|
|
|
|
|
|
|
|
print("\nInitializing Tinker...") |
|
|
service_client = tinker.ServiceClient() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Creating training client (base: {config.base_model})...") |
|
|
training_client = await service_client.create_lora_training_client_async( |
|
|
base_model=config.base_model, |
|
|
rank=config.lora_rank, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"SFT checkpoint for reference: {config.sft_checkpoint}") |
|
|
print("Note: Using fresh LoRA weights for training, SFT checkpoint for initial sampling") |
|
|
|
|
|
|
|
|
tokenizer = get_tokenizer(config.base_model) |
|
|
renderer = renderers.get_renderer(name=config.renderer_name, tokenizer=tokenizer) |
|
|
stop_sequences = renderer.get_stop_sequences() |
|
|
|
|
|
print(f""" |
|
|
RL Training Configuration: |
|
|
-------------------------- |
|
|
SFT Checkpoint: {config.sft_checkpoint} |
|
|
Base Model: {config.base_model} |
|
|
LoRA Rank: {config.lora_rank} |
|
|
Iterations: {config.num_iterations} |
|
|
Batch Size: {config.batch_size} |
|
|
Group Size: {config.group_size} |
|
|
Learning Rate: {config.learning_rate:.2e} |
|
|
Temperature: {config.temperature} |
|
|
KL Threshold: {config.kl_threshold} |
|
|
""") |
|
|
|
|
|
metrics_log = [] |
|
|
final_checkpoint = None |
|
|
|
|
|
for iteration in range(config.num_iterations): |
|
|
iter_start = time.time() |
|
|
print(f"\n{'='*70}") |
|
|
print(f"Iteration {iteration + 1}/{config.num_iterations}") |
|
|
print(f"{'='*70}") |
|
|
|
|
|
|
|
|
print("\n[1/4] Getting sampling client...") |
|
|
|
|
|
if iteration == 0: |
|
|
|
|
|
sampling_path = config.sft_checkpoint |
|
|
print(f" Using SFT checkpoint: {sampling_path}") |
|
|
else: |
|
|
|
|
|
save_future = await training_client.save_weights_for_sampler_async( |
|
|
name=f"rl_iter_{iteration:04d}" |
|
|
) |
|
|
save_result = await save_future.result_async() |
|
|
sampling_path = save_result.path |
|
|
print(f" Saved new checkpoint: {sampling_path}") |
|
|
|
|
|
|
|
|
sampling_client = service_client.create_sampling_client(model_path=sampling_path) |
|
|
|
|
|
|
|
|
print("[2/4] Generating rollouts...") |
|
|
|
|
|
|
|
|
batch_indices = np.random.choice(len(train_data), size=config.batch_size, replace=False) |
|
|
batch_examples = [train_data[i] for i in batch_indices] |
|
|
|
|
|
all_rollouts = [] |
|
|
all_rewards = [] |
|
|
all_gold_categories = [] |
|
|
|
|
|
for example in batch_examples: |
|
|
|
|
|
gold_categories = example.get("categories", []) |
|
|
if not gold_categories: |
|
|
gold_categories = example.get("labels", {}).get("categories", []) |
|
|
|
|
|
|
|
|
messages = example.get("messages", []) |
|
|
if not messages: |
|
|
continue |
|
|
|
|
|
|
|
|
prompt_messages = [m for m in messages if m.get("role") != "assistant"] |
|
|
|
|
|
|
|
|
prompt = renderer.build_generation_prompt(prompt_messages) |
|
|
|
|
|
|
|
|
params = types.SamplingParams( |
|
|
max_tokens=config.max_tokens, |
|
|
temperature=config.temperature, |
|
|
stop=stop_sequences |
|
|
) |
|
|
|
|
|
result = sampling_client.sample( |
|
|
prompt=prompt, |
|
|
sampling_params=params, |
|
|
num_samples=config.group_size |
|
|
).result() |
|
|
|
|
|
|
|
|
for seq in result.sequences: |
|
|
response_message, success = renderer.parse_response(seq.tokens) |
|
|
predicted_text = response_message["content"] if success else "" |
|
|
|
|
|
|
|
|
reward_result = compute_reward(predicted_text, gold_categories) |
|
|
|
|
|
|
|
|
if len(all_rollouts) < 3: |
|
|
print(f" DEBUG: predicted='{predicted_text}', gold={gold_categories}, reward={reward_result.r_total:.3f}") |
|
|
|
|
|
all_rollouts.append({ |
|
|
"prompt": prompt, |
|
|
"tokens": seq.tokens, |
|
|
"logprobs": seq.logprobs if seq.logprobs else [], |
|
|
"predicted": predicted_text, |
|
|
"gold": gold_categories |
|
|
}) |
|
|
all_rewards.append(reward_result.r_total) |
|
|
all_gold_categories.append(gold_categories) |
|
|
|
|
|
|
|
|
print("[3/4] Computing advantages...") |
|
|
|
|
|
rewards_array = np.array(all_rewards) |
|
|
mean_reward = rewards_array.mean() |
|
|
std_reward = rewards_array.std() + 1e-8 |
|
|
|
|
|
|
|
|
advantages = (rewards_array - mean_reward) / std_reward |
|
|
|
|
|
|
|
|
print("[4/4] Updating model...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_data = [] |
|
|
for i, rollout in enumerate(all_rollouts): |
|
|
if not rollout["logprobs"] or len(rollout["logprobs"]) == 0: |
|
|
continue |
|
|
|
|
|
tokens = rollout["tokens"] |
|
|
logprobs = rollout["logprobs"] |
|
|
advantage = advantages[i] |
|
|
|
|
|
|
|
|
prompt_tokens = rollout["prompt"].to_ints() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n_gen = len(tokens) |
|
|
if n_gen < 1 or len(logprobs) != n_gen: |
|
|
continue |
|
|
|
|
|
|
|
|
full_input = prompt_tokens + tokens[:-1] if n_gen > 1 else prompt_tokens |
|
|
n_input = len(full_input) |
|
|
|
|
|
|
|
|
|
|
|
full_target = prompt_tokens[1:] + tokens if len(prompt_tokens) > 0 else tokens |
|
|
|
|
|
|
|
|
n_prompt = len(prompt_tokens) - 1 if len(prompt_tokens) > 0 else 0 |
|
|
full_logprobs = [0.0] * n_prompt + logprobs |
|
|
|
|
|
|
|
|
full_advantages = [0.0] * n_prompt + [advantage] * n_gen |
|
|
|
|
|
|
|
|
if len(full_target) != n_input or len(full_logprobs) != n_input or len(full_advantages) != n_input: |
|
|
|
|
|
continue |
|
|
|
|
|
datum = types.Datum( |
|
|
model_input=types.ModelInput.from_ints(full_input), |
|
|
loss_fn_inputs=dict( |
|
|
target_tokens=full_target, |
|
|
logprobs=full_logprobs, |
|
|
advantages=full_advantages |
|
|
) |
|
|
) |
|
|
training_data.append(datum) |
|
|
|
|
|
if training_data: |
|
|
|
|
|
fwd_bwd_future = await training_client.forward_backward_async( |
|
|
training_data, |
|
|
loss_fn="importance_sampling" |
|
|
) |
|
|
|
|
|
|
|
|
adam_params = types.AdamParams( |
|
|
learning_rate=config.learning_rate, |
|
|
beta1=config.beta1, |
|
|
beta2=config.beta2, |
|
|
eps=config.eps, |
|
|
) |
|
|
optim_future = await training_client.optim_step_async(adam_params) |
|
|
|
|
|
|
|
|
fwd_bwd_result = await fwd_bwd_future.result_async() |
|
|
optim_result = await optim_future.result_async() |
|
|
|
|
|
|
|
|
kl_values = [] |
|
|
for output in fwd_bwd_result.loss_fn_outputs: |
|
|
if "logprobs" in output: |
|
|
new_logprobs = output["logprobs"].tolist() |
|
|
|
|
|
kl_values.extend(new_logprobs) |
|
|
|
|
|
|
|
|
iter_time = time.time() - iter_start |
|
|
|
|
|
|
|
|
correct = 0 |
|
|
total = len(all_rollouts) |
|
|
for rollout in all_rollouts: |
|
|
predicted_set = set([x.strip() for x in rollout["predicted"].split(",") if x.strip() in VALID_CATEGORIES]) |
|
|
gold_set = set(rollout["gold"]) |
|
|
if predicted_set.intersection(gold_set): |
|
|
correct += 1 |
|
|
|
|
|
accuracy = correct / total if total > 0 else 0 |
|
|
|
|
|
metrics = { |
|
|
"iteration": iteration, |
|
|
"mean_reward": float(mean_reward), |
|
|
"std_reward": float(std_reward), |
|
|
"accuracy": accuracy, |
|
|
"num_rollouts": len(all_rollouts), |
|
|
"num_training_examples": len(training_data), |
|
|
"iter_time": iter_time, |
|
|
} |
|
|
metrics_log.append(metrics) |
|
|
|
|
|
print(f""" |
|
|
Iteration {iteration + 1} Results: |
|
|
---------------------------------- |
|
|
Mean Reward: {mean_reward:.4f} |
|
|
Std Reward: {std_reward:.4f} |
|
|
Accuracy: {accuracy:.2%} |
|
|
Rollouts: {len(all_rollouts)} |
|
|
Training Data: {len(training_data)} |
|
|
Time: {iter_time:.1f}s |
|
|
""") |
|
|
|
|
|
|
|
|
if (iteration + 1) % config.checkpoint_every == 0 or iteration == config.num_iterations - 1: |
|
|
ckpt_future = await training_client.save_weights_for_sampler_async( |
|
|
name=f"rl_final_{iteration:04d}" |
|
|
) |
|
|
ckpt_result = await ckpt_future.result_async() |
|
|
final_checkpoint = ckpt_result.path |
|
|
print(f"Checkpoint saved: {final_checkpoint}") |
|
|
|
|
|
|
|
|
metrics_path = os.path.join(config.log_path, "metrics.jsonl") |
|
|
with open(metrics_path, "w") as f: |
|
|
for m in metrics_log: |
|
|
f.write(json.dumps(m) + "\n") |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print("RL TRAINING COMPLETE") |
|
|
print(f"{'='*70}") |
|
|
print(f"Final checkpoint: {final_checkpoint}") |
|
|
print(f"Metrics saved to: {metrics_path}") |
|
|
|
|
|
return final_checkpoint, metrics_log |
|
|
|
|
|
|
|
|
async def main(): |
|
|
import sys |
|
|
|
|
|
config = RLConfig() |
|
|
|
|
|
|
|
|
for arg in sys.argv[1:]: |
|
|
if "=" in arg: |
|
|
key, value = arg.split("=", 1) |
|
|
if hasattr(config, key): |
|
|
current_value = getattr(config, key) |
|
|
if isinstance(current_value, int): |
|
|
setattr(config, key, int(value)) |
|
|
elif isinstance(current_value, float): |
|
|
setattr(config, key, float(value)) |
|
|
else: |
|
|
setattr(config, key, value) |
|
|
|
|
|
if not config.sft_checkpoint: |
|
|
print("ERROR: sft_checkpoint is required") |
|
|
print("Usage: python rl_train.py sft_checkpoint=tinker://...") |
|
|
sys.exit(1) |
|
|
|
|
|
await run_rl_training(config) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main()) |
|
|
|
|
|
|