| | |
| |
|
| | import argparse |
| | import sys |
| | import os |
| | from train_agent import train_agent |
| | from test_agent import TestAgent, run_test_session |
| | from lightbulb import main as world_model_main |
| | from lightbulb_inf import main as inference_main |
| | from twisted.internet import reactor, task |
| |
|
| | def parse_main_args(): |
| | parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks") |
| | parser.add_argument('--task', type=str, choices=[ |
| | 'train_llm_world', |
| | 'train_agent', |
| | 'test_agent', |
| | 'inference_llm', |
| | 'inference_world_model', |
| | 'advanced_inference' |
| | ], |
| | required=True, |
| | help='Choose task to execute: train_llm_world, train_agent, test_agent, inference_llm, inference_world_model, advanced_inference') |
| | |
| | parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM') |
| | parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training') |
| | parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name') |
| | parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training') |
| | parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training') |
| | parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training') |
| | parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Train or inference mode for LLM') |
| | parser.add_argument('--query', type=str, default='', help='Query for the test_agent or inference tasks') |
| | |
| | return parser.parse_args() |
| |
|
| | def main(): |
| | |
| | args = parse_main_args() |
| |
|
| | |
| | if args.task == 'train_llm_world': |
| | print("Starting LLM and World Model Training...") |
| | |
| | sys.argv = [ |
| | 'lightbulb_custom.py', |
| | '--mode', args.mode, |
| | '--model_name', args.model_name, |
| | '--dataset_name', args.dataset_name, |
| | '--dataset_config', args.dataset_config, |
| | '--batch_size', str(args.batch_size), |
| | '--num_epochs', str(args.num_epochs), |
| | '--max_length', str(args.max_length) |
| | ] |
| | world_model_main() |
| |
|
| | elif args.task == 'train_agent': |
| | print("Starting Agent Training...") |
| | |
| | d = task.deferLater(reactor, 0, train_agent) |
| | d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True)) |
| | d.addBoth(lambda _: reactor.stop()) |
| | reactor.run() |
| |
|
| | elif args.task == 'test_agent': |
| | print("Starting Test Agent...") |
| | test_agent = TestAgent() |
| | if args.query: |
| | |
| | result = test_agent.process_query(args.query) |
| | print("\nAgent's response:") |
| | print(result) |
| | else: |
| | |
| | reactor.callWhenRunning(run_test_session) |
| | reactor.run() |
| |
|
| | elif args.task in ['inference_llm', 'inference_world_model', 'advanced_inference']: |
| | print("Starting Inference Task...") |
| | |
| |
|
| | |
| | inference_mode_map = { |
| | 'inference_llm': 'without_world_model', |
| | 'inference_world_model': 'world_model', |
| | 'advanced_inference': 'world_model_tree_of_thought' |
| | } |
| |
|
| | selected_inference_mode = inference_mode_map.get(args.task, 'world_model_tree_of_thought') |
| |
|
| | |
| | lightbulb_inf_args = [ |
| | 'lightbulb_custom.py', |
| | '--mode', 'inference', |
| | '--model_name', args.model_name, |
| | '--query', args.query, |
| | '--max_length', str(args.max_length), |
| | '--inference_mode', selected_inference_mode, |
| | '--beam_size', str(getattr(args, 'beam_size', 5)), |
| | '--n_tokens_predict', str(getattr(args, 'n_tokens_predict', 3)), |
| | '--mcts_iterations', str(getattr(args, 'mcts_iterations', 10)), |
| | '--mcts_exploration_constant', str(getattr(args, 'mcts_exploration_constant', 1.414)) |
| | ] |
| |
|
| | |
| | if hasattr(args, 'load_model') and args.load_model: |
| | lightbulb_inf_args += ['--load_model', args.load_model] |
| |
|
| | |
| | sys.argv = lightbulb_inf_args |
| | inference_main() |
| |
|
| | else: |
| | print(f"Unknown task: {args.task}") |
| | sys.exit(1) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|