Spaces:
Running
Running
File size: 2,533 Bytes
1508879 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | from asymptotes import asymptotic_behaviors
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
from prefect_dask import DaskTaskRunner
# ==============================================================================
# 1. JOB CONFIGURATION
# ==============================================================================
# Example A: Registered string model (e.g., "MACE-MP(M)", "CHGNet")
MODEL = "NequIP-OAM-L"
# Example B: Custom ASE Calculator
# from mace.calculators import mace_mp
# MODEL = mace_mp(model="medium", dispersion=False, default_dtype="float64", device="cuda")
# SLURM environment configuration
SLURM_CONFIG = {
"nodes": 1,
"gpus_per_node": 1,
"account": "matgen",
"qos": "preempt",
"walltime": "02:00:00",
# Example virutal environment (must modify)
"job_script_prologue": [
"source ~/.bashrc",
"module load python",
"module load cudatoolkit/12.4",
"source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena",
],
}
job_model_name = MODEL if isinstance(MODEL, str) else MODEL.__class__.__name__
cluster_kwargs = dict(
cores=1,
memory="64 GB",
processes=1,
shebang="#!/bin/bash",
account=SLURM_CONFIG["account"],
walltime=SLURM_CONFIG["walltime"],
job_script_prologue=SLURM_CONFIG["job_script_prologue"],
job_directives_skip=["-n", "--cpus-per-task", "-J"],
job_extra_directives=[
f"-J benchmark_{job_model_name}",
f"-q {SLURM_CONFIG['qos']}",
f"-N {SLURM_CONFIG['nodes']}",
"-C gpu",
f"-G {SLURM_CONFIG['gpus_per_node']}",
"-o /dev/null", # Redirect stdout to null
"-e /dev/null", # Redirect stderr to null
],
)
cluster = SLURMCluster(**cluster_kwargs)
print("--------------------------------------------------------------------------------")
print(f"Generating SLURM cluster jobs with script:\n{cluster.job_script()}")
print("--------------------------------------------------------------------------------")
cluster.adapt(minimum_jobs=1, maximum_jobs=50)
client = Client(cluster)
print(f"Dask dashboard available at: {client.dashboard_link}")
# ==============================================================================
# 2. JOB EXECUTION
# ==============================================================================
asymptotic_behaviors.with_options(
task_runner=DaskTaskRunner(address=client.scheduler.address),
log_prints=True,
)(model=MODEL)
# TODO: Add Distribution Shift, Stability and Reactivity
|