|
|
|
|
|
|
|
|
import argparse |
|
|
import importlib |
|
|
import os |
|
|
import sys |
|
|
from datetime import datetime |
|
|
sys.dont_write_bytecode = True |
|
|
from scepter.modules.solver.registry import SOLVERS |
|
|
from scepter.modules.utils.config import Config |
|
|
from scepter.modules.utils.distribute import we |
|
|
from scepter.modules.utils.file_system import FS |
|
|
from scepter.modules.utils.logger import get_logger |
|
|
|
|
|
if os.path.exists('__init__.py'): |
|
|
package_name = 'scepter_ext' |
|
|
spec = importlib.util.spec_from_file_location(package_name, '__init__.py') |
|
|
package = importlib.util.module_from_spec(spec) |
|
|
sys.modules[package_name] = package |
|
|
spec.loader.exec_module(package) |
|
|
|
|
|
def run_task(cfg): |
|
|
std_logger = get_logger(name='scepter') |
|
|
solver = SOLVERS.build(cfg.SOLVER, logger=std_logger) |
|
|
solver.set_up_pre() |
|
|
solver.set_up() |
|
|
if we.rank == 0: |
|
|
FS.put_object_from_local_file(cfg.args.cfg_file, os.path.join(solver.work_dir, "train.yaml")) |
|
|
if cfg.args.stage == "train": |
|
|
solver.solve() |
|
|
elif cfg.args.stage == "eval": |
|
|
solver.run_eval() |
|
|
|
|
|
|
|
|
def update_config(cfg): |
|
|
if hasattr(cfg.args, 'learning_rate') and cfg.args.learning_rate: |
|
|
print( |
|
|
f'learning_rate change from {cfg.SOLVER.OPTIMIZER.LEARNING_RATE} to {cfg.args.learning_rate}' |
|
|
) |
|
|
cfg.SOLVER.OPTIMIZER.LEARNING_RATE = float(cfg.args.learning_rate) |
|
|
if hasattr(cfg.args, 'max_steps') and cfg.args.max_steps: |
|
|
print( |
|
|
f'max_steps change from {cfg.SOLVER.MAX_STEPS} to {cfg.args.max_steps}' |
|
|
) |
|
|
cfg.SOLVER.MAX_STEPS = int(cfg.args.max_steps) |
|
|
cfg.SOLVER.WORK_DIR = os.path.join(cfg.SOLVER.WORK_DIR, "{0:%Y%m%d%H%M%S}".format(datetime.now())) |
|
|
return cfg |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser(description='Argparser for Scepter:\n') |
|
|
parser.add_argument( |
|
|
"--stage", |
|
|
dest="stage", |
|
|
help="Running stage!", |
|
|
default="train", |
|
|
choices=["train", "eval"] |
|
|
) |
|
|
parser.add_argument('--learning_rate', |
|
|
dest='learning_rate', |
|
|
help='The learning rate for our network!', |
|
|
default=None) |
|
|
parser.add_argument('--max_steps', |
|
|
dest='max_steps', |
|
|
help='The max steps for training!', |
|
|
default=None) |
|
|
|
|
|
cfg = Config(load=True, parser_ins=parser) |
|
|
cfg = update_config(cfg) |
|
|
we.init_env(cfg, logger=None, fn=run_task) |
|
|
|