File size: 2,976 Bytes
e34b94f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import os
import argparse
import torch
import random
import numpy as np
import logging
import time
from datetime import datetime
from larm.common.config import Config
from larm.common.logger import setup_logger
from larm.common.registry import registry
from larm.task import Task, BaseRunner
def set_seed(random_seed: int, use_gpu: bool):
random.seed(random_seed)
os.environ['PYTHONHASHSEED'] = str(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
if use_gpu:
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
print(f"set seed: {random_seed}")
def parse_args():
parser = argparse.ArgumentParser(description="Language Reasoning and Memory")
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
def get_save_dir(config) -> str:
# Read output_dir directly from config
output_dir = config.run_cfg.get("output_dir", None)
if output_dir is None:
# Fallback to time-based naming if not specified
time = datetime.now().strftime("%Y%m%d-%H%M%S")
output_dir = os.path.join("results", config.method, time)
logging.warning(f"output_dir not specified in config, using default: {output_dir}")
return output_dir
def get_runner_class(config) -> BaseRunner:
print(config.method)
return registry.get_runner_class(config.method)
def main():
# parse configs
args = parse_args()
config = Config(args)
set_seed(config.run_cfg.seed, use_gpu=True)
# set up save folder
save_dir = get_save_dir(config)
config.run_cfg.save_dir = save_dir
# set up logger
config.run_cfg.log_dir = os.path.join(save_dir, "logs")
setup_logger(output_dir=config.run_cfg.log_dir)
config.pretty_print()
task = Task(config)
datasets_dict = task.build_dataset()
env_and_gens_dict = task.build_env_and_generator()
model = task.build_model()
# build runner
runner_cls = get_runner_class(config)
# For multimodal models, use processor; otherwise use tokenizer
processing_class = getattr(model, 'processor', model.tokenizer)
runner = runner_cls(
model=model,
processing_class=processing_class,
configs=config,
datasets_dict=datasets_dict,
env_and_gens_dict=env_and_gens_dict,
)
# train or evaluate
if config.run_cfg.mode == "train":
runner.train()
if config.run_cfg.mode == "evaluate":
runner.evaluate()
if __name__ == "__main__":
main() |