st24hour's picture
Upload folder using huggingface_hub
e101805 verified
import argparse
import os
import sys
# Ensure project root is on Python path for module resolution
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import uuid
from pathlib import Path
import submitit
from exaonepath.create_patches import create_patches_fp
from exaonepath.utils.cluster import get_slurm_partition
def parse_args():
slurm_partition = get_slurm_partition()
parser = argparse.ArgumentParser("Submitit for create patches", parents=[create_patches_fp.get_args_parser()])
parser.add_argument("--tasks_per_node", default=26, type=int, help="Number of processes to request on each node (CPU-only)")
parser.add_argument("--nodes", default=4, type=int, help="Number of nodes to request")
parser.add_argument("--timeout", default=20160, type=int, help="Duration of the job - minute")
parser.add_argument("--partition", default=slurm_partition, type=str, help="Partition where to submit")
parser.add_argument("--job_name", default="create_patches", type=str, help="Job name")
parser.add_argument("--nodelist", default=None, type=str, help="Specific node to request (biolabslur-a3ultranodeset-[0-3])")
parser.add_argument("--cpus_per_task", default=4, type=int, help="Number of CPUs per task")
parser.add_argument("--qos", default=None, type=str, help="Quality of Service for Slurm, e.g. low")
parser.add_argument("--logs_dir", default="create_patches/logs/%j", type=str, help="Directory to save logs and checkpoints")
return parser.parse_args()
def get_shared_folder(args) -> Path:
p = Path(f"{args.logs_dir}")
p.mkdir(parents=True, exist_ok=True)
return p
def get_init_file(args):
# Init file must not exist, but it's parent dir must exist.
os.makedirs(str(get_shared_folder(args)), exist_ok=True)
init_file = get_shared_folder(args) / f"{uuid.uuid4().hex}_init"
if init_file.exists():
os.remove(str(init_file))
return init_file
class Trainer(object):
def __init__(self, args):
self.args = args
def __call__(self):
from exaonepath.create_patches import create_patches_fp
self._setup_gpu_args()
create_patches_fp.main(self.args)
def checkpoint(self):
import submitit
self.args.dist_url = get_init_file(self.args).as_uri()
print("Requeuing ", self.args)
empty_trainer = type(self)(self.args)
return submitit.helpers.DelayedSubmission(empty_trainer)
def _setup_gpu_args(self):
import submitit
from pathlib import Path
job_env = submitit.JobEnvironment()
self.args.logs_dir = Path(str(self.args.logs_dir).replace("%j", str(job_env.job_id)))
self.args.gpu = job_env.local_rank
self.args.rank = job_env.global_rank
self.args.world_size = job_env.num_tasks
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
def main():
args = parse_args()
executor = submitit.AutoExecutor(folder=args.logs_dir, slurm_max_num_timeout=30)
# Submitit ๋ฒ„์ „ ์ถœ๋ ฅ
print(f"Submitit version: {submitit.__version__}")
# kwargs = {}
executor_params = {
"mem_gb": 512, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
"gpus_per_node": 0, # CPU-only job
"tasks_per_node": args.tasks_per_node,
"cpus_per_task": args.cpus_per_task,
"nodes": args.nodes,
"slurm_partition": get_slurm_partition(),
"timeout_min": args.timeout, # Set job timeout in minutes based on parsed argument
"slurm_signal_delay_s": 120,
}
# Add Slurm QoS option if provided
if args.qos:
executor_params["slurm_qos"] = args.qos
# Add specific nodelist constraint to Slurm parameters
if args.nodelist:
executor_params["slurm_nodelist"] = args.nodelist
executor.update_parameters(name=args.job_name, **executor_params)
# ์‹ค์ œ ์ƒ์„ฑ๋˜๋Š” sbatch ๋ช…๋ น์–ด ํ™•์ธ
try:
# ์ง์ ‘ ๋‚ด๋ถ€ ๊ตฌํ˜„์„ ํ™•์ธ
import inspect
if hasattr(executor._executor, "_make_submission_file"):
# ๋Œ€๋ถ€๋ถ„์˜ submitit ๊ตฌํ˜„์—์„œ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ์‹
print("SLURM submission command would be created with:")
print(inspect.getsource(executor._executor._make_submission_file))
# ๋˜๋Š” ๋” ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์œผ๋กœ:
print("\nExecutor type:", type(executor._executor))
print("Available methods:", [method for method in dir(executor._executor) if not method.startswith('_')])
# ์‹ค์ œ ํŒŒ๋ผ๋ฏธํ„ฐ ํ™•์ธ
print("\nSubmission parameters:")
for key, value in executor._executor.parameters.items():
print(f" {key}: {value}")
except Exception as e:
print(f"Error inspecting executor: {e}")
args.dist_url = get_init_file(args).as_uri()
trainer = Trainer(args)
job = executor.submit(trainer)
print(f"Submitted job_id: {job.job_id}")
print(f"Logs and checkpoints will be saved at: {args.logs_dir}")
return 0
if __name__ == "__main__":
sys.exit(main())