""" Hyperparameter sweep script for Q-TensorFormer v3. Runs a grid/search over key hyperparameters and produces comparative evaluation results. Usage: python scripts/sweep.py --preset sweep --output results/ """ import sys import os import json import itertools from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) import torch from src.config import ExperimentConfig, ModelConfig, TrainingConfig from src.models import create_model from src.baselines import StandardTransformer from src.data import load_wikitext2, load_synthetic_data from src.training import Trainer from src.metrics import evaluate_model, print_comparison_table, compute_pareto_frontier def run_sweep(base_config, sweep_params, train_loader, val_loader, test_loader, device="cpu", output_dir="./outputs/sweep/"): """ Run a hyperparameter sweep. Args: base_config: Base ExperimentConfig. sweep_params: Dict of param_name → [values]. """ keys = list(sweep_params.keys()) values = list(sweep_params.values()) os.makedirs(output_dir, exist_ok=True) results = {} configs = [] for combo in itertools.product(*values): config = ExperimentConfig( model=ModelConfig(**base_config.model.__dict__), training=TrainingConfig(**base_config.training.__dict__), ) # Apply sweep params param_dict = dict(zip(keys, combo)) for k, v in param_dict.items(): if "." in k: section, key = k.split(".") getattr(getattr(config, section), key).__class__.__dict__ setattr(getattr(config, section), key, v) else: if hasattr(config.model, k): setattr(config.model, k, v) elif hasattr(config.training, k): setattr(config.training, k, v) name = "_".join(f"{k}={v}" for k, v in param_dict.items()) config.experiment_name = name configs.append((name, config)) print(f"Running {len(configs)} configurations...") for i, (name, config) in enumerate(configs): print(f"\n[{i+1}/{len(configs)}] {name}") # Create model model = create_model(config, "qtensor") # Train trainer = Trainer( model, config, train_loader=train_loader, val_loader=val_loader, test_loader=test_loader, device=device, output_dir=f"{output_dir}/{name}", ) trainer.train() # Evaluate results[name] = evaluate_model(model, test_loader, device) # Save sweep results with open(f"{output_dir}/sweep_results.json", "w") as f: clean = {} for name, r in results.items(): clean[name] = {k: (float(v) if hasattr(v, "item") else v) for k, v in r.items()} json.dump(clean, f, indent=2) # Print summary print("\n" + "=" * 70) print("SWEEP RESULTS") print("=" * 70) print_comparison_table(results) pareto = compute_pareto_frontier(results) print(f"\nPareto-optimal: {pareto}") # Best by metric best_ppl = min(results.items(), key=lambda x: x[1]["test_ppl"]) best_params = min(results.items(), key=lambda x: x[1]["total_params"]) print(f"\nBest PPL: {best_ppl[0]} ({best_ppl[1]['test_ppl']:.2f})") print(f"Fewest params: {best_params[0]} ({best_params[1]['total_params']:,})") return results def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--epochs", type=int, default=5) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--device", type=str, default="cpu") parser.add_argument("--output", type=str, default="./outputs/sweep/") parser.add_argument("--synthetic", action="store_true") args = parser.parse_args() torch.manual_seed(42) # Base config config = ExperimentConfig( model=ModelConfig(d_model=128, n_layers=2, n_heads=4, tt_rank=8, vocab_size=10000, max_seq_len=128), training=TrainingConfig(max_epochs=args.epochs, batch_size=args.batch_size), ) # Load data if args.synthetic: train_loader = load_synthetic_data(batch_size=args.batch_size) val_loader = None test_loader = train_loader else: train_loader, val_loader, test_loader, tokenizer = load_wikitext2( seq_len=128, batch_size=args.batch_size ) config.model.vocab_size = tokenizer.vocab_size # Sweep parameters sweep = { "tt_rank": [2, 4, 8, 16], "use_quantum": [True, False], "quantum_sparsity": [0.5, 0.7, 0.9], "rank_alpha": [1.0, 2.0, 3.0], } # Limit combinations for manageable runtime # Full sweep: 4 * 2 * 3 * 3 = 72 combos # Reduced: tt_rank vs quantum vs alpha sweep = { "tt_rank": [2, 4, 8, 16], "use_quantum": [True, False], "quantum_sparsity": [0.7], # Fixed for now "rank_alpha": [2.0], # Fixed for now } # 4 * 2 = 8 combos run_sweep(config, sweep, train_loader, val_loader, test_loader, args.device, args.output) if __name__ == "__main__": main()