|
|
""" |
|
|
Training script with detailed real-time logging. |
|
|
|
|
|
This script logs: |
|
|
- SFT: loss, test loss, accuracy at checkpoints |
|
|
- RL: reward, accuracy, KL divergence per iteration |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import json |
|
|
import os |
|
|
import time |
|
|
from datetime import datetime |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
import tinker |
|
|
from tinker import types |
|
|
from tinker_cookbook import renderers |
|
|
from tinker_cookbook.tokenizer_utils import get_tokenizer |
|
|
from tinker_cookbook.hyperparam_utils import get_lr |
|
|
import numpy as np |
|
|
from collections import Counter |
|
|
|
|
|
|
|
|
BASE_MODEL = "meta-llama/Llama-3.1-8B" |
|
|
LORA_RANK = 32 |
|
|
SFT_STEPS = 50 |
|
|
SFT_BATCH_SIZE = 32 |
|
|
RL_ITERATIONS = 15 |
|
|
RL_BATCH_SIZE = 16 |
|
|
RL_GROUP_SIZE = 4 |
|
|
RL_LR = 1e-5 |
|
|
|
|
|
|
|
|
TRAIN_DATA = "training/processed_data/train_data.json" |
|
|
TEST_DATA = "training/processed_data/test_data.json" |
|
|
LOG_DIR = "training/logs/run_" + datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
|
|
|
VALID_CATEGORIES = { |
|
|
"company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts", |
|
|
"company.business_priorities", "company.tools_config", "company.performance_context", |
|
|
"user.communication_style", "user.strategic_approach", "user.role_context", |
|
|
"user.workflow_patterns", "user.session_history", "user.interaction_preferences", |
|
|
"none" |
|
|
} |
|
|
|
|
|
SYSTEM_PROMPT = """You route marketing conversations into structured memory categories. |
|
|
|
|
|
Available categories: |
|
|
- company.brand_core: Voice, values, positioning, identity anchors (Long >1y) |
|
|
- company.strategic_signatures: Decision frameworks, strategic heuristics (Long >1y) |
|
|
- company.knowledge_artifacts: Docs, style guides, playbooks (Long >1y) |
|
|
- company.business_priorities: Quarterly/seasonal goals, active campaigns (Short <3m) |
|
|
- company.tools_config: Integrations, API keys, workflow settings (Medium ~6m) |
|
|
- company.performance_context: Campaign metrics, retrospectives, learnings (Rolling ~6m) |
|
|
- user.communication_style: Tone, verbosity, format expectations (Long >1y) |
|
|
- user.strategic_approach: Personal priorities, success definitions (Long >1y) |
|
|
- user.role_context: Title, scope, decision authority (Medium ~1y) |
|
|
- user.workflow_patterns: Review cadence, collaboration norms (Medium ~1y) |
|
|
- user.session_history: Immediate context, recent asks (Short <2w) |
|
|
- user.interaction_preferences: Coaching style, feedback expectations (Evolving) |
|
|
- none: Irrelevant, vague, or transactional content |
|
|
|
|
|
Respond with comma-separated categories. Use 'none' only if no other category applies.""" |
|
|
|
|
|
|
|
|
class TrainingLogger: |
|
|
def __init__(self, log_dir): |
|
|
os.makedirs(log_dir, exist_ok=True) |
|
|
self.log_dir = log_dir |
|
|
self.sft_log = open(os.path.join(log_dir, "sft_metrics.jsonl"), "w") |
|
|
self.rl_log = open(os.path.join(log_dir, "rl_metrics.jsonl"), "w") |
|
|
self.start_time = time.time() |
|
|
|
|
|
def log_sft(self, step, metrics): |
|
|
metrics["step"] = step |
|
|
metrics["elapsed_time"] = time.time() - self.start_time |
|
|
self.sft_log.write(json.dumps(metrics) + "\n") |
|
|
self.sft_log.flush() |
|
|
|
|
|
|
|
|
print(f"[SFT Step {step:3d}] " |
|
|
f"Loss: {metrics.get('train_loss', 0):.4f} | " |
|
|
f"Test: {metrics.get('test_loss', 'N/A')} | " |
|
|
f"Acc: {metrics.get('accuracy', 'N/A')} | " |
|
|
f"Time: {metrics.get('step_time', 0):.1f}s") |
|
|
|
|
|
def log_rl(self, iteration, metrics): |
|
|
metrics["iteration"] = iteration |
|
|
metrics["elapsed_time"] = time.time() - self.start_time |
|
|
self.rl_log.write(json.dumps(metrics) + "\n") |
|
|
self.rl_log.flush() |
|
|
|
|
|
|
|
|
print(f"[RL Iter {iteration:3d}] " |
|
|
f"Reward: {metrics.get('mean_reward', 0):.3f} | " |
|
|
f"Acc: {metrics.get('accuracy', 0):.1%} | " |
|
|
f"Format: {metrics.get('format_valid', 0):.1%} | " |
|
|
f"Time: {metrics.get('iter_time', 0):.1f}s") |
|
|
|
|
|
def close(self): |
|
|
self.sft_log.close() |
|
|
self.rl_log.close() |
|
|
|
|
|
|
|
|
def compute_reward(predicted_text, gold_categories): |
|
|
"""Compute reward for RL.""" |
|
|
if not predicted_text or not predicted_text.strip(): |
|
|
return -1.0, {"format_valid": False} |
|
|
|
|
|
predicted = set([c.strip().lower() for c in predicted_text.split(",") |
|
|
if c.strip().lower() in VALID_CATEGORIES]) |
|
|
|
|
|
if not predicted: |
|
|
return -1.0, {"format_valid": False} |
|
|
|
|
|
gold = set([c.lower() for c in gold_categories]) |
|
|
|
|
|
|
|
|
if predicted and gold: |
|
|
tp = len(predicted & gold) |
|
|
precision = tp / len(predicted) |
|
|
recall = tp / len(gold) |
|
|
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
|
|
else: |
|
|
f1 = 1.0 if not predicted and not gold else 0.0 |
|
|
|
|
|
return f1, {"format_valid": True, "f1": f1} |
|
|
|
|
|
|
|
|
async def evaluate_accuracy(sampling_client, renderer, test_data, n_samples=20): |
|
|
"""Quick accuracy evaluation.""" |
|
|
stop_sequences = renderer.get_stop_sequences() |
|
|
correct = 0 |
|
|
|
|
|
for item in test_data[:n_samples]: |
|
|
messages = item.get("messages", []) |
|
|
gold = item.get("categories", []) |
|
|
|
|
|
|
|
|
prompt_messages = messages[:-1] |
|
|
prompt = renderer.build_generation_prompt(prompt_messages) |
|
|
params = types.SamplingParams(max_tokens=100, temperature=0.1, stop=stop_sequences) |
|
|
|
|
|
result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result() |
|
|
response, _ = renderer.parse_response(result.sequences[0].tokens) |
|
|
pred = response["content"] |
|
|
|
|
|
pred_set = set([c.strip().lower() for c in pred.split(",") |
|
|
if c.strip().lower() in VALID_CATEGORIES]) |
|
|
gold_set = set([c.lower() for c in gold]) |
|
|
|
|
|
if pred_set & gold_set: |
|
|
correct += 1 |
|
|
|
|
|
return correct / n_samples |
|
|
|
|
|
|
|
|
async def run_sft(service_client, training_client, tokenizer, renderer, |
|
|
train_data, test_data, logger): |
|
|
"""Run SFT with detailed logging.""" |
|
|
print("\n" + "=" * 70) |
|
|
print("PHASE 1: SUPERVISED FINE-TUNING") |
|
|
print("=" * 70) |
|
|
|
|
|
lr = get_lr(BASE_MODEL) |
|
|
print(f"Learning rate: {lr:.2e}") |
|
|
print(f"Steps: {SFT_STEPS}, Batch size: {SFT_BATCH_SIZE}") |
|
|
print() |
|
|
|
|
|
|
|
|
def to_datum(item): |
|
|
messages = item.get("messages", []) |
|
|
tokens, weights = renderer.build_supervised_example(messages) |
|
|
if hasattr(tokens, 'tolist'): |
|
|
tokens = tokens.tolist() |
|
|
if hasattr(weights, 'tolist'): |
|
|
weights = weights.tolist() |
|
|
return types.Datum( |
|
|
model_input=types.ModelInput.from_ints(tokens[:-1]), |
|
|
loss_fn_inputs=dict(target_tokens=tokens[1:], weights=weights[1:]) |
|
|
) |
|
|
|
|
|
train_datums = [to_datum(item) for item in train_data] |
|
|
test_datums = [to_datum(item) for item in test_data[:50]] |
|
|
|
|
|
for step in range(SFT_STEPS): |
|
|
step_start = time.time() |
|
|
|
|
|
|
|
|
batch_idx = (step * SFT_BATCH_SIZE) % len(train_datums) |
|
|
batch = train_datums[batch_idx:batch_idx + SFT_BATCH_SIZE] |
|
|
if len(batch) < SFT_BATCH_SIZE: |
|
|
batch = batch + train_datums[:SFT_BATCH_SIZE - len(batch)] |
|
|
|
|
|
|
|
|
fwd_future = await training_client.forward_backward_async(batch, loss_fn="cross_entropy") |
|
|
optim_future = await training_client.optim_step_async( |
|
|
types.AdamParams(learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8) |
|
|
) |
|
|
|
|
|
fwd_result = await fwd_future.result_async() |
|
|
await optim_future.result_async() |
|
|
|
|
|
|
|
|
logprobs = np.concatenate([o['logprobs'].tolist() for o in fwd_result.loss_fn_outputs]) |
|
|
weights = np.concatenate([d.loss_fn_inputs['weights'].tolist() for d in batch]) |
|
|
train_loss = -np.dot(logprobs, weights) / max(weights.sum(), 1) |
|
|
|
|
|
step_time = time.time() - step_start |
|
|
|
|
|
metrics = {"train_loss": float(train_loss), "step_time": step_time} |
|
|
|
|
|
|
|
|
if step % 10 == 0 or step == SFT_STEPS - 1: |
|
|
|
|
|
eval_future = await training_client.forward_backward_async(test_datums, loss_fn="cross_entropy") |
|
|
eval_result = await eval_future.result_async() |
|
|
test_logprobs = np.concatenate([o['logprobs'].tolist() for o in eval_result.loss_fn_outputs]) |
|
|
test_weights = np.concatenate([d.loss_fn_inputs['weights'].tolist() for d in test_datums]) |
|
|
test_loss = -np.dot(test_logprobs, test_weights) / max(test_weights.sum(), 1) |
|
|
metrics["test_loss"] = float(test_loss) |
|
|
|
|
|
|
|
|
save_future = await training_client.save_weights_for_sampler_async(name=f"sft_step_{step:04d}") |
|
|
save_result = await save_future.result_async() |
|
|
|
|
|
sampling_client = service_client.create_sampling_client(model_path=save_result.path) |
|
|
accuracy = await evaluate_accuracy(sampling_client, renderer, test_data, n_samples=20) |
|
|
metrics["accuracy"] = f"{accuracy:.1%}" |
|
|
metrics["checkpoint"] = save_result.path |
|
|
|
|
|
logger.log_sft(step, metrics) |
|
|
|
|
|
|
|
|
state_future = await training_client.save_state_async(name="sft_final") |
|
|
state_result = await state_future.result_async() |
|
|
|
|
|
sampler_future = await training_client.save_weights_for_sampler_async(name="sft_final_sampler") |
|
|
sampler_result = await sampler_future.result_async() |
|
|
|
|
|
print(f"\nSFT Complete. State: {state_result.path}") |
|
|
|
|
|
return state_result.path, sampler_result.path |
|
|
|
|
|
|
|
|
async def run_rl(service_client, training_client, sft_state_path, |
|
|
tokenizer, renderer, train_data, test_data, logger): |
|
|
"""Run RL with detailed logging.""" |
|
|
print("\n" + "=" * 70) |
|
|
print("PHASE 2: REINFORCEMENT LEARNING") |
|
|
print("=" * 70) |
|
|
|
|
|
print(f"Loading SFT weights from: {sft_state_path}") |
|
|
await training_client.load_state_async(sft_state_path) |
|
|
|
|
|
print(f"Iterations: {RL_ITERATIONS}, Batch: {RL_BATCH_SIZE}, Group: {RL_GROUP_SIZE}") |
|
|
print() |
|
|
|
|
|
stop_sequences = renderer.get_stop_sequences() |
|
|
|
|
|
for iteration in range(RL_ITERATIONS): |
|
|
iter_start = time.time() |
|
|
|
|
|
|
|
|
save_future = await training_client.save_weights_for_sampler_async(name=f"rl_iter_{iteration:03d}") |
|
|
save_result = await save_future.result_async() |
|
|
sampling_client = service_client.create_sampling_client(model_path=save_result.path) |
|
|
|
|
|
|
|
|
batch_indices = np.random.choice(len(train_data), size=RL_BATCH_SIZE, replace=False) |
|
|
|
|
|
all_rewards = [] |
|
|
format_valid_count = 0 |
|
|
training_data = [] |
|
|
|
|
|
for idx in batch_indices: |
|
|
example = train_data[idx] |
|
|
gold = example.get("categories", []) |
|
|
messages = example.get("messages", []) |
|
|
prompt_messages = messages[:-1] |
|
|
|
|
|
if not prompt_messages: |
|
|
continue |
|
|
|
|
|
prompt = renderer.build_generation_prompt(prompt_messages) |
|
|
params = types.SamplingParams( |
|
|
max_tokens=100, temperature=0.7, stop=stop_sequences |
|
|
) |
|
|
|
|
|
result = sampling_client.sample( |
|
|
prompt=prompt, sampling_params=params, num_samples=RL_GROUP_SIZE |
|
|
).result() |
|
|
|
|
|
for seq in result.sequences: |
|
|
response, success = renderer.parse_response(seq.tokens) |
|
|
predicted = response["content"] if success else "" |
|
|
reward, info = compute_reward(predicted, gold) |
|
|
|
|
|
all_rewards.append(reward) |
|
|
if info["format_valid"]: |
|
|
format_valid_count += 1 |
|
|
|
|
|
|
|
|
if seq.logprobs and reward > -1: |
|
|
prompt_tokens = prompt.to_ints() |
|
|
gen_tokens = seq.tokens |
|
|
logprobs = seq.logprobs |
|
|
|
|
|
n_prompt = len(prompt_tokens) - 1 |
|
|
n_gen = len(gen_tokens) |
|
|
|
|
|
if len(logprobs) == n_gen: |
|
|
full_input = prompt_tokens + gen_tokens[:-1] if n_gen > 1 else prompt_tokens |
|
|
full_target = prompt_tokens[1:] + gen_tokens |
|
|
full_logprobs = [0.0] * n_prompt + logprobs |
|
|
full_advantages = [0.0] * n_prompt + [reward] * n_gen |
|
|
|
|
|
if len(full_target) == len(full_input): |
|
|
training_data.append(types.Datum( |
|
|
model_input=types.ModelInput.from_ints(full_input), |
|
|
loss_fn_inputs=dict( |
|
|
target_tokens=full_target, |
|
|
logprobs=full_logprobs, |
|
|
advantages=full_advantages |
|
|
) |
|
|
)) |
|
|
|
|
|
|
|
|
if training_data: |
|
|
|
|
|
rewards_arr = np.array(all_rewards) |
|
|
mean_r = rewards_arr.mean() |
|
|
std_r = rewards_arr.std() + 1e-8 |
|
|
|
|
|
fwd_future = await training_client.forward_backward_async( |
|
|
training_data, loss_fn="importance_sampling" |
|
|
) |
|
|
optim_future = await training_client.optim_step_async( |
|
|
types.AdamParams(learning_rate=RL_LR, beta1=0.9, beta2=0.95, eps=1e-8) |
|
|
) |
|
|
await fwd_future.result_async() |
|
|
await optim_future.result_async() |
|
|
|
|
|
iter_time = time.time() - iter_start |
|
|
|
|
|
metrics = { |
|
|
"mean_reward": float(np.mean(all_rewards)), |
|
|
"std_reward": float(np.std(all_rewards)), |
|
|
"accuracy": sum(1 for r in all_rewards if r > 0) / len(all_rewards) if all_rewards else 0, |
|
|
"format_valid": format_valid_count / len(all_rewards) if all_rewards else 0, |
|
|
"num_rollouts": len(all_rewards), |
|
|
"num_training": len(training_data), |
|
|
"iter_time": iter_time, |
|
|
"checkpoint": save_result.path |
|
|
} |
|
|
|
|
|
logger.log_rl(iteration, metrics) |
|
|
|
|
|
|
|
|
final_future = await training_client.save_weights_for_sampler_async(name="rl_final") |
|
|
final_result = await final_future.result_async() |
|
|
|
|
|
print(f"\nRL Complete. Final: {final_result.path}") |
|
|
|
|
|
return final_result.path |
|
|
|
|
|
|
|
|
async def main(): |
|
|
print("=" * 70) |
|
|
print("MEMORY ROUTING AGENT - TRAINING WITH DETAILED LOGGING") |
|
|
print("=" * 70) |
|
|
print(f"Log directory: {LOG_DIR}") |
|
|
print(f"Model: {BASE_MODEL}") |
|
|
print() |
|
|
|
|
|
|
|
|
service_client = tinker.ServiceClient() |
|
|
tokenizer = get_tokenizer(BASE_MODEL) |
|
|
renderer = renderers.get_renderer(name="llama3", tokenizer=tokenizer) |
|
|
|
|
|
|
|
|
with open(TRAIN_DATA, "r") as f: |
|
|
train_data = json.load(f) |
|
|
with open(TEST_DATA, "r") as f: |
|
|
test_data = json.load(f) |
|
|
|
|
|
print(f"Train: {len(train_data)}, Test: {len(test_data)}") |
|
|
|
|
|
|
|
|
logger = TrainingLogger(LOG_DIR) |
|
|
|
|
|
|
|
|
training_client = await service_client.create_lora_training_client_async( |
|
|
base_model=BASE_MODEL, rank=LORA_RANK |
|
|
) |
|
|
|
|
|
|
|
|
sft_state, sft_sampler = await run_sft( |
|
|
service_client, training_client, tokenizer, renderer, |
|
|
train_data, test_data, logger |
|
|
) |
|
|
|
|
|
|
|
|
rl_final = await run_rl( |
|
|
service_client, training_client, sft_state, |
|
|
tokenizer, renderer, train_data, test_data, logger |
|
|
) |
|
|
|
|
|
logger.close() |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("TRAINING COMPLETE") |
|
|
print("=" * 70) |
|
|
print(f"Logs: {LOG_DIR}") |
|
|
print(f"SFT: {sft_sampler}") |
|
|
print(f"RL: {rl_final}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main()) |
|
|
|
|
|
|