File size: 2,533 Bytes
1508879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from asymptotes import asymptotic_behaviors
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
from prefect_dask import DaskTaskRunner

# ==============================================================================
# 1. JOB CONFIGURATION
# ==============================================================================

# Example A: Registered string model (e.g., "MACE-MP(M)", "CHGNet")
MODEL = "NequIP-OAM-L"

# Example B: Custom ASE Calculator
# from mace.calculators import mace_mp
# MODEL = mace_mp(model="medium", dispersion=False, default_dtype="float64", device="cuda")

# SLURM environment configuration
SLURM_CONFIG = {
    "nodes": 1,
    "gpus_per_node": 1,
    "account": "matgen",
    "qos": "preempt",
    "walltime": "02:00:00",
    # Example virutal environment (must modify)
    "job_script_prologue": [
        "source ~/.bashrc",
        "module load python",
        "module load cudatoolkit/12.4",
        "source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena",
    ],
}

job_model_name = MODEL if isinstance(MODEL, str) else MODEL.__class__.__name__

cluster_kwargs = dict(
    cores=1,
    memory="64 GB",
    processes=1,
    shebang="#!/bin/bash",
    account=SLURM_CONFIG["account"],
    walltime=SLURM_CONFIG["walltime"],
    job_script_prologue=SLURM_CONFIG["job_script_prologue"],
    job_directives_skip=["-n", "--cpus-per-task", "-J"],
    job_extra_directives=[
        f"-J benchmark_{job_model_name}",
        f"-q {SLURM_CONFIG['qos']}",
        f"-N {SLURM_CONFIG['nodes']}",
        "-C gpu",
        f"-G {SLURM_CONFIG['gpus_per_node']}",
        "-o /dev/null",  # Redirect stdout to null
        "-e /dev/null",  # Redirect stderr to null
    ],
)

cluster = SLURMCluster(**cluster_kwargs)
print("--------------------------------------------------------------------------------")
print(f"Generating SLURM cluster jobs with script:\n{cluster.job_script()}")
print("--------------------------------------------------------------------------------")

cluster.adapt(minimum_jobs=1, maximum_jobs=50)
client = Client(cluster)

print(f"Dask dashboard available at: {client.dashboard_link}")

# ==============================================================================
# 2. JOB EXECUTION
# ==============================================================================

asymptotic_behaviors.with_options(
    task_runner=DaskTaskRunner(address=client.scheduler.address),
    log_prints=True,
)(model=MODEL)

# TODO: Add Distribution Shift, Stability and Reactivity