File size: 4,462 Bytes
1482463 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import argparse
import os
import torch_npu
from automr import AutoMR, AutoMRTrainer, AutoMREvaluator, AutoMRConfig
from automr.data_loader import DataLoader
def parse_args():
parser = argparse.ArgumentParser()
# Mode
parser.add_argument('--mode', type=str, default='train', choices=['train', 'eval'], help='Mode: train, eval, or train_eval')
# Model settings
parser.add_argument('--model_name', type=str, default='Qwen/Qwen2.5-3B-Instruct', help='Pretrained LLM model name')
parser.add_argument('--device', type=str, default='cuda', choices=['cuda', 'cpu', 'npu'], help='Device to use')
parser.add_argument('--token_budget', type=int, default=256, help='Token budget for reasoning')
parser.add_argument('--hidden_size', type=int, default=4096, help='Hidden size of the model')
# Training settings
parser.add_argument('--learning_rate', type=float, default=5e-4, help='Learning rate')
parser.add_argument('--num_epochs', type=int, default=5, help='Number of training epochs')
parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
parser.add_argument('--num_samples', type=int, default=1, help='Number of skeletons to sample per query (M)')
# Data paths
parser.add_argument('--train_data', type=str, default='data/train.json', help='Path to training data')
parser.add_argument('--val_data', type=str, default='data/val.json', help='Path to validation data')
parser.add_argument('--test_data', type=str, default='data/test.json', help='Path to test data')
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='Directory to save checkpoints')
parser.add_argument('--results_dir', type=str, default='results', help='Directory to save results')
# Checkpoint
parser.add_argument('--load_checkpoint', type=str, default=None, help='Path to checkpoint to load')
# Task type
parser.add_argument('--task_type', type=str, default='math', choices=['math', 'multiple_choice'], help='Task type')
return parser.parse_args()
def main():
args = parse_args()
# Create configuration
config = AutoMRConfig(
model_name=args.model_name,
device=args.device,
token_budget=args.token_budget,
learning_rate=args.learning_rate,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
num_samples_per_query=args.num_samples,
train_data_path=args.train_data,
val_data_path=args.val_data,
test_data_path=args.test_data,
checkpoint_dir=args.checkpoint_dir,
results_dir=args.results_dir,
task_type=args.task_type,
hidden_size=args.hidden_size,
)
print("="*80)
print("AutoMR: Automatic Meta-Reasoning Skeleton Search")
print("="*80)
print(f"\nConfiguration:")
print(f" Model: {config.model_name}")
print(f" Device: {config.device}")
print(f" Token Budget: {config.token_budget}")
print(f" Task Type: {config.task_type}")
print(f" Mode: {args.mode}")
print("="*80)
# Initialize model
model = AutoMR(config)
# Load checkpoint if specified
if args.load_checkpoint and os.path.exists(args.load_checkpoint):
model.load_checkpoint(args.load_checkpoint)
# Training mode
if args.mode == 'train':
print(f"\n{'='*80}")
print("TRAINING")
print("="*80)
# Load training data
train_data = DataLoader.load_data(config.train_data_path)
val_data = DataLoader.load_data(config.val_data_path)
print(f"Loaded {len(train_data)} training samples from {config.train_data_path}")
# Train
trainer = AutoMRTrainer(model, config)
trainer.train(train_data, val_data)
# Evaluation mode
elif args.mode == 'eval':
print(f"\n{'='*80}")
print("EVALUATION")
print("="*80)
# Load test data
test_data = DataLoader.load_data(config.test_data_path)
print(f"Loaded {len(test_data)} test samples from {config.test_data_path}")
# Evaluate
evaluator = AutoMREvaluator(model, config)
accuracy, results = evaluator.evaluate(test_data)
print(f"\n{'='*80}")
print(f"Final Accuracy: {accuracy:.2%}")
print("="*80)
else:
raise NotImplementedError
if __name__ == "__main__":
main() |