|
|
|
|
|
|
|
|
|
|
| import os
|
| from enum import Enum
|
| from pathlib import Path
|
| from typing import Any, Dict, Optional
|
|
|
|
|
| class ClusterType(Enum):
|
| CW = "cw"
|
|
|
|
|
| def _guess_cluster_type() -> ClusterType:
|
| return ClusterType.CW
|
|
|
|
|
| def get_cluster_type(
|
| cluster_type: Optional[ClusterType] = None,
|
| ) -> Optional[ClusterType]:
|
| if cluster_type is None:
|
| return _guess_cluster_type()
|
|
|
| return cluster_type
|
|
|
|
|
| def get_slurm_account(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
|
| cluster_type = get_cluster_type(cluster_type)
|
| if cluster_type is None:
|
| return None
|
| return {
|
| ClusterType.CW: "fair_amaia_cw_explore",
|
| }[cluster_type]
|
|
|
|
|
| def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
| cluster_type = get_cluster_type(cluster_type)
|
| if cluster_type is None:
|
| return None
|
|
|
| CHECKPOINT_DIRNAMES = {
|
| ClusterType.CW: "",
|
| }
|
| return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
|
|
|
|
|
| def get_user_checkpoint_path(
|
| cluster_type: Optional[ClusterType] = None,
|
| ) -> Optional[Path]:
|
| checkpoint_path = get_checkpoint_path(cluster_type)
|
| if checkpoint_path is None:
|
| return None
|
|
|
| username = os.environ.get("USER")
|
| assert username is not None
|
| return checkpoint_path / username
|
|
|
|
|
| def get_slurm_qos(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
|
| cluster_type = get_cluster_type(cluster_type)
|
| if cluster_type is None:
|
| return None
|
|
|
| return {
|
| ClusterType.CW: "explore",
|
| }.get(cluster_type)
|
|
|
|
|
| def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
|
| cluster_type = get_cluster_type(cluster_type)
|
| if cluster_type is None:
|
| return None
|
|
|
| SLURM_PARTITIONS = {
|
| ClusterType.CW: "learn",
|
| }
|
| return SLURM_PARTITIONS[cluster_type]
|
|
|
|
|
| def get_slurm_executor_parameters(
|
| nodes: int,
|
| num_gpus_per_node: int,
|
| cluster_type: Optional[ClusterType] = None,
|
| **kwargs,
|
| ) -> Dict[str, Any]:
|
|
|
| params = {
|
| "mem_gb": 0,
|
| "gpus_per_node": num_gpus_per_node,
|
| "tasks_per_node": num_gpus_per_node,
|
| "cpus_per_task": 10,
|
| "nodes": nodes,
|
| "slurm_partition": get_slurm_partition(cluster_type),
|
| }
|
|
|
| cluster_type = get_cluster_type(cluster_type)
|
| if cluster_type == ClusterType.CW:
|
| params["cpus_per_task"] = 16
|
|
|
| params.update(kwargs)
|
| return params
|
|
|