himipo's picture
first
11aa70b
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
import os
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Optional
class ClusterType(Enum):
CW = "cw"
def _guess_cluster_type() -> ClusterType:
return ClusterType.CW
def get_cluster_type(
cluster_type: Optional[ClusterType] = None,
) -> Optional[ClusterType]:
if cluster_type is None:
return _guess_cluster_type()
return cluster_type
def get_slurm_account(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
cluster_type = get_cluster_type(cluster_type)
if cluster_type is None:
return None
return {
ClusterType.CW: "fair_amaia_cw_explore",
}[cluster_type]
def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
cluster_type = get_cluster_type(cluster_type)
if cluster_type is None:
return None
CHECKPOINT_DIRNAMES = {
ClusterType.CW: "",
}
return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
def get_user_checkpoint_path(
cluster_type: Optional[ClusterType] = None,
) -> Optional[Path]:
checkpoint_path = get_checkpoint_path(cluster_type)
if checkpoint_path is None:
return None
username = os.environ.get("USER")
assert username is not None
return checkpoint_path / username
def get_slurm_qos(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
cluster_type = get_cluster_type(cluster_type)
if cluster_type is None:
return None
return {
ClusterType.CW: "explore",
}.get(cluster_type)
def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
cluster_type = get_cluster_type(cluster_type)
if cluster_type is None:
return None
SLURM_PARTITIONS = {
ClusterType.CW: "learn",
}
return SLURM_PARTITIONS[cluster_type]
def get_slurm_executor_parameters(
nodes: int,
num_gpus_per_node: int,
cluster_type: Optional[ClusterType] = None,
**kwargs,
) -> Dict[str, Any]:
# create default parameters
params = {
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
"gpus_per_node": num_gpus_per_node,
"tasks_per_node": num_gpus_per_node, # one task per GPU
"cpus_per_task": 10,
"nodes": nodes,
"slurm_partition": get_slurm_partition(cluster_type),
}
# apply cluster-specific adjustments
cluster_type = get_cluster_type(cluster_type)
if cluster_type == ClusterType.CW:
params["cpus_per_task"] = 16
# set additional parameters / apply overrides
params.update(kwargs)
return params