File size: 7,002 Bytes
bc90483 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | # Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
import argparse
import logging
import os
from pathlib import Path
from dinov3.logging import setup_logging
from dinov3.utils.cluster import (
get_slurm_account,
get_slurm_executor_parameters,
get_slurm_partition,
get_slurm_qos,
get_user_checkpoint_path,
)
from dinov3.utils.custom_callable import load_custom_callable
logger = logging.getLogger("dinov3")
def get_submitit_parser():
slurm_partition = get_slurm_partition()
slurm_account = get_slurm_account()
slurm_qos = get_slurm_qos()
parser = argparse.ArgumentParser("Submitit arguments", add_help=False)
parser.add_argument(
"--ngpus",
default=8,
type=int,
help="Number of gpus to request on each node, default: %(default)s",
)
parser.add_argument(
"--nodes",
default=1,
type=int,
help="Number of nodes to request, default: %(default)s",
)
parser.add_argument(
"--timeout",
default=2800,
type=int,
help="Duration of the job, default: %(default)s",
)
parser.add_argument(
"--slurm-partition",
default=slurm_partition,
type=str,
help="Partition where to submit, default: %(default)s",
)
parser.add_argument(
"--slurm-qos",
default=slurm_qos,
metavar="SLURM_QOS",
type=str,
dest="slurm_qos",
help="slurm QoS to use for jobs in cluster environment, default: %(default)s",
)
parser.add_argument(
"--slurm-array-parallelism",
default=256,
type=int,
help="Maximum number of jobs that will be executed in parallel, default: %(default)s",
)
parser.add_argument(
"--slurm-nice",
default=0,
type=int,
help="Adjusted scheduling priority within Slurm, default: %(default)s",
)
parser.add_argument(
"--slurm-account",
default=slurm_account,
type=str,
help="Slurm account name, default: %(default)s",
)
parser.add_argument(
"--comment",
default="",
type=str,
help="Comment to pass to scheduler, e.g. priority message, default: '%(default)s'",
)
parser.add_argument(
"--exclude",
default="",
type=str,
help="Nodes to exclude, default: '%(default)s'",
)
parser.add_argument(
"--output-dir",
type=str,
help="output dir",
)
return parser
def get_run_parser():
parser = argparse.ArgumentParser("Launcher arguments", parents=[get_submitit_parser()])
parser.add_argument(
"module_path",
type=str,
help="Full path to the program/script to be launched in parallel, "
"followed by all the arguments for the training script.",
)
parser.add_argument(
"--callable-name",
type=str,
default="main",
help="Name of the callable to execute in the script",
)
return parser
def get_shared_folder() -> Path:
user_checkpoint_path = get_user_checkpoint_path()
if user_checkpoint_path is None:
raise RuntimeError("Path to user checkpoint cannot be determined")
path = user_checkpoint_path / "experiments"
path.mkdir(exist_ok=True)
return path
class CheckpointableSubmitter:
def __init__(self, module_path, callable_name, args, output_dir):
self.args = args
self.callable_name = callable_name
self.module_path = os.path.realpath(module_path)
self.output_dir = os.path.realpath(output_dir)
def __call__(self):
self._setup_args()
callable_ = load_custom_callable(self.module_path, self.callable_name)
callable_(self.args)
def checkpoint(self):
import submitit
logger.info(f"Requeuing {self.callable_name} from {self.module_path} with {self.args}")
empty_class = type(self)(self.module_path, self.callable_name, self.args, self.output_dir)
return submitit.helpers.DelayedSubmission(empty_class)
def _setup_args(self):
import submitit
job_env = submitit.JobEnvironment()
self.output_dir = str(self.output_dir).replace("%j", str(job_env.job_id))
if "--output-dir" not in self.args:
self.args.insert(0, f"--output-dir={self.output_dir}")
# Setup logging with exact same arguments as in fairvit/run/init.py
# to use lru_cache memoization and avoid setting up the logger twice
setup_logging(output=self.output_dir, level=logging.INFO)
logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
logger.info(f"Module Path: {self.module_path}")
logger.info(f"Callable Name: {self.callable_name}")
logger.info(f'Args: {" ".join(self.args)}')
def submit_jobs(class_to_submit, output_dir, submitit_args, name="fairvit"):
import submitit
Path(output_dir).mkdir(parents=True, exist_ok=True)
executor = submitit.AutoExecutor(folder=output_dir, slurm_max_num_timeout=30)
kwargs = {}
if submitit_args.comment:
kwargs["slurm_comment"] = submitit_args.comment
if submitit_args.exclude:
kwargs["slurm_exclude"] = submitit_args.exclude
executor_params = get_slurm_executor_parameters(
nodes=submitit_args.nodes,
num_gpus_per_node=submitit_args.ngpus,
timeout_min=submitit_args.timeout, # max is 60 * 72
slurm_signal_delay_s=120,
slurm_partition=submitit_args.slurm_partition,
slurm_qos=submitit_args.slurm_qos,
# slurm_account=submitit_args.slurm_account,
slurm_additional_parameters=dict(nice=submitit_args.slurm_nice),
**kwargs,
)
executor.update_parameters(name=name, **executor_params)
job = executor.submit(class_to_submit)
logger.info(f"Submitted job_id: {job.job_id}")
str_output_dir = os.path.abspath(output_dir).replace("%j", str(job.job_id))
logger.info(f"Logs and checkpoints will be saved at: {str_output_dir}")
def main():
setup_logging(level=logging.INFO)
args, script_args = get_run_parser().parse_known_args()
assert os.path.exists(args.module_path), "The module path does not exist"
file_name = os.path.splitext(os.path.split(args.module_path)[1])[0]
name = f"{file_name}:{args.callable_name}"
if args.output_dir is None:
args.output_dir = get_shared_folder() / "%j"
class_to_submit = CheckpointableSubmitter(args.module_path, args.callable_name, script_args, args.output_dir)
submit_jobs(class_to_submit, args.output_dir, args, name=name)
if __name__ == "__main__":
main()
|