|
|
""" |
|
|
Fixed Training Script for Memory Routing Agent. |
|
|
|
|
|
Key fixes based on Tinker docs: |
|
|
1. Proper advantage computation (centered within groups) |
|
|
2. Correct tensor alignment for importance_sampling loss |
|
|
3. Proper group-based rollout collection |
|
|
4. KL divergence monitoring |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import json |
|
|
import os |
|
|
import time |
|
|
from datetime import datetime |
|
|
from dataclasses import dataclass |
|
|
from typing import List, Dict, Any, Tuple |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
import tinker |
|
|
from tinker import types |
|
|
from tinker_cookbook import renderers |
|
|
from tinker_cookbook.tokenizer_utils import get_tokenizer |
|
|
from tinker_cookbook.hyperparam_utils import get_lr |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
BASE_MODEL = "meta-llama/Llama-3.1-8B" |
|
|
LORA_RANK = 32 |
|
|
|
|
|
|
|
|
SFT_STEPS = 50 |
|
|
SFT_BATCH_SIZE = 32 |
|
|
|
|
|
|
|
|
RL_ITERATIONS = 20 |
|
|
RL_GROUPS_PER_BATCH = 32 |
|
|
RL_GROUP_SIZE = 4 |
|
|
RL_LR = 2e-5 |
|
|
RL_TEMPERATURE = 0.7 |
|
|
RL_MAX_TOKENS = 100 |
|
|
|
|
|
|
|
|
TRAIN_DATA = "training/processed_data/train_data.json" |
|
|
TEST_DATA = "training/processed_data/test_data.json" |
|
|
LOG_DIR = "training/logs/run_" + datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
|
|
|
VALID_CATEGORIES = { |
|
|
"company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts", |
|
|
"company.business_priorities", "company.tools_config", "company.performance_context", |
|
|
"user.communication_style", "user.strategic_approach", "user.role_context", |
|
|
"user.workflow_patterns", "user.session_history", "user.interaction_preferences", |
|
|
"none" |
|
|
} |
|
|
|
|
|
SYSTEM_PROMPT = """You route marketing conversations into structured memory categories. |
|
|
|
|
|
Available categories: |
|
|
- company.brand_core: Voice, values, positioning, identity anchors (Long >1y) |
|
|
- company.strategic_signatures: Decision frameworks, strategic heuristics (Long >1y) |
|
|
- company.knowledge_artifacts: Docs, style guides, playbooks (Long >1y) |
|
|
- company.business_priorities: Quarterly/seasonal goals, active campaigns (Short <3m) |
|
|
- company.tools_config: Integrations, API keys, workflow settings (Medium ~6m) |
|
|
- company.performance_context: Campaign metrics, retrospectives, learnings (Rolling ~6m) |
|
|
- user.communication_style: Tone, verbosity, format expectations (Long >1y) |
|
|
- user.strategic_approach: Personal priorities, success definitions (Long >1y) |
|
|
- user.role_context: Title, scope, decision authority (Medium ~1y) |
|
|
- user.workflow_patterns: Review cadence, collaboration norms (Medium ~1y) |
|
|
- user.session_history: Immediate context, recent asks (Short <2w) |
|
|
- user.interaction_preferences: Coaching style, feedback expectations (Evolving) |
|
|
- none: Irrelevant, vague, or transactional content |
|
|
|
|
|
Respond with comma-separated categories. Use 'none' only if no other category applies.""" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Rollout: |
|
|
"""Single rollout from a problem.""" |
|
|
prompt_tokens: List[int] |
|
|
gen_tokens: List[int] |
|
|
logprobs: List[float] |
|
|
reward: float |
|
|
predicted: str |
|
|
gold: List[str] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RolloutGroup: |
|
|
"""Group of rollouts for the same problem.""" |
|
|
problem_id: int |
|
|
rollouts: List[Rollout] |
|
|
|
|
|
def get_rewards(self) -> List[float]: |
|
|
return [r.reward for r in self.rollouts] |
|
|
|
|
|
def is_constant_reward(self) -> bool: |
|
|
"""Check if all rewards are the same (no learning signal).""" |
|
|
rewards = self.get_rewards() |
|
|
return len(set(rewards)) == 1 |
|
|
|
|
|
|
|
|
class TrainingLogger: |
|
|
def __init__(self, log_dir): |
|
|
os.makedirs(log_dir, exist_ok=True) |
|
|
self.log_dir = log_dir |
|
|
self.sft_log = open(os.path.join(log_dir, "sft_metrics.jsonl"), "w") |
|
|
self.rl_log = open(os.path.join(log_dir, "rl_metrics.jsonl"), "w") |
|
|
self.start_time = time.time() |
|
|
|
|
|
def log_sft(self, step, metrics): |
|
|
metrics["step"] = step |
|
|
metrics["elapsed_time"] = time.time() - self.start_time |
|
|
self.sft_log.write(json.dumps(metrics) + "\n") |
|
|
self.sft_log.flush() |
|
|
|
|
|
test_loss = metrics.get('test_loss') |
|
|
test_str = f"{test_loss:.4f}" if isinstance(test_loss, (int, float)) else "N/A" |
|
|
print(f"[SFT {step:3d}] " |
|
|
f"Loss: {metrics.get('train_loss', 0):.4f} | " |
|
|
f"Test: {test_str} | " |
|
|
f"Acc: {metrics.get('accuracy', 'N/A')} | " |
|
|
f"Time: {metrics.get('step_time', 0):.1f}s") |
|
|
|
|
|
def log_rl(self, iteration, metrics): |
|
|
metrics["iteration"] = iteration |
|
|
metrics["elapsed_time"] = time.time() - self.start_time |
|
|
self.rl_log.write(json.dumps(metrics) + "\n") |
|
|
self.rl_log.flush() |
|
|
|
|
|
print(f"[RL {iteration:3d}] " |
|
|
f"Reward: {metrics.get('mean_reward', 0):.3f} (±{metrics.get('std_reward', 0):.3f}) | " |
|
|
f"Acc: {metrics.get('accuracy', 0):.1%} | " |
|
|
f"KL: {metrics.get('kl_divergence', 0):.4f} | " |
|
|
f"Groups: {metrics.get('active_groups', 0)} | " |
|
|
f"Time: {metrics.get('iter_time', 0):.1f}s") |
|
|
|
|
|
def close(self): |
|
|
self.sft_log.close() |
|
|
self.rl_log.close() |
|
|
|
|
|
|
|
|
def compute_reward(predicted_text: str, gold_categories: List[str]) -> Tuple[float, Dict]: |
|
|
"""Compute F1-based reward for RL.""" |
|
|
if not predicted_text or not predicted_text.strip(): |
|
|
return -1.0, {"format_valid": False, "predicted": set(), "gold": set(gold_categories)} |
|
|
|
|
|
predicted = set([c.strip().lower() for c in predicted_text.split(",") |
|
|
if c.strip().lower() in VALID_CATEGORIES]) |
|
|
|
|
|
if not predicted: |
|
|
return -1.0, {"format_valid": False, "predicted": set(), "gold": set(gold_categories)} |
|
|
|
|
|
gold = set([c.lower() for c in gold_categories]) |
|
|
|
|
|
|
|
|
if predicted and gold: |
|
|
tp = len(predicted & gold) |
|
|
precision = tp / len(predicted) if predicted else 0 |
|
|
recall = tp / len(gold) if gold else 0 |
|
|
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
|
|
else: |
|
|
f1 = 1.0 if not predicted and not gold else 0.0 |
|
|
|
|
|
return f1, {"format_valid": True, "f1": f1, "predicted": predicted, "gold": gold} |
|
|
|
|
|
|
|
|
def compute_group_advantages(groups: List[RolloutGroup]) -> List[List[float]]: |
|
|
""" |
|
|
Compute advantages by centering rewards within each group. |
|
|
This is the correct way per Tinker docs. |
|
|
""" |
|
|
all_advantages = [] |
|
|
|
|
|
for group in groups: |
|
|
rewards = np.array(group.get_rewards()) |
|
|
|
|
|
|
|
|
mean_reward = rewards.mean() |
|
|
|
|
|
|
|
|
std_reward = rewards.std() |
|
|
if std_reward > 1e-8: |
|
|
advantages = (rewards - mean_reward) / std_reward |
|
|
else: |
|
|
advantages = rewards - mean_reward |
|
|
|
|
|
all_advantages.append(advantages.tolist()) |
|
|
|
|
|
return all_advantages |
|
|
|
|
|
|
|
|
def build_rl_datum(rollout: Rollout, advantage: float) -> types.Datum: |
|
|
""" |
|
|
Build a Datum for importance_sampling loss. |
|
|
|
|
|
Per Tinker docs, importance_sampling requires: |
|
|
- target_tokens: array[(N,), int] - Target token IDs from sampler |
|
|
- logprobs: array[(N,), float] - Reference log probabilities from sampler |
|
|
- advantages: array[(N,), float] - Advantage values |
|
|
|
|
|
All must have shape (N,) where N = model_input.length |
|
|
""" |
|
|
prompt_tokens = rollout.prompt_tokens |
|
|
gen_tokens = rollout.gen_tokens |
|
|
sampler_logprobs = rollout.logprobs |
|
|
|
|
|
|
|
|
|
|
|
n_prompt = len(prompt_tokens) |
|
|
n_gen = len(gen_tokens) |
|
|
|
|
|
|
|
|
full_tokens = prompt_tokens + gen_tokens |
|
|
|
|
|
|
|
|
input_tokens = full_tokens[:-1] |
|
|
|
|
|
|
|
|
target_tokens = full_tokens[1:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n_input = len(input_tokens) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
full_logprobs = [0.0] * (n_prompt - 1) + sampler_logprobs |
|
|
|
|
|
|
|
|
full_advantages = [0.0] * (n_prompt - 1) + [advantage] * n_gen |
|
|
|
|
|
|
|
|
assert len(target_tokens) == n_input, f"target_tokens length mismatch: {len(target_tokens)} vs {n_input}" |
|
|
assert len(full_logprobs) == n_input, f"logprobs length mismatch: {len(full_logprobs)} vs {n_input}" |
|
|
assert len(full_advantages) == n_input, f"advantages length mismatch: {len(full_advantages)} vs {n_input}" |
|
|
|
|
|
return types.Datum( |
|
|
model_input=types.ModelInput.from_ints(input_tokens), |
|
|
loss_fn_inputs=dict( |
|
|
target_tokens=target_tokens, |
|
|
logprobs=full_logprobs, |
|
|
advantages=full_advantages |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
async def collect_rollouts( |
|
|
sampling_client: tinker.SamplingClient, |
|
|
renderer: renderers.Renderer, |
|
|
train_data: List[Dict], |
|
|
groups_per_batch: int, |
|
|
group_size: int |
|
|
) -> List[RolloutGroup]: |
|
|
"""Collect rollouts organized by problem groups.""" |
|
|
|
|
|
stop_sequences = renderer.get_stop_sequences() |
|
|
params = types.SamplingParams( |
|
|
max_tokens=RL_MAX_TOKENS, |
|
|
temperature=RL_TEMPERATURE, |
|
|
stop=stop_sequences |
|
|
) |
|
|
|
|
|
|
|
|
problem_indices = np.random.choice(len(train_data), size=groups_per_batch, replace=False) |
|
|
|
|
|
rollout_groups = [] |
|
|
|
|
|
for problem_idx in problem_indices: |
|
|
example = train_data[problem_idx] |
|
|
gold = example.get("categories", []) |
|
|
messages = example.get("messages", []) |
|
|
|
|
|
|
|
|
prompt_messages = messages[:-1] if messages else [] |
|
|
if not prompt_messages: |
|
|
continue |
|
|
|
|
|
prompt = renderer.build_generation_prompt(prompt_messages) |
|
|
prompt_tokens = prompt.to_ints() |
|
|
|
|
|
|
|
|
result = sampling_client.sample( |
|
|
prompt=prompt, |
|
|
sampling_params=params, |
|
|
num_samples=group_size |
|
|
).result() |
|
|
|
|
|
rollouts = [] |
|
|
for seq in result.sequences: |
|
|
response, success = renderer.parse_response(seq.tokens) |
|
|
predicted = response["content"] if success else "" |
|
|
reward, info = compute_reward(predicted, gold) |
|
|
|
|
|
|
|
|
if seq.logprobs and len(seq.logprobs) == len(seq.tokens): |
|
|
rollouts.append(Rollout( |
|
|
prompt_tokens=prompt_tokens, |
|
|
gen_tokens=seq.tokens, |
|
|
logprobs=seq.logprobs, |
|
|
reward=reward, |
|
|
predicted=predicted, |
|
|
gold=gold |
|
|
)) |
|
|
|
|
|
if rollouts: |
|
|
rollout_groups.append(RolloutGroup( |
|
|
problem_id=problem_idx, |
|
|
rollouts=rollouts |
|
|
)) |
|
|
|
|
|
return rollout_groups |
|
|
|
|
|
|
|
|
def filter_constant_reward_groups(groups: List[RolloutGroup]) -> List[RolloutGroup]: |
|
|
""" |
|
|
Remove groups where all rollouts have the same reward. |
|
|
These provide no learning signal (gradient is zero). |
|
|
""" |
|
|
return [g for g in groups if not g.is_constant_reward()] |
|
|
|
|
|
|
|
|
async def run_sft( |
|
|
service_client: tinker.ServiceClient, |
|
|
training_client: tinker.TrainingClient, |
|
|
tokenizer, |
|
|
renderer: renderers.Renderer, |
|
|
train_data: List[Dict], |
|
|
test_data: List[Dict], |
|
|
logger: TrainingLogger |
|
|
) -> Tuple[str, str]: |
|
|
"""Run SFT phase.""" |
|
|
print("\n" + "=" * 70) |
|
|
print("PHASE 1: SUPERVISED FINE-TUNING") |
|
|
print("=" * 70) |
|
|
|
|
|
lr = get_lr(BASE_MODEL) |
|
|
print(f"Learning rate: {lr:.2e}") |
|
|
print(f"Steps: {SFT_STEPS}, Batch size: {SFT_BATCH_SIZE}") |
|
|
print() |
|
|
|
|
|
|
|
|
def to_datum(item): |
|
|
messages = item.get("messages", []) |
|
|
tokens, weights = renderer.build_supervised_example(messages) |
|
|
if hasattr(tokens, 'tolist'): |
|
|
tokens = tokens.tolist() |
|
|
if hasattr(weights, 'tolist'): |
|
|
weights = weights.tolist() |
|
|
return types.Datum( |
|
|
model_input=types.ModelInput.from_ints(tokens[:-1]), |
|
|
loss_fn_inputs=dict(target_tokens=tokens[1:], weights=weights[1:]) |
|
|
) |
|
|
|
|
|
train_datums = [to_datum(item) for item in train_data] |
|
|
test_datums = [to_datum(item) for item in test_data[:50]] |
|
|
|
|
|
for step in range(SFT_STEPS): |
|
|
step_start = time.time() |
|
|
|
|
|
|
|
|
batch_idx = (step * SFT_BATCH_SIZE) % len(train_datums) |
|
|
batch = train_datums[batch_idx:batch_idx + SFT_BATCH_SIZE] |
|
|
if len(batch) < SFT_BATCH_SIZE: |
|
|
batch = batch + train_datums[:SFT_BATCH_SIZE - len(batch)] |
|
|
|
|
|
|
|
|
fwd_future = await training_client.forward_backward_async(batch, loss_fn="cross_entropy") |
|
|
optim_future = await training_client.optim_step_async( |
|
|
types.AdamParams(learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8) |
|
|
) |
|
|
|
|
|
fwd_result = await fwd_future.result_async() |
|
|
await optim_future.result_async() |
|
|
|
|
|
|
|
|
logprobs = np.concatenate([o['logprobs'].tolist() for o in fwd_result.loss_fn_outputs]) |
|
|
weights = np.concatenate([d.loss_fn_inputs['weights'].tolist() for d in batch]) |
|
|
train_loss = -np.dot(logprobs, weights) / max(weights.sum(), 1) |
|
|
|
|
|
step_time = time.time() - step_start |
|
|
metrics = {"train_loss": float(train_loss), "step_time": step_time} |
|
|
|
|
|
|
|
|
if step % 10 == 0 or step == SFT_STEPS - 1: |
|
|
eval_future = await training_client.forward_backward_async(test_datums, loss_fn="cross_entropy") |
|
|
eval_result = await eval_future.result_async() |
|
|
test_logprobs = np.concatenate([o['logprobs'].tolist() for o in eval_result.loss_fn_outputs]) |
|
|
test_weights = np.concatenate([d.loss_fn_inputs['weights'].tolist() for d in test_datums]) |
|
|
test_loss = -np.dot(test_logprobs, test_weights) / max(test_weights.sum(), 1) |
|
|
metrics["test_loss"] = float(test_loss) |
|
|
|
|
|
|
|
|
save_future = await training_client.save_weights_for_sampler_async(name=f"sft_step_{step:04d}") |
|
|
save_result = await save_future.result_async() |
|
|
metrics["checkpoint"] = save_result.path |
|
|
|
|
|
logger.log_sft(step, metrics) |
|
|
|
|
|
|
|
|
state_future = await training_client.save_state_async(name="sft_final") |
|
|
state_result = await state_future.result_async() |
|
|
|
|
|
sampler_future = await training_client.save_weights_for_sampler_async(name="sft_final_sampler") |
|
|
sampler_result = await sampler_future.result_async() |
|
|
|
|
|
print(f"\nSFT Complete. State checkpoint: {state_result.path}") |
|
|
|
|
|
return state_result.path, sampler_result.path |
|
|
|
|
|
|
|
|
async def run_rl( |
|
|
service_client: tinker.ServiceClient, |
|
|
training_client: tinker.TrainingClient, |
|
|
sft_state_path: str, |
|
|
tokenizer, |
|
|
renderer: renderers.Renderer, |
|
|
train_data: List[Dict], |
|
|
test_data: List[Dict], |
|
|
logger: TrainingLogger |
|
|
) -> str: |
|
|
"""Run RL phase with proper advantage computation.""" |
|
|
print("\n" + "=" * 70) |
|
|
print("PHASE 2: REINFORCEMENT LEARNING") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
print(f"Loading SFT state from: {sft_state_path}") |
|
|
await training_client.load_state_async(sft_state_path) |
|
|
|
|
|
print(f"Iterations: {RL_ITERATIONS}") |
|
|
print(f"Groups per batch: {RL_GROUPS_PER_BATCH}") |
|
|
print(f"Group size: {RL_GROUP_SIZE}") |
|
|
print(f"Learning rate: {RL_LR:.2e}") |
|
|
print() |
|
|
|
|
|
for iteration in range(RL_ITERATIONS): |
|
|
iter_start = time.time() |
|
|
|
|
|
|
|
|
save_future = await training_client.save_weights_for_sampler_async(name=f"rl_iter_{iteration:03d}") |
|
|
save_result = await save_future.result_async() |
|
|
sampling_client = service_client.create_sampling_client(model_path=save_result.path) |
|
|
|
|
|
|
|
|
rollout_groups = await collect_rollouts( |
|
|
sampling_client, renderer, train_data, |
|
|
RL_GROUPS_PER_BATCH, RL_GROUP_SIZE |
|
|
) |
|
|
|
|
|
|
|
|
active_groups = filter_constant_reward_groups(rollout_groups) |
|
|
|
|
|
|
|
|
all_rewards = [] |
|
|
for group in rollout_groups: |
|
|
all_rewards.extend(group.get_rewards()) |
|
|
|
|
|
|
|
|
group_advantages = compute_group_advantages(active_groups) |
|
|
|
|
|
|
|
|
training_data = [] |
|
|
for group, advantages in zip(active_groups, group_advantages): |
|
|
for rollout, advantage in zip(group.rollouts, advantages): |
|
|
try: |
|
|
datum = build_rl_datum(rollout, advantage) |
|
|
training_data.append(datum) |
|
|
except AssertionError as e: |
|
|
print(f"Warning: Skipping datum due to length mismatch: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
kl_samples = [] |
|
|
if training_data: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
if training_data: |
|
|
fwd_future = await training_client.forward_backward_async( |
|
|
training_data, loss_fn="importance_sampling" |
|
|
) |
|
|
optim_future = await training_client.optim_step_async( |
|
|
types.AdamParams(learning_rate=RL_LR, beta1=0.9, beta2=0.95, eps=1e-8) |
|
|
) |
|
|
|
|
|
fwd_result = await fwd_future.result_async() |
|
|
await optim_future.result_async() |
|
|
|
|
|
|
|
|
|
|
|
for i, output in enumerate(fwd_result.loss_fn_outputs): |
|
|
new_logprobs = output['logprobs'].tolist() |
|
|
old_logprobs = training_data[i].loss_fn_inputs['logprobs'].tolist() |
|
|
|
|
|
for new_lp, old_lp in zip(new_logprobs, old_logprobs): |
|
|
if old_lp != 0.0: |
|
|
kl_samples.append(new_lp - old_lp) |
|
|
|
|
|
iter_time = time.time() - iter_start |
|
|
|
|
|
|
|
|
mean_reward = np.mean(all_rewards) if all_rewards else 0 |
|
|
std_reward = np.std(all_rewards) if all_rewards else 0 |
|
|
accuracy = sum(1 for r in all_rewards if r > 0) / len(all_rewards) if all_rewards else 0 |
|
|
kl_divergence = np.mean(kl_samples) if kl_samples else 0 |
|
|
|
|
|
metrics = { |
|
|
"mean_reward": float(mean_reward), |
|
|
"std_reward": float(std_reward), |
|
|
"accuracy": accuracy, |
|
|
"kl_divergence": float(kl_divergence), |
|
|
"total_groups": len(rollout_groups), |
|
|
"active_groups": len(active_groups), |
|
|
"num_training_examples": len(training_data), |
|
|
"iter_time": iter_time, |
|
|
"checkpoint": save_result.path |
|
|
} |
|
|
|
|
|
logger.log_rl(iteration, metrics) |
|
|
|
|
|
|
|
|
if abs(kl_divergence) > 0.01: |
|
|
print(f"Warning: KL divergence {kl_divergence:.4f} exceeds threshold 0.01") |
|
|
|
|
|
|
|
|
final_future = await training_client.save_weights_for_sampler_async(name="rl_final") |
|
|
final_result = await final_future.result_async() |
|
|
|
|
|
print(f"\nRL Complete. Final checkpoint: {final_result.path}") |
|
|
|
|
|
return final_result.path |
|
|
|
|
|
|
|
|
async def evaluate_model( |
|
|
service_client: tinker.ServiceClient, |
|
|
model_path: str, |
|
|
renderer: renderers.Renderer, |
|
|
test_data: List[Dict], |
|
|
n_samples: int = 100 |
|
|
) -> Dict[str, float]: |
|
|
"""Evaluate model on test data.""" |
|
|
print(f"\nEvaluating: {model_path}") |
|
|
|
|
|
sampling_client = service_client.create_sampling_client(model_path=model_path) |
|
|
stop_sequences = renderer.get_stop_sequences() |
|
|
params = types.SamplingParams(max_tokens=100, temperature=0.1, stop=stop_sequences) |
|
|
|
|
|
correct_any = 0 |
|
|
correct_exact = 0 |
|
|
total_f1 = 0 |
|
|
|
|
|
for item in test_data[:n_samples]: |
|
|
messages = item.get("messages", []) |
|
|
gold = item.get("categories", []) |
|
|
|
|
|
prompt_messages = messages[:-1] if messages else [] |
|
|
if not prompt_messages: |
|
|
continue |
|
|
|
|
|
prompt = renderer.build_generation_prompt(prompt_messages) |
|
|
result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result() |
|
|
response, _ = renderer.parse_response(result.sequences[0].tokens) |
|
|
pred = response["content"] |
|
|
|
|
|
pred_set = set([c.strip().lower() for c in pred.split(",") |
|
|
if c.strip().lower() in VALID_CATEGORIES]) |
|
|
gold_set = set([c.lower() for c in gold]) |
|
|
|
|
|
if pred_set & gold_set: |
|
|
correct_any += 1 |
|
|
if pred_set == gold_set: |
|
|
correct_exact += 1 |
|
|
|
|
|
|
|
|
if pred_set and gold_set: |
|
|
tp = len(pred_set & gold_set) |
|
|
precision = tp / len(pred_set) |
|
|
recall = tp / len(gold_set) |
|
|
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
|
|
total_f1 += f1 |
|
|
|
|
|
n = min(n_samples, len(test_data)) |
|
|
return { |
|
|
"any_match": correct_any / n, |
|
|
"exact_match": correct_exact / n, |
|
|
"f1": total_f1 / n |
|
|
} |
|
|
|
|
|
|
|
|
async def main(): |
|
|
print("=" * 70) |
|
|
print("MEMORY ROUTING AGENT - FIXED TRAINING PIPELINE") |
|
|
print("=" * 70) |
|
|
print(f"Log directory: {LOG_DIR}") |
|
|
print(f"Model: {BASE_MODEL}") |
|
|
print() |
|
|
|
|
|
|
|
|
service_client = tinker.ServiceClient() |
|
|
tokenizer = get_tokenizer(BASE_MODEL) |
|
|
renderer = renderers.get_renderer(name="llama3", tokenizer=tokenizer) |
|
|
|
|
|
|
|
|
with open(TRAIN_DATA, "r") as f: |
|
|
train_data = json.load(f) |
|
|
with open(TEST_DATA, "r") as f: |
|
|
test_data = json.load(f) |
|
|
|
|
|
print(f"Train: {len(train_data)}, Test: {len(test_data)}") |
|
|
|
|
|
|
|
|
logger = TrainingLogger(LOG_DIR) |
|
|
|
|
|
|
|
|
training_client = await service_client.create_lora_training_client_async( |
|
|
base_model=BASE_MODEL, rank=LORA_RANK |
|
|
) |
|
|
|
|
|
|
|
|
sft_state, sft_sampler = await run_sft( |
|
|
service_client, training_client, tokenizer, renderer, |
|
|
train_data, test_data, logger |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "-" * 70) |
|
|
sft_results = await evaluate_model(service_client, sft_sampler, renderer, test_data) |
|
|
print(f"SFT Results: Any={sft_results['any_match']:.1%}, Exact={sft_results['exact_match']:.1%}, F1={sft_results['f1']:.1%}") |
|
|
|
|
|
|
|
|
rl_final = await run_rl( |
|
|
service_client, training_client, sft_state, |
|
|
tokenizer, renderer, train_data, test_data, logger |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "-" * 70) |
|
|
rl_results = await evaluate_model(service_client, rl_final, renderer, test_data) |
|
|
print(f"RL Results: Any={rl_results['any_match']:.1%}, Exact={rl_results['exact_match']:.1%}, F1={rl_results['f1']:.1%}") |
|
|
|
|
|
logger.close() |
|
|
|
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("TRAINING COMPLETE - SUMMARY") |
|
|
print("=" * 70) |
|
|
print(f"Logs: {LOG_DIR}") |
|
|
print(f"SFT Checkpoint: {sft_sampler}") |
|
|
print(f"RL Checkpoint: {rl_final}") |
|
|
print() |
|
|
print("Performance Comparison:") |
|
|
print(f"{'Metric':<15} {'SFT':>10} {'RL':>10} {'Delta':>10}") |
|
|
print("-" * 45) |
|
|
for metric in ['any_match', 'exact_match', 'f1']: |
|
|
sft_val = sft_results[metric] |
|
|
rl_val = rl_results[metric] |
|
|
delta = rl_val - sft_val |
|
|
print(f"{metric:<15} {sft_val:>10.1%} {rl_val:>10.1%} {delta:>+10.1%}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main()) |
|
|
|
|
|
|