import os import argparse import torch import random import numpy as np import logging import time from datetime import datetime from larm.common.config import Config from larm.common.logger import setup_logger from larm.common.registry import registry from larm.task import Task, BaseRunner def set_seed(random_seed: int, use_gpu: bool): random.seed(random_seed) os.environ['PYTHONHASHSEED'] = str(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) torch.cuda.manual_seed(random_seed) if use_gpu: torch.cuda.manual_seed_all(random_seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False print(f"set seed: {random_seed}") def parse_args(): parser = argparse.ArgumentParser(description="Language Reasoning and Memory") parser.add_argument("--cfg-path", required=True, help="path to configuration file.") parser.add_argument( "--options", nargs="+", help="override some settings in the used config, the key-value pair " "in xxx=yyy format will be merged into config file (deprecate), " "change to --cfg-options instead.", ) args = parser.parse_args() return args def get_save_dir(config) -> str: # Read output_dir directly from config output_dir = config.run_cfg.get("output_dir", None) if output_dir is None: # Fallback to time-based naming if not specified time = datetime.now().strftime("%Y%m%d-%H%M%S") output_dir = os.path.join("results", config.method, time) logging.warning(f"output_dir not specified in config, using default: {output_dir}") return output_dir def get_runner_class(config) -> BaseRunner: print(config.method) return registry.get_runner_class(config.method) def main(): # parse configs args = parse_args() config = Config(args) set_seed(config.run_cfg.seed, use_gpu=True) # set up save folder save_dir = get_save_dir(config) config.run_cfg.save_dir = save_dir # set up logger config.run_cfg.log_dir = os.path.join(save_dir, "logs") setup_logger(output_dir=config.run_cfg.log_dir) config.pretty_print() task = Task(config) datasets_dict = task.build_dataset() env_and_gens_dict = task.build_env_and_generator() model = task.build_model() # build runner runner_cls = get_runner_class(config) # For multimodal models, use processor; otherwise use tokenizer processing_class = getattr(model, 'processor', model.tokenizer) runner = runner_cls( model=model, processing_class=processing_class, configs=config, datasets_dict=datasets_dict, env_and_gens_dict=env_and_gens_dict, ) # train or evaluate if config.run_cfg.mode == "train": runner.train() if config.run_cfg.mode == "evaluate": runner.evaluate() if __name__ == "__main__": main()