| """ |
| Supervised Fine-Tuning (SFT) for Memory Routing |
| |
| This implements Stage 1 of the PRD: Prompt Distillation using Tinker's |
| cross_entropy loss function with LoRA fine-tuning. |
| |
| Per Tinker docs (supervised-learning.mdx): |
| - SFT means maximizing log-probability of target tokens |
| - Use cross_entropy loss: -(weights * logp(target_tokens)).sum() |
| |
| Per Tinker docs (lora-primer.mdx): |
| - LoRA requires larger LR than full fine-tuning (20-100x) |
| - Use get_lr() utility to get recommended LR |
| - Default rank 32 is suitable for classification tasks |
| |
| Per Tinker docs (async.mdx): |
| - Use async methods for performance |
| - Double await pattern: await future, then await result_async() |
| |
| Per PRD Section 7: |
| - 300-500 steps minimum |
| - Batch size 128 |
| - Early stopping if test loss plateaus |
| - Checkpoint every 20 steps |
| """ |
|
|
| import asyncio |
| import json |
| import time |
| from typing import List, Dict, Any, Optional |
| from dataclasses import dataclass, field |
|
|
| |
| @dataclass |
| class SFTConfig: |
| |
| base_model: str = "meta-llama/Llama-3.1-8B" |
| lora_rank: int = 32 |
| renderer_name: str = "llama3" |
| |
| |
| num_steps: int = 300 |
| batch_size: int = 128 |
| learning_rate: Optional[float] = None |
| |
| |
| beta1: float = 0.9 |
| beta2: float = 0.95 |
| eps: float = 1e-8 |
| |
| |
| checkpoint_every: int = 20 |
| eval_every: int = 20 |
| |
| |
| early_stopping_patience: int = 5 |
| |
| |
| train_data_path: str = "training/processed_data/train_data.json" |
| test_data_path: str = "training/processed_data/test_data.json" |
| log_path: str = "training/logs/sft" |
|
|
|
|
| @dataclass |
| class TrainingMetrics: |
| step: int |
| train_loss: float |
| test_loss: Optional[float] = None |
| learning_rate: float = 0.0 |
| batch_time: float = 0.0 |
| checkpoint_path: Optional[str] = None |
|
|
|
|
| def load_processed_data(path: str) -> List[Dict[str, Any]]: |
| """Load preprocessed data from JSON.""" |
| with open(path, "r") as f: |
| return json.load(f) |
|
|
|
|
| def create_batch(data: List[Any], batch_size: int, step: int) -> List[Any]: |
| """ |
| Create a batch of data for training. |
| Cycles through data if step * batch_size exceeds data length. |
| """ |
| start_idx = (step * batch_size) % len(data) |
| end_idx = start_idx + batch_size |
| |
| if end_idx <= len(data): |
| return data[start_idx:end_idx] |
| else: |
| |
| batch = data[start_idx:] |
| batch.extend(data[:end_idx - len(data)]) |
| return batch |
|
|
|
|
| async def run_sft_training(config: SFTConfig): |
| """ |
| Main SFT training loop. |
| |
| Per Tinker docs (training-sampling.mdx): |
| 1. Create ServiceClient |
| 2. Create TrainingClient with base_model and LoRA config |
| 3. Loop: forward_backward -> optim_step |
| 4. Periodically save checkpoints and evaluate |
| """ |
| import tinker |
| from tinker import types |
| from tinker_cookbook.hyperparam_utils import get_lr |
| from tinker_cookbook import renderers, tokenizer_utils |
| import numpy as np |
| import os |
| from dotenv import load_dotenv |
| |
| |
| load_dotenv() |
| |
| os.makedirs(config.log_path, exist_ok=True) |
| |
| |
| if config.learning_rate is None: |
| config.learning_rate = get_lr(config.base_model) |
| print(f"Using recommended LR for {config.base_model}: {config.learning_rate:.2e}") |
| |
| |
| print(f"Loading training data from {config.train_data_path}...") |
| train_data_raw = load_processed_data(config.train_data_path) |
| print(f"Loading test data from {config.test_data_path}...") |
| test_data_raw = load_processed_data(config.test_data_path) |
| |
| print(f"Train examples: {len(train_data_raw)}") |
| print(f"Test examples: {len(test_data_raw)}") |
| |
| |
| print(f"\nInitializing Tinker ServiceClient...") |
| service_client = tinker.ServiceClient() |
| |
| print(f"Creating LoRA training client...") |
| print(f" Base model: {config.base_model}") |
| print(f" LoRA rank: {config.lora_rank}") |
| |
| training_client = await service_client.create_lora_training_client_async( |
| base_model=config.base_model, |
| rank=config.lora_rank, |
| ) |
| |
| |
| tokenizer = training_client.get_tokenizer() |
| renderer = renderers.get_renderer(name=config.renderer_name, tokenizer=tokenizer) |
| |
| |
| print("Converting data to Datum objects...") |
| |
| def convert_to_datum(item: Dict) -> types.Datum: |
| """Convert preprocessed item back to Datum.""" |
| if "model_input" in item: |
| |
| return types.Datum( |
| model_input=types.ModelInput.from_ints(item["model_input"]["chunks"][0]["tokens"]), |
| loss_fn_inputs=item["loss_fn_inputs"] |
| ) |
| else: |
| |
| messages = item["messages"] |
| tokens, weights = renderer.build_supervised_example(messages) |
| |
| |
| if hasattr(tokens, 'tolist'): |
| tokens = tokens.tolist() |
| if hasattr(weights, 'tolist'): |
| weights = weights.tolist() |
| |
| input_tokens = tokens[:-1] |
| target_tokens = tokens[1:] |
| loss_weights = weights[1:] |
| |
| return types.Datum( |
| model_input=types.ModelInput.from_ints(input_tokens), |
| loss_fn_inputs=dict( |
| target_tokens=target_tokens, |
| weights=loss_weights |
| ) |
| ) |
| |
| train_data = [convert_to_datum(item) for item in train_data_raw] |
| test_data = [convert_to_datum(item) for item in test_data_raw] |
| |
| print(f"Converted {len(train_data)} train, {len(test_data)} test examples") |
| |
| |
| print(f"\n{'='*60}") |
| print(f"Starting SFT Training") |
| print(f"{'='*60}") |
| print(f"Steps: {config.num_steps}") |
| print(f"Batch size: {config.batch_size}") |
| print(f"Learning rate: {config.learning_rate:.2e}") |
| print(f"Checkpoint every: {config.checkpoint_every} steps") |
| print(f"Eval every: {config.eval_every} steps") |
| print(f"{'='*60}\n") |
| |
| metrics_log = [] |
| best_test_loss = float('inf') |
| no_improvement_count = 0 |
| final_checkpoint_path = None |
| |
| for step in range(config.num_steps): |
| step_start = time.time() |
| |
| |
| batch = create_batch(train_data, config.batch_size, step) |
| |
| |
| |
| |
| fwd_bwd_future = await training_client.forward_backward_async( |
| batch, |
| loss_fn="cross_entropy" |
| ) |
| |
| |
| adam_params = types.AdamParams( |
| learning_rate=config.learning_rate, |
| beta1=config.beta1, |
| beta2=config.beta2, |
| eps=config.eps, |
| ) |
| optim_future = await training_client.optim_step_async(adam_params) |
| |
| |
| |
| fwd_bwd_result = await fwd_bwd_future.result_async() |
| optim_result = await optim_future.result_async() |
| |
| |
| |
| logprobs = np.concatenate([ |
| output['logprobs'].tolist() |
| for output in fwd_bwd_result.loss_fn_outputs |
| ]) |
| weights = np.concatenate([ |
| datum.loss_fn_inputs['weights'].tolist() |
| for datum in batch |
| ]) |
| train_loss = -np.dot(logprobs, weights) / max(weights.sum(), 1) |
| |
| step_time = time.time() - step_start |
| |
| |
| metrics = TrainingMetrics( |
| step=step, |
| train_loss=train_loss, |
| learning_rate=config.learning_rate, |
| batch_time=step_time |
| ) |
| |
| |
| if step % config.eval_every == 0 or step == config.num_steps - 1: |
| |
| test_batch = create_batch(test_data, min(config.batch_size, len(test_data)), 0) |
| |
| |
| eval_future = await training_client.forward_backward_async( |
| test_batch, |
| loss_fn="cross_entropy" |
| ) |
| eval_result = await eval_future.result_async() |
| |
| test_logprobs = np.concatenate([ |
| output['logprobs'].tolist() |
| for output in eval_result.loss_fn_outputs |
| ]) |
| test_weights = np.concatenate([ |
| datum.loss_fn_inputs['weights'].tolist() |
| for datum in test_batch |
| ]) |
| test_loss = -np.dot(test_logprobs, test_weights) / max(test_weights.sum(), 1) |
| metrics.test_loss = test_loss |
| |
| |
| if test_loss < best_test_loss: |
| best_test_loss = test_loss |
| no_improvement_count = 0 |
| else: |
| no_improvement_count += 1 |
| |
| if no_improvement_count >= config.early_stopping_patience: |
| print(f"\nEarly stopping at step {step} (no improvement for {config.early_stopping_patience} evals)") |
| break |
| |
| |
| if step % config.checkpoint_every == 0 or step == config.num_steps - 1: |
| |
| |
| |
| |
| sampler_future = await training_client.save_weights_for_sampler_async( |
| name=f"sft_step_{step:04d}" |
| ) |
| sampler_result = await sampler_future.result_async() |
| metrics.checkpoint_path = sampler_result.path |
| |
| |
| if step == config.num_steps - 1: |
| state_future = await training_client.save_state_async( |
| name=f"sft_final_state" |
| ) |
| state_result = await state_future.result_async() |
| final_checkpoint_path = state_result.path |
| print(f" Full state checkpoint: {final_checkpoint_path}") |
| else: |
| final_checkpoint_path = sampler_result.path |
| |
| metrics_log.append(metrics) |
| |
| |
| test_str = f", test_loss={metrics.test_loss:.4f}" if metrics.test_loss else "" |
| ckpt_str = f", checkpoint={metrics.checkpoint_path}" if metrics.checkpoint_path else "" |
| print(f"Step {step:4d}/{config.num_steps}: train_loss={train_loss:.4f}{test_str}, time={step_time:.1f}s{ckpt_str}") |
| |
| |
| metrics_path = os.path.join(config.log_path, "metrics.jsonl") |
| with open(metrics_path, "w") as f: |
| for m in metrics_log: |
| f.write(json.dumps({ |
| "step": m.step, |
| "train_loss": m.train_loss, |
| "test_loss": m.test_loss, |
| "learning_rate": m.learning_rate, |
| "batch_time": m.batch_time, |
| "checkpoint_path": m.checkpoint_path |
| }) + "\n") |
| |
| print(f"\n{'='*60}") |
| print(f"SFT Training Complete") |
| print(f"{'='*60}") |
| print(f"Final train loss: {metrics_log[-1].train_loss:.4f}") |
| print(f"Best test loss: {best_test_loss:.4f}") |
| print(f"Final checkpoint: {final_checkpoint_path}") |
| print(f"Metrics saved to: {metrics_path}") |
| print(f"{'='*60}") |
| |
| return final_checkpoint_path, metrics_log |
|
|
|
|
| async def main(): |
| """Entry point for SFT training.""" |
| import sys |
| |
| config = SFTConfig() |
| |
| |
| for arg in sys.argv[1:]: |
| if "=" in arg: |
| key, value = arg.split("=", 1) |
| if hasattr(config, key): |
| |
| current_value = getattr(config, key) |
| if isinstance(current_value, int): |
| setattr(config, key, int(value)) |
| elif isinstance(current_value, float): |
| setattr(config, key, float(value)) |
| else: |
| setattr(config, key, value) |
| |
| await run_sft_training(config) |
|
|
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |
|
|
|
|