File size: 2,976 Bytes
e34b94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import argparse
import torch
import random
import numpy as np
import logging
import time

from datetime import datetime

from larm.common.config import Config
from larm.common.logger import setup_logger
from larm.common.registry import registry
from larm.task import Task, BaseRunner

def set_seed(random_seed: int, use_gpu: bool):

    random.seed(random_seed)
    os.environ['PYTHONHASHSEED'] = str(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    if use_gpu:
        torch.cuda.manual_seed_all(random_seed)

    torch.backends.cudnn.deterministic = True   
    torch.backends.cudnn.benchmark = False      

    print(f"set seed: {random_seed}")

def parse_args():
    parser = argparse.ArgumentParser(description="Language Reasoning and Memory")

    parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
        "in xxx=yyy format will be merged into config file (deprecate), "
        "change to --cfg-options instead.",
    )

    args = parser.parse_args()

    return args

def get_save_dir(config) -> str:
    # Read output_dir directly from config
    output_dir = config.run_cfg.get("output_dir", None)
    if output_dir is None:
        # Fallback to time-based naming if not specified
        time = datetime.now().strftime("%Y%m%d-%H%M%S")
        output_dir = os.path.join("results", config.method, time)
        logging.warning(f"output_dir not specified in config, using default: {output_dir}")
    return output_dir

def get_runner_class(config) -> BaseRunner:
    print(config.method)
    return registry.get_runner_class(config.method)

def main():
    
    # parse configs
    args = parse_args()
    config = Config(args)
    
    set_seed(config.run_cfg.seed, use_gpu=True)

    # set up save folder
    save_dir = get_save_dir(config)
    config.run_cfg.save_dir = save_dir

    # set up logger
    config.run_cfg.log_dir = os.path.join(save_dir, "logs")
    setup_logger(output_dir=config.run_cfg.log_dir)

    config.pretty_print()
    
    task = Task(config)
    datasets_dict = task.build_dataset()
    env_and_gens_dict = task.build_env_and_generator()
    model = task.build_model()
    
    # build runner
    runner_cls = get_runner_class(config)
    # For multimodal models, use processor; otherwise use tokenizer
    processing_class = getattr(model, 'processor', model.tokenizer)
    runner = runner_cls(
        model=model, 
        processing_class=processing_class, 
        configs=config,
        datasets_dict=datasets_dict, 
        env_and_gens_dict=env_and_gens_dict,
    )
    
    # train or evaluate
    if config.run_cfg.mode == "train":
        runner.train()
    if config.run_cfg.mode == "evaluate":
        runner.evaluate()

if __name__ == "__main__":
    main()