|
|
""" |
|
|
Evaluate the SFT model and run RL continuation. |
|
|
|
|
|
This script: |
|
|
1. Evaluates the SFT checkpoint from our full_pipeline run |
|
|
2. Continues RL training from the SFT state checkpoint |
|
|
3. Evaluates the final RL model |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import json |
|
|
import os |
|
|
import time |
|
|
import numpy as np |
|
|
from collections import Counter |
|
|
from datetime import datetime |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
import tinker |
|
|
from tinker import types |
|
|
from tinker_cookbook import renderers |
|
|
from tinker_cookbook.tokenizer_utils import get_tokenizer |
|
|
|
|
|
|
|
|
BASE_MODEL = "meta-llama/Llama-3.1-8B" |
|
|
LORA_RANK = 32 |
|
|
|
|
|
|
|
|
SFT_STATE_CHECKPOINT = "tinker://398393e1-7182-555d-aa1b-7ddf23892338:train:0/weights/sft_final" |
|
|
SFT_SAMPLER_CHECKPOINT = "tinker://398393e1-7182-555d-aa1b-7ddf23892338:train:0/sampler_weights/sft_final_sampler" |
|
|
|
|
|
|
|
|
RL_ITERATIONS = 10 |
|
|
RL_BATCH_SIZE = 16 |
|
|
RL_GROUP_SIZE = 4 |
|
|
RL_LR = 1e-5 |
|
|
RL_TEMPERATURE = 0.7 |
|
|
|
|
|
|
|
|
TRAIN_DATA_PATH = "training/processed_data/train_data.json" |
|
|
TEST_DATA_PATH = "training/processed_data/test_data.json" |
|
|
|
|
|
|
|
|
VALID_CATEGORIES = { |
|
|
"company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts", |
|
|
"company.business_priorities", "company.tools_config", "company.performance_context", |
|
|
"user.communication_style", "user.strategic_approach", "user.role_context", |
|
|
"user.workflow_patterns", "user.session_history", "user.interaction_preferences", |
|
|
"none" |
|
|
} |
|
|
|
|
|
CATEGORY_PERSISTENCE = { |
|
|
"company.brand_core": "long", "company.strategic_signatures": "long", |
|
|
"company.knowledge_artifacts": "long", "company.business_priorities": "short", |
|
|
"company.tools_config": "medium", "company.performance_context": "rolling", |
|
|
"user.communication_style": "long", "user.strategic_approach": "long", |
|
|
"user.role_context": "medium", "user.workflow_patterns": "medium", |
|
|
"user.session_history": "short", "user.interaction_preferences": "evolving", |
|
|
"none": "short" |
|
|
} |
|
|
|
|
|
|
|
|
def compute_reward(predicted_text: str, gold_categories: list) -> tuple: |
|
|
"""Compute reward with detailed breakdown.""" |
|
|
info = {"format_valid": True, "r_f1": 0, "r_temp": 0, "r_parity": 0, "r_eff": 0} |
|
|
|
|
|
if not predicted_text or not predicted_text.strip(): |
|
|
info["format_valid"] = False |
|
|
return -1.0, info |
|
|
|
|
|
predicted = set([c.strip().lower() for c in predicted_text.split(",") |
|
|
if c.strip().lower() in VALID_CATEGORIES]) |
|
|
|
|
|
if not predicted: |
|
|
info["format_valid"] = False |
|
|
return -1.0, info |
|
|
|
|
|
if "none" in predicted and len(predicted) > 1: |
|
|
predicted.discard("none") |
|
|
|
|
|
gold = set([c.lower() for c in gold_categories]) |
|
|
|
|
|
|
|
|
if predicted and gold: |
|
|
tp = len(predicted & gold) |
|
|
precision = tp / len(predicted) |
|
|
recall = tp / len(gold) |
|
|
info["r_f1"] = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
|
|
elif not predicted and not gold: |
|
|
info["r_f1"] = 1.0 |
|
|
|
|
|
|
|
|
def majority_persistence(cats): |
|
|
if not cats: |
|
|
return "medium" |
|
|
persis = [CATEGORY_PERSISTENCE.get(c, "medium") for c in cats] |
|
|
return Counter(persis).most_common(1)[0][0] |
|
|
|
|
|
if majority_persistence(predicted) == majority_persistence(gold): |
|
|
info["r_temp"] = 1.0 |
|
|
|
|
|
|
|
|
def get_scope(cats): |
|
|
scopes = set() |
|
|
for c in cats: |
|
|
if c.startswith("company."): |
|
|
scopes.add("company") |
|
|
elif c.startswith("user."): |
|
|
scopes.add("user") |
|
|
if len(scopes) == 2: |
|
|
return "mixed" |
|
|
return scopes.pop() if scopes else "none" |
|
|
|
|
|
if get_scope(predicted) == get_scope(gold): |
|
|
info["r_parity"] = 1.0 |
|
|
|
|
|
|
|
|
n = len(predicted) |
|
|
info["r_eff"] = 1.0 if n <= 3 else (0.7 if n == 4 else 0.4) |
|
|
|
|
|
r_total = 0.6 * info["r_f1"] + 0.2 * info["r_temp"] + 0.1 * info["r_parity"] + 0.1 * info["r_eff"] |
|
|
return r_total, info |
|
|
|
|
|
|
|
|
async def evaluate_model(service_client, checkpoint, tokenizer, renderer, test_data, name, n_samples=100): |
|
|
"""Evaluate a model checkpoint.""" |
|
|
print(f"\n{'='*60}") |
|
|
print(f"EVALUATING: {name}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
sampling_client = service_client.create_sampling_client(model_path=checkpoint) |
|
|
stop_sequences = renderer.get_stop_sequences() |
|
|
|
|
|
results = [] |
|
|
|
|
|
for i, example in enumerate(test_data[:n_samples]): |
|
|
gold = example.get("categories", []) |
|
|
messages = example.get("messages", []) |
|
|
prompt_messages = [m for m in messages if m.get("role") != "assistant"] |
|
|
|
|
|
if not prompt_messages: |
|
|
continue |
|
|
|
|
|
prompt = renderer.build_generation_prompt(prompt_messages) |
|
|
params = types.SamplingParams(max_tokens=50, temperature=0.1, stop=stop_sequences) |
|
|
|
|
|
result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result() |
|
|
response, success = renderer.parse_response(result.sequences[0].tokens) |
|
|
predicted_text = response["content"] if success else "" |
|
|
|
|
|
predicted_set = set([c.strip().lower() for c in predicted_text.split(",") |
|
|
if c.strip().lower() in VALID_CATEGORIES]) |
|
|
gold_set = set([c.lower() for c in gold]) |
|
|
|
|
|
reward, info = compute_reward(predicted_text, gold) |
|
|
|
|
|
results.append({ |
|
|
"any_match": len(predicted_set & gold_set) > 0, |
|
|
"exact_match": predicted_set == gold_set, |
|
|
"precision": len(predicted_set & gold_set) / len(predicted_set) if predicted_set else 0, |
|
|
"recall": len(predicted_set & gold_set) / len(gold_set) if gold_set else 0, |
|
|
"reward": reward, |
|
|
"format_valid": info["format_valid"] |
|
|
}) |
|
|
|
|
|
if (i + 1) % 25 == 0: |
|
|
any_match = np.mean([r["any_match"] for r in results]) |
|
|
print(f" Progress: {i+1}/{n_samples}, Any Match: {any_match:.1%}") |
|
|
|
|
|
metrics = { |
|
|
"any_match": np.mean([r["any_match"] for r in results]), |
|
|
"exact_match": np.mean([r["exact_match"] for r in results]), |
|
|
"precision": np.mean([r["precision"] for r in results]), |
|
|
"recall": np.mean([r["recall"] for r in results]), |
|
|
"mean_reward": np.mean([r["reward"] for r in results]), |
|
|
"format_valid": np.mean([r["format_valid"] for r in results]), |
|
|
} |
|
|
metrics["f1"] = 2 * metrics["precision"] * metrics["recall"] / (metrics["precision"] + metrics["recall"]) if (metrics["precision"] + metrics["recall"]) > 0 else 0 |
|
|
|
|
|
print(f"\nResults for {name}:") |
|
|
print(f" Any Match: {metrics['any_match']:.1%}") |
|
|
print(f" Exact Match: {metrics['exact_match']:.1%}") |
|
|
print(f" F1 Score: {metrics['f1']:.1%}") |
|
|
print(f" Mean Reward: {metrics['mean_reward']:.3f}") |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
async def run_rl_phase(service_client, training_client, tokenizer, renderer, train_data): |
|
|
"""Run RL training phase.""" |
|
|
print(f"\n{'='*60}") |
|
|
print("PHASE 2: REINFORCEMENT LEARNING") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
print(f"Loading SFT state from: {SFT_STATE_CHECKPOINT}") |
|
|
await training_client.load_state_async(SFT_STATE_CHECKPOINT) |
|
|
print("SFT weights loaded successfully!") |
|
|
|
|
|
stop_sequences = renderer.get_stop_sequences() |
|
|
metrics_log = [] |
|
|
|
|
|
for iteration in range(RL_ITERATIONS): |
|
|
iter_start = time.time() |
|
|
print(f"\n--- RL Iteration {iteration + 1}/{RL_ITERATIONS} ---") |
|
|
|
|
|
|
|
|
save_future = await training_client.save_weights_for_sampler_async( |
|
|
name=f"rl_iter_{iteration:03d}" |
|
|
) |
|
|
save_result = await save_future.result_async() |
|
|
sampling_client = service_client.create_sampling_client(model_path=save_result.path) |
|
|
|
|
|
|
|
|
batch_indices = np.random.choice(len(train_data), size=RL_BATCH_SIZE, replace=False) |
|
|
|
|
|
all_rollouts = [] |
|
|
all_rewards = [] |
|
|
|
|
|
for idx in batch_indices: |
|
|
example = train_data[idx] |
|
|
gold_categories = example.get("categories", []) |
|
|
messages = example.get("messages", []) |
|
|
prompt_messages = [m for m in messages if m.get("role") != "assistant"] |
|
|
|
|
|
if not prompt_messages: |
|
|
continue |
|
|
|
|
|
prompt = renderer.build_generation_prompt(prompt_messages) |
|
|
params = types.SamplingParams( |
|
|
max_tokens=50, temperature=RL_TEMPERATURE, stop=stop_sequences |
|
|
) |
|
|
|
|
|
result = sampling_client.sample( |
|
|
prompt=prompt, sampling_params=params, num_samples=RL_GROUP_SIZE |
|
|
).result() |
|
|
|
|
|
for seq in result.sequences: |
|
|
response, success = renderer.parse_response(seq.tokens) |
|
|
predicted = response["content"] if success else "" |
|
|
reward, _ = compute_reward(predicted, gold_categories) |
|
|
|
|
|
all_rollouts.append({ |
|
|
"prompt": prompt, |
|
|
"tokens": seq.tokens, |
|
|
"logprobs": seq.logprobs or [], |
|
|
"predicted": predicted, |
|
|
"gold": gold_categories |
|
|
}) |
|
|
all_rewards.append(reward) |
|
|
|
|
|
|
|
|
rewards_arr = np.array(all_rewards) |
|
|
mean_reward = rewards_arr.mean() |
|
|
std_reward = rewards_arr.std() + 1e-8 |
|
|
advantages = (rewards_arr - mean_reward) / std_reward |
|
|
|
|
|
|
|
|
training_data = [] |
|
|
for i, rollout in enumerate(all_rollouts): |
|
|
if not rollout["logprobs"]: |
|
|
continue |
|
|
|
|
|
prompt_tokens = rollout["prompt"].to_ints() |
|
|
gen_tokens = rollout["tokens"] |
|
|
logprobs = rollout["logprobs"] |
|
|
adv = advantages[i] |
|
|
|
|
|
n_prompt = len(prompt_tokens) - 1 |
|
|
n_gen = len(gen_tokens) |
|
|
|
|
|
if len(logprobs) != n_gen: |
|
|
continue |
|
|
|
|
|
full_input = prompt_tokens + gen_tokens[:-1] if n_gen > 1 else prompt_tokens |
|
|
full_target = prompt_tokens[1:] + gen_tokens |
|
|
full_logprobs = [0.0] * n_prompt + logprobs |
|
|
full_advantages = [0.0] * n_prompt + [adv] * n_gen |
|
|
|
|
|
if len(full_target) != len(full_input) or len(full_logprobs) != len(full_input): |
|
|
continue |
|
|
|
|
|
training_data.append(types.Datum( |
|
|
model_input=types.ModelInput.from_ints(full_input), |
|
|
loss_fn_inputs=dict( |
|
|
target_tokens=full_target, |
|
|
logprobs=full_logprobs, |
|
|
advantages=full_advantages |
|
|
) |
|
|
)) |
|
|
|
|
|
|
|
|
if training_data: |
|
|
fwd_future = await training_client.forward_backward_async( |
|
|
training_data, loss_fn="importance_sampling" |
|
|
) |
|
|
optim_future = await training_client.optim_step_async( |
|
|
types.AdamParams(learning_rate=RL_LR, beta1=0.9, beta2=0.95, eps=1e-8) |
|
|
) |
|
|
await fwd_future.result_async() |
|
|
await optim_future.result_async() |
|
|
|
|
|
iter_time = time.time() - iter_start |
|
|
accuracy = sum(1 for r in all_rewards if r > 0) / len(all_rewards) if all_rewards else 0 |
|
|
|
|
|
metrics = { |
|
|
"iteration": iteration, |
|
|
"mean_reward": float(mean_reward), |
|
|
"accuracy": accuracy, |
|
|
"num_rollouts": len(all_rollouts), |
|
|
"time": iter_time |
|
|
} |
|
|
metrics_log.append(metrics) |
|
|
|
|
|
print(f" Reward: {mean_reward:.3f}, Accuracy: {accuracy:.1%}, Time: {iter_time:.1f}s") |
|
|
|
|
|
|
|
|
print("\nSaving final RL checkpoint...") |
|
|
final_future = await training_client.save_weights_for_sampler_async(name="rl_final") |
|
|
final_result = await final_future.result_async() |
|
|
rl_checkpoint = final_result.path |
|
|
|
|
|
print(f"RL checkpoint: {rl_checkpoint}") |
|
|
|
|
|
return rl_checkpoint, metrics_log |
|
|
|
|
|
|
|
|
async def main(): |
|
|
print("=" * 70) |
|
|
print("MEMORY ROUTING AGENT - EVALUATION & RL CONTINUATION") |
|
|
print("=" * 70) |
|
|
print(f"Timestamp: {datetime.now()}") |
|
|
print(f"Base Model: {BASE_MODEL}") |
|
|
print(f"SFT State Checkpoint: {SFT_STATE_CHECKPOINT}") |
|
|
|
|
|
|
|
|
service_client = tinker.ServiceClient() |
|
|
tokenizer = get_tokenizer(BASE_MODEL) |
|
|
renderer = renderers.get_renderer(name="llama3", tokenizer=tokenizer) |
|
|
|
|
|
|
|
|
with open(TRAIN_DATA_PATH, "r") as f: |
|
|
train_data = json.load(f) |
|
|
with open(TEST_DATA_PATH, "r") as f: |
|
|
test_data = json.load(f) |
|
|
|
|
|
print(f"Train: {len(train_data)}, Test: {len(test_data)}") |
|
|
|
|
|
|
|
|
sft_metrics = await evaluate_model( |
|
|
service_client, SFT_SAMPLER_CHECKPOINT, tokenizer, renderer, test_data, "SFT Model", n_samples=100 |
|
|
) |
|
|
|
|
|
|
|
|
training_client = await service_client.create_lora_training_client_async( |
|
|
base_model=BASE_MODEL, |
|
|
rank=LORA_RANK, |
|
|
) |
|
|
|
|
|
|
|
|
rl_checkpoint, rl_metrics = await run_rl_phase( |
|
|
service_client, training_client, tokenizer, renderer, train_data |
|
|
) |
|
|
|
|
|
|
|
|
rl_eval_metrics = await evaluate_model( |
|
|
service_client, rl_checkpoint, tokenizer, renderer, test_data, "RL Model", n_samples=100 |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("TRAINING COMPLETE - SUMMARY") |
|
|
print("=" * 70) |
|
|
print(f"\nSFT Model:") |
|
|
print(f" Checkpoint: {SFT_SAMPLER_CHECKPOINT}") |
|
|
print(f" Any Match: {sft_metrics['any_match']:.1%}") |
|
|
print(f" F1 Score: {sft_metrics['f1']:.1%}") |
|
|
|
|
|
print(f"\nRL Model:") |
|
|
print(f" Checkpoint: {rl_checkpoint}") |
|
|
print(f" Any Match: {rl_eval_metrics['any_match']:.1%}") |
|
|
print(f" F1 Score: {rl_eval_metrics['f1']:.1%}") |
|
|
|
|
|
improvement = rl_eval_metrics['any_match'] - sft_metrics['any_match'] |
|
|
print(f"\nImprovement: {improvement:+.1%}") |
|
|
|
|
|
|
|
|
results = { |
|
|
"sft_checkpoint": SFT_SAMPLER_CHECKPOINT, |
|
|
"rl_checkpoint": rl_checkpoint, |
|
|
"sft_metrics": sft_metrics, |
|
|
"rl_metrics": rl_eval_metrics, |
|
|
"rl_training_log": rl_metrics |
|
|
} |
|
|
|
|
|
results_path = f"training/experiments/results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" |
|
|
os.makedirs(os.path.dirname(results_path), exist_ok=True) |
|
|
with open(results_path, "w") as f: |
|
|
json.dump(results, f, indent=2, default=str) |
|
|
|
|
|
print(f"\nResults saved to: {results_path}") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main()) |
|
|
|
|
|
|