# Copyright (c) Meta Platforms, Inc. and affiliates. # # This software may be used and distributed in accordance with # the terms of the DINOv3 License Agreement. import os from enum import Enum from pathlib import Path from typing import Any, Dict, Optional class ClusterType(Enum): CW = "cw" def _guess_cluster_type() -> ClusterType: return ClusterType.CW def get_cluster_type( cluster_type: Optional[ClusterType] = None, ) -> Optional[ClusterType]: if cluster_type is None: return _guess_cluster_type() return cluster_type def get_slurm_account(cluster_type: Optional[ClusterType] = None) -> Optional[str]: cluster_type = get_cluster_type(cluster_type) if cluster_type is None: return None return { ClusterType.CW: "fair_amaia_cw_explore", }[cluster_type] def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: cluster_type = get_cluster_type(cluster_type) if cluster_type is None: return None CHECKPOINT_DIRNAMES = { ClusterType.CW: "", } return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] def get_user_checkpoint_path( cluster_type: Optional[ClusterType] = None, ) -> Optional[Path]: checkpoint_path = get_checkpoint_path(cluster_type) if checkpoint_path is None: return None username = os.environ.get("USER") assert username is not None return checkpoint_path / username def get_slurm_qos(cluster_type: Optional[ClusterType] = None) -> Optional[str]: cluster_type = get_cluster_type(cluster_type) if cluster_type is None: return None return { ClusterType.CW: "explore", }.get(cluster_type) def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: cluster_type = get_cluster_type(cluster_type) if cluster_type is None: return None SLURM_PARTITIONS = { ClusterType.CW: "learn", } return SLURM_PARTITIONS[cluster_type] def get_slurm_executor_parameters( nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs, ) -> Dict[str, Any]: # create default parameters params = { "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html "gpus_per_node": num_gpus_per_node, "tasks_per_node": num_gpus_per_node, # one task per GPU "cpus_per_task": 10, "nodes": nodes, "slurm_partition": get_slurm_partition(cluster_type), } # apply cluster-specific adjustments cluster_type = get_cluster_type(cluster_type) if cluster_type == ClusterType.CW: params["cpus_per_task"] = 16 # set additional parameters / apply overrides params.update(kwargs) return params