import argparse import os import torch_npu from automr import AutoMR, AutoMRTrainer, AutoMREvaluator, AutoMRConfig from automr.data_loader import DataLoader def parse_args(): parser = argparse.ArgumentParser() # Mode parser.add_argument('--mode', type=str, default='train', choices=['train', 'eval'], help='Mode: train, eval, or train_eval') # Model settings parser.add_argument('--model_name', type=str, default='Qwen/Qwen2.5-3B-Instruct', help='Pretrained LLM model name') parser.add_argument('--device', type=str, default='cuda', choices=['cuda', 'cpu', 'npu'], help='Device to use') parser.add_argument('--token_budget', type=int, default=256, help='Token budget for reasoning') parser.add_argument('--hidden_size', type=int, default=4096, help='Hidden size of the model') # Training settings parser.add_argument('--learning_rate', type=float, default=5e-4, help='Learning rate') parser.add_argument('--num_epochs', type=int, default=5, help='Number of training epochs') parser.add_argument('--batch_size', type=int, default=8, help='Batch size') parser.add_argument('--num_samples', type=int, default=1, help='Number of skeletons to sample per query (M)') # Data paths parser.add_argument('--train_data', type=str, default='data/train.json', help='Path to training data') parser.add_argument('--val_data', type=str, default='data/val.json', help='Path to validation data') parser.add_argument('--test_data', type=str, default='data/test.json', help='Path to test data') parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='Directory to save checkpoints') parser.add_argument('--results_dir', type=str, default='results', help='Directory to save results') # Checkpoint parser.add_argument('--load_checkpoint', type=str, default=None, help='Path to checkpoint to load') # Task type parser.add_argument('--task_type', type=str, default='math', choices=['math', 'multiple_choice'], help='Task type') return parser.parse_args() def main(): args = parse_args() # Create configuration config = AutoMRConfig( model_name=args.model_name, device=args.device, token_budget=args.token_budget, learning_rate=args.learning_rate, num_epochs=args.num_epochs, batch_size=args.batch_size, num_samples_per_query=args.num_samples, train_data_path=args.train_data, val_data_path=args.val_data, test_data_path=args.test_data, checkpoint_dir=args.checkpoint_dir, results_dir=args.results_dir, task_type=args.task_type, hidden_size=args.hidden_size, ) print("="*80) print("AutoMR: Automatic Meta-Reasoning Skeleton Search") print("="*80) print(f"\nConfiguration:") print(f" Model: {config.model_name}") print(f" Device: {config.device}") print(f" Token Budget: {config.token_budget}") print(f" Task Type: {config.task_type}") print(f" Mode: {args.mode}") print("="*80) # Initialize model model = AutoMR(config) # Load checkpoint if specified if args.load_checkpoint and os.path.exists(args.load_checkpoint): model.load_checkpoint(args.load_checkpoint) # Training mode if args.mode == 'train': print(f"\n{'='*80}") print("TRAINING") print("="*80) # Load training data train_data = DataLoader.load_data(config.train_data_path) val_data = DataLoader.load_data(config.val_data_path) print(f"Loaded {len(train_data)} training samples from {config.train_data_path}") # Train trainer = AutoMRTrainer(model, config) trainer.train(train_data, val_data) # Evaluation mode elif args.mode == 'eval': print(f"\n{'='*80}") print("EVALUATION") print("="*80) # Load test data test_data = DataLoader.load_data(config.test_data_path) print(f"Loaded {len(test_data)} test samples from {config.test_data_path}") # Evaluate evaluator = AutoMREvaluator(model, config) accuracy, results = evaluator.evaluate(test_data) print(f"\n{'='*80}") print(f"Final Accuracy: {accuracy:.2%}") print("="*80) else: raise NotImplementedError if __name__ == "__main__": main()