|
|
|
|
|
|
|
|
import json |
|
|
import os |
|
|
import shutil |
|
|
import subprocess |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Dict |
|
|
|
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
from core.args import dataclass_from_dict |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class StoolArgs: |
|
|
config: Any = None |
|
|
launcher: str = "sbatch" |
|
|
script: str = "apps.main.train" |
|
|
copy_code: bool = True |
|
|
dirs_exists_ok: bool = ( |
|
|
False |
|
|
) |
|
|
override: bool = False |
|
|
nodes: int = -1 |
|
|
ngpu: int = 8 |
|
|
ncpu: int = 16 |
|
|
mem: str = "" |
|
|
anaconda: str = "default" |
|
|
constraint: str = "" |
|
|
exclude: str = "" |
|
|
time: int = -1 |
|
|
account: str = "" |
|
|
qos: str = "" |
|
|
partition: str = "learn" |
|
|
stdout: bool = False |
|
|
|
|
|
|
|
|
SBATCH_COMMAND = """#!/bin/bash |
|
|
|
|
|
{exclude} |
|
|
{qos} |
|
|
{account} |
|
|
{constraint} |
|
|
#SBATCH --job-name={name} |
|
|
#SBATCH --nodes={nodes} |
|
|
#SBATCH --gres=gpu:{ngpus} |
|
|
#SBATCH --cpus-per-gpu={ncpu} |
|
|
#SBATCH --time={time} |
|
|
#SBATCH --partition={partition} |
|
|
#SBATCH --mem={mem} |
|
|
|
|
|
#SBATCH --output={dump_dir}/logs/%j/%j.stdout |
|
|
#SBATCH --error={dump_dir}/logs/%j/%j.stderr |
|
|
|
|
|
#SBATCH --open-mode=append |
|
|
#SBATCH --signal=USR2@120 |
|
|
#SBATCH --distribution=block |
|
|
|
|
|
# Mimic the effect of "conda init", which doesn't work for scripts |
|
|
eval "$({conda_exe} shell.bash hook)" |
|
|
source activate {conda_env_path} |
|
|
|
|
|
{go_to_code_dir} |
|
|
|
|
|
export OMP_NUM_THREADS=1 |
|
|
export LAUNCH_WITH="SBATCH" |
|
|
export DUMP_DIR={dump_dir} |
|
|
srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml |
|
|
""" |
|
|
|
|
|
|
|
|
def copy_dir(input_dir: str, output_dir: str) -> None: |
|
|
print(f"Copying : {input_dir}\n" f"to : {output_dir} ...") |
|
|
assert os.path.isdir(input_dir), f"{input_dir} is not a directory" |
|
|
assert os.path.isdir(output_dir), f"{output_dir} is not a directory" |
|
|
rsync_cmd = ( |
|
|
f"rsync -arm --copy-links " |
|
|
f"--include '**/' " |
|
|
f"--include '*.py' " |
|
|
f"--include '*.yaml' " |
|
|
f"--exclude='*' " |
|
|
f"{input_dir}/ {output_dir}" |
|
|
) |
|
|
print(f"Copying command: {rsync_cmd}") |
|
|
subprocess.call([rsync_cmd], shell=True) |
|
|
print("Copy done.") |
|
|
|
|
|
|
|
|
def retrieve_max_time_per_partition() -> Dict[str, int]: |
|
|
|
|
|
|
|
|
sinfo = json.loads(subprocess.check_output("sinfo --json", shell=True))["sinfo"] |
|
|
max_times: Dict[str, int] = {} |
|
|
|
|
|
for info in sinfo: |
|
|
if info["partition"]["maximums"]["time"]["infinite"]: |
|
|
max_times[info["partition"]["name"]] = 14 * 24 * 60 |
|
|
else: |
|
|
max_times[info["partition"]["name"]] = info["partition"]["maximums"][ |
|
|
"time" |
|
|
][ |
|
|
"number" |
|
|
] |
|
|
|
|
|
return max_times |
|
|
|
|
|
|
|
|
def validate_args(args) -> None: |
|
|
|
|
|
if args.time == -1: |
|
|
max_times = retrieve_max_time_per_partition() |
|
|
args.time = max_times.get( |
|
|
args.partition, 3 * 24 * 60 |
|
|
) |
|
|
print( |
|
|
f"No time limit specified, using max time for partitions: {args.time} minutes" |
|
|
) |
|
|
|
|
|
if args.constraint: |
|
|
args.constraint = f"#SBATCH --constraint={args.constraint}" |
|
|
|
|
|
if args.account: |
|
|
args.account = f"#SBATCH --account={args.account}" |
|
|
|
|
|
if args.qos: |
|
|
args.qos = f"#SBATCH --qos={args.qos}" |
|
|
|
|
|
if getattr(args, "exclude", ""): |
|
|
args.exclude = f"#SBATCH --exclude={args.exclude}" |
|
|
|
|
|
if hasattr(args, "anaconda") and args.anaconda: |
|
|
if args.anaconda == "default": |
|
|
args.anaconda = ( |
|
|
subprocess.check_output("which python", shell=True) |
|
|
.decode("ascii") |
|
|
.strip() |
|
|
) |
|
|
else: |
|
|
args.anaconda = f"{args.anaconda}/bin/python" |
|
|
assert os.path.isfile(args.anaconda) |
|
|
|
|
|
args.mem = args.mem or "0" |
|
|
|
|
|
assert args.partition |
|
|
assert args.ngpu > 0 |
|
|
assert args.ncpu > 0 |
|
|
assert args.nodes > 0 |
|
|
assert args.time > 0 |
|
|
assert args.partition |
|
|
|
|
|
|
|
|
def launch_job(args: StoolArgs): |
|
|
|
|
|
validate_args(args) |
|
|
dump_dir = args.config["dump_dir"] |
|
|
job_name = args.config["name"] |
|
|
print("Creating directories...") |
|
|
os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override) |
|
|
if args.override: |
|
|
confirm = input( |
|
|
f"Are you sure you want to delete the directory '{dump_dir}'? This action cannot be undone. (yes/no): " |
|
|
) |
|
|
if confirm.lower() == "yes": |
|
|
shutil.rmtree(dump_dir) |
|
|
print(f"Directory '{dump_dir}' has been deleted.") |
|
|
else: |
|
|
print("Operation cancelled.") |
|
|
return |
|
|
if args.copy_code: |
|
|
os.makedirs(f"{dump_dir}/code", exist_ok=args.dirs_exists_ok) |
|
|
print("Copying code ...") |
|
|
copy_dir(os.getcwd(), f"{dump_dir}/code") |
|
|
|
|
|
print("Saving config file ...") |
|
|
with open(f"{dump_dir}/base_config.yaml", "w") as cfg: |
|
|
cfg.write(OmegaConf.to_yaml(args.config)) |
|
|
|
|
|
conda_exe = os.environ.get("CONDA_EXE", "conda") |
|
|
conda_env_path = os.path.dirname(os.path.dirname(args.anaconda)) |
|
|
log_output = ( |
|
|
"-o $DUMP_DIR/logs/%j/%j_%t.out -e $DUMP_DIR/logs/%j/%j_%t.err" |
|
|
if not args.stdout |
|
|
else "" |
|
|
) |
|
|
sbatch = SBATCH_COMMAND.format( |
|
|
name=job_name, |
|
|
script=args.script, |
|
|
dump_dir=dump_dir, |
|
|
nodes=args.nodes, |
|
|
tasks=args.nodes * args.ngpu, |
|
|
nodes_per_run=args.nodes, |
|
|
ngpus=args.ngpu, |
|
|
ncpu=args.ncpu, |
|
|
mem=args.mem, |
|
|
qos=args.qos, |
|
|
account=args.account, |
|
|
constraint=args.constraint, |
|
|
exclude=args.exclude, |
|
|
time=args.time, |
|
|
partition=args.partition, |
|
|
conda_exe=conda_exe, |
|
|
conda_env_path=conda_env_path, |
|
|
log_output=log_output, |
|
|
go_to_code_dir=f"cd {dump_dir}/code/" if args.copy_code else "", |
|
|
) |
|
|
|
|
|
print("Writing sbatch command ...") |
|
|
with open(f"{dump_dir}/submit.slurm", "w") as f: |
|
|
f.write(sbatch) |
|
|
|
|
|
print("Submitting job ...") |
|
|
os.system(f"{args.launcher} {dump_dir}/submit.slurm") |
|
|
|
|
|
print("Done.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
""" |
|
|
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments |
|
|
This accepts arguments as a dot list |
|
|
So if the dataclass looks like |
|
|
|
|
|
@dataclass |
|
|
class DummyArgs: |
|
|
name: str |
|
|
mode: LMTransformerArgs |
|
|
|
|
|
@dataclass |
|
|
class LMTransformerArgs: |
|
|
dim: int |
|
|
|
|
|
Then you can pass model.dim=32 to change values in LMTransformerArgs |
|
|
or just name=tictac for top level attributes. |
|
|
""" |
|
|
args = OmegaConf.from_cli() |
|
|
args.config = OmegaConf.load(args.config) |
|
|
args = dataclass_from_dict(StoolArgs, args) |
|
|
launch_job(args) |
|
|
|