| |
| """ |
| Main training orchestration script. |
| Loops training + evaluation until SUCCESS CRITERIA are met. |
| """ |
|
|
| import json |
| import subprocess |
| import time |
| import yaml |
| from pathlib import Path |
| from datetime import datetime |
|
|
| |
| CONFIG_PATH = Path("/home/finetune/router-finetune-agent/config/training_config.yaml") |
| EVAL_RESULTS_PATH = Path("/home/finetune/router-finetune-agent/logs/eval_results.json") |
| RETRY_HISTORY_PATH = Path("/home/finetune/router-finetune-agent/logs/retry_history.json") |
| CHECKPOINT_DIR = Path("/home/finetune/router-finetune-agent/checkpoints") |
|
|
| |
| SUCCESS_CRITERIA = { |
| "routing_accuracy": 0.85, |
| "macro_f1": 0.80, |
| "avg_routing_latency_ms": 20.0, |
| } |
|
|
| MAX_RETRIES = 10 |
|
|
|
|
| def load_config(): |
| with open(CONFIG_PATH, 'r') as f: |
| return yaml.safe_load(f) |
|
|
|
|
| def save_config(config): |
| with open(CONFIG_PATH, 'w') as f: |
| yaml.dump(config, f) |
|
|
|
|
| def load_eval_results(): |
| if not EVAL_RESULTS_PATH.exists(): |
| return None |
| with open(EVAL_RESULTS_PATH, 'r') as f: |
| return json.load(f) |
|
|
|
|
| def load_retry_history(): |
| if RETRY_HISTORY_PATH.exists(): |
| with open(RETRY_HISTORY_PATH, 'r') as f: |
| return json.load(f) |
| return [] |
|
|
|
|
| def save_retry_history(history): |
| with open(RETRY_HISTORY_PATH, 'w') as f: |
| json.dump(history, f, indent=2) |
|
|
|
|
| def check_success(results): |
| """Check if all success criteria are met.""" |
| failures = [] |
|
|
| if results['routing_accuracy'] < SUCCESS_CRITERIA['routing_accuracy']: |
| failures.append(f"accuracy: {results['routing_accuracy']:.4f} < {SUCCESS_CRITERIA['routing_accuracy']}") |
|
|
| if results['macro_f1'] < SUCCESS_CRITERIA['macro_f1']: |
| failures.append(f"macro_f1: {results['macro_f1']:.4f} < {SUCCESS_CRITERIA['macro_f1']}") |
|
|
| if results['avg_routing_latency_ms'] > SUCCESS_CRITERIA['avg_routing_latency_ms']: |
| failures.append(f"latency: {results['avg_routing_latency_ms']:.2f}ms > {SUCCESS_CRITERIA['avg_routing_latency_ms']}ms") |
|
|
| return len(failures) == 0, failures |
|
|
|
|
| def run_command(cmd, description): |
| """Run a shell command and return success status.""" |
| print(f"\n{'='*60}") |
| print(f"[{description}]") |
| print(f"{'='*60}") |
| print(f"CMD: {' '.join(cmd)}") |
|
|
| result = subprocess.run(cmd, cwd=Path("/home/finetune/router-finetune-agent")) |
| return result.returncode == 0 |
|
|
|
|
| def apply_retry_strategy(retry_count, config, results): |
| """Apply retry strategy based on failure type.""" |
| accuracy = results['routing_accuracy'] |
| macro_f1 = results['macro_f1'] |
| latency = results['avg_routing_latency_ms'] |
|
|
| print(f"\nApplying retry strategy #{retry_count}...") |
|
|
| if accuracy < 0.75: |
| |
| old_samples = config.get('num_samples', 10000) |
| new_samples = old_samples + 5000 |
| config['num_samples'] = new_samples |
| reason = f"accuracy {accuracy:.4f} < 0.75, increased samples {old_samples} -> {new_samples}" |
|
|
| elif accuracy < 0.85: |
| |
| old_epochs = config.get('num_train_epochs', 7) |
| old_lr = config.get('learning_rate', 0.0001) |
| config['num_train_epochs'] = old_epochs + 2 |
| config['learning_rate'] = old_lr * 0.5 |
| reason = f"accuracy {accuracy:.4f} < 0.85, epochs {old_epochs} -> {old_epochs+2}, lr {old_lr} -> {old_lr*0.5}" |
|
|
| else: |
| |
| old_lr = config.get('learning_rate', 0.0001) |
| old_warmup = config.get('warmup_ratio', 0.05) |
| config['learning_rate'] = old_lr * 0.3 |
| config['warmup_ratio'] = old_warmup + 0.03 |
| reason = f"convergence issue, lr {old_lr} -> {old_lr*0.3}, warmup {old_warmup} -> {old_warmup+0.03}" |
|
|
| |
| config['resume_checkpoint'] = None |
|
|
| save_config(config) |
| print(f"Config updated: {reason}") |
|
|
| return reason |
|
|
|
|
| def main(): |
| print("="*60) |
| print("ROUTER FINETUNE AUTONOMOUS PIPELINE") |
| print("="*60) |
|
|
| retry_history = load_retry_history() |
| config = load_config() |
|
|
| for attempt in range(1, MAX_RETRIES + 1): |
| print(f"\n{'#'*60}") |
| print(f"# ATTEMPT {attempt}/{MAX_RETRIES}") |
| print(f"{'#'*60}") |
|
|
| |
| config['retry_count'] = attempt - 1 |
| save_config(config) |
|
|
| |
| if attempt == 1: |
| print("\n[STAGE 1] Generating data...") |
| success = run_command(['python3', 'pipeline/01_generate_data.py'], "Data Generation") |
| if not success: |
| print("ERROR: Data generation failed") |
| continue |
|
|
| |
| print("\n[STAGE 2] Labeling data...") |
| success = run_command(['python3', 'pipeline/02_label_data.py'], "Data Labeling") |
| if not success: |
| print("ERROR: Data labeling failed") |
| continue |
|
|
| |
| print("\n[STAGE 3] Training model...") |
| success = run_command([ |
| 'torchrun', |
| '--nproc_per_node=8', |
| 'pipeline/03_finetune.py' |
| ], "Model Training") |
| if not success: |
| print("ERROR: Training failed") |
| |
| retry_history.append({ |
| "attempt": attempt, |
| "stage": "training", |
| "success": False, |
| "timestamp": datetime.utcnow().isoformat() + "Z" |
| }) |
| save_retry_history(retry_history) |
| continue |
|
|
| |
| print("\n[STAGE 4] Evaluating model...") |
| success = run_command(['python3', 'pipeline/04_evaluate.py'], "Model Evaluation") |
| if not success: |
| print("ERROR: Evaluation failed") |
| retry_history.append({ |
| "attempt": attempt, |
| "stage": "evaluation", |
| "success": False, |
| "timestamp": datetime.utcnow().isoformat() + "Z" |
| }) |
| save_retry_history(retry_history) |
| continue |
|
|
| |
| results = load_eval_results() |
| if results is None: |
| print("ERROR: No evaluation results found") |
| continue |
|
|
| print("\n" + "="*60) |
| print("EVALUATION RESULTS") |
| print("="*60) |
| print(f" Accuracy: {results['routing_accuracy']:.4f} (target: >= 0.85)") |
| print(f" Macro F1: {results['macro_f1']:.4f} (target: >= 0.80)") |
| print(f" Latency: {results['avg_routing_latency_ms']:.2f}ms (target: <= 20ms)") |
| print(f" Loss: {results['eval_loss']:.4f}") |
|
|
| |
| success, failures = check_success(results) |
|
|
| if success: |
| print("\n" + "="*60) |
| print("SUCCESS! ALL CRITERIA MET!") |
| print("="*60) |
|
|
| |
| print("\n[STAGE 5] Exporting model...") |
| run_command(['python3', 'pipeline/05_export_model.py'], "Model Export") |
|
|
| retry_history.append({ |
| "attempt": attempt, |
| "stage": "complete", |
| "success": True, |
| "results": results, |
| "timestamp": datetime.utcnow().isoformat() + "Z" |
| }) |
| save_retry_history(retry_history) |
| print("\n*** PIPELINE COMPLETE ***") |
| return True |
|
|
| else: |
| print("\n" + "="*60) |
| print("CRITERIA NOT MET") |
| print("="*60) |
| for f in failures: |
| print(f" - {f}") |
|
|
| |
| reason = apply_retry_strategy(attempt, config, results) |
|
|
| retry_history.append({ |
| "attempt": attempt, |
| "stage": "retry", |
| "success": False, |
| "failures": failures, |
| "action": reason, |
| "timestamp": datetime.utcnow().isoformat() + "Z" |
| }) |
| save_retry_history(retry_history) |
|
|
| |
| time.sleep(5) |
|
|
| print(f"\n{'!'*60}") |
| print(f"MAX RETRIES ({MAX_RETRIES}) EXCEEDED - PIPELINE FAILED") |
| print(f"{'!'*60}") |
| return False |
|
|
|
|
| if __name__ == "__main__": |
| success = main() |
| exit(0 if success else 1) |
|
|