File size: 4,462 Bytes
1482463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import os
import torch_npu
from automr import AutoMR, AutoMRTrainer, AutoMREvaluator, AutoMRConfig
from automr.data_loader import DataLoader


def parse_args():
    parser = argparse.ArgumentParser()
    
    # Mode
    parser.add_argument('--mode', type=str, default='train', choices=['train', 'eval'], help='Mode: train, eval, or train_eval')
    
    # Model settings
    parser.add_argument('--model_name', type=str, default='Qwen/Qwen2.5-3B-Instruct', help='Pretrained LLM model name')
    parser.add_argument('--device', type=str, default='cuda', choices=['cuda', 'cpu', 'npu'], help='Device to use')
    parser.add_argument('--token_budget', type=int, default=256, help='Token budget for reasoning')
    parser.add_argument('--hidden_size', type=int, default=4096, help='Hidden size of the model')
    
    # Training settings
    parser.add_argument('--learning_rate', type=float, default=5e-4, help='Learning rate')
    parser.add_argument('--num_epochs', type=int, default=5, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
    parser.add_argument('--num_samples', type=int, default=1, help='Number of skeletons to sample per query (M)')
    
    # Data paths
    parser.add_argument('--train_data', type=str, default='data/train.json', help='Path to training data')
    parser.add_argument('--val_data', type=str, default='data/val.json', help='Path to validation data')
    parser.add_argument('--test_data', type=str, default='data/test.json', help='Path to test data')
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='Directory to save checkpoints')
    parser.add_argument('--results_dir', type=str, default='results', help='Directory to save results')
    
    # Checkpoint
    parser.add_argument('--load_checkpoint', type=str, default=None, help='Path to checkpoint to load')
    
    # Task type
    parser.add_argument('--task_type', type=str, default='math', choices=['math', 'multiple_choice'], help='Task type')
    
    return parser.parse_args()


def main():
    args = parse_args()
    
    # Create configuration
    config = AutoMRConfig(
        model_name=args.model_name,
        device=args.device,
        token_budget=args.token_budget,
        learning_rate=args.learning_rate,
        num_epochs=args.num_epochs,
        batch_size=args.batch_size,
        num_samples_per_query=args.num_samples,
        train_data_path=args.train_data,
        val_data_path=args.val_data,
        test_data_path=args.test_data,
        checkpoint_dir=args.checkpoint_dir,
        results_dir=args.results_dir,
        task_type=args.task_type,
        hidden_size=args.hidden_size,
    )
    
    print("="*80)
    print("AutoMR: Automatic Meta-Reasoning Skeleton Search")
    print("="*80)
    print(f"\nConfiguration:")
    print(f"  Model: {config.model_name}")
    print(f"  Device: {config.device}")
    print(f"  Token Budget: {config.token_budget}")
    print(f"  Task Type: {config.task_type}")
    print(f"  Mode: {args.mode}")
    print("="*80)
    
    # Initialize model
    model = AutoMR(config)
    
    # Load checkpoint if specified
    if args.load_checkpoint and os.path.exists(args.load_checkpoint):
        model.load_checkpoint(args.load_checkpoint)
    
    # Training mode
    if args.mode == 'train':
        print(f"\n{'='*80}")
        print("TRAINING")
        print("="*80)
        
        # Load training data
        train_data = DataLoader.load_data(config.train_data_path)
        val_data = DataLoader.load_data(config.val_data_path)
        print(f"Loaded {len(train_data)} training samples from {config.train_data_path}")
        
        # Train
        trainer = AutoMRTrainer(model, config)
        trainer.train(train_data, val_data)
    
    # Evaluation mode
    elif args.mode == 'eval':
        print(f"\n{'='*80}")
        print("EVALUATION")
        print("="*80)
        
        # Load test data
        test_data = DataLoader.load_data(config.test_data_path)
        print(f"Loaded {len(test_data)} test samples from {config.test_data_path}")
        
        # Evaluate
        evaluator = AutoMREvaluator(model, config)
        accuracy, results = evaluator.evaluate(test_data)
        
        print(f"\n{'='*80}")
        print(f"Final Accuracy: {accuracy:.2%}")
        print("="*80)

    else:
        raise NotImplementedError


if __name__ == "__main__":
    main()