multimodalart's picture
multimodalart HF Staff
Initial best-effort ZeroGPU port of Khala song generation
d1f1097 verified
import copy
import itertools
import pathlib
from typing import List, Optional
import yaml
BASE_PATH = pathlib.Path(__file__).parent.resolve()
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def resolve_cluster_config(cluster: str) -> str:
if cluster == "dgxh100_eos":
return "eos"
if cluster == "dgxa100_dracooci":
return "draco-oci-iad"
if cluster == "dgxa100_dracooci-ord":
return "draco-oci-ord"
if cluster == "dgxh100_coreweave":
return "coreweave"
raise ValueError(f"Unknown cluster {cluster} provided.")
def resolve_artifact_config(cluster: str) -> str:
if cluster == "dgxh100_eos":
return "eos_lustre"
if cluster == "dgxa100_dracooci":
return "draco-oci_lustre"
if cluster == "dgxa100_dracooci-ord":
return "draco-oci-ord_lustre"
if cluster == "dgxh100_coreweave":
return "coreweave_lustre"
raise ValueError(f"Unknown cluster {cluster} provided.")
def flatten_products(workload_manifest: dotdict) -> dotdict:
"""Flattens a nested dict of products"""
workload_manifest.products = [
dict(**dict(zip(inp.keys(), values)), **{"test_case": product["test_case"][0]})
for product in (workload_manifest.products or [])
if "products" in product
for inp in product["products"]
for values in itertools.product(*inp.values())
]
return workload_manifest
def flatten_workload(workload_manifest: dotdict) -> List[dotdict]:
"""Flattens a workload with products into a list of workloads that don't have products."""
workload_manifest = dict(workload_manifest)
products = workload_manifest.pop("products")
workload_manifests = []
for product in products:
workload = copy.deepcopy(workload_manifest)
workload["spec"] = {k: v for k, v in workload["spec"].items() if k not in product.keys()}
workload["spec"] = dict(**dict(workload["spec"].items()), **product)
workload_manifests.append(dotdict(**workload))
return workload_manifests
def set_build_dependency(workload_manifests: List[dotdict]) -> List[dotdict]:
for workload_manifest in workload_manifests:
workload_manifest.spec["build"] = workload_manifest.spec["build"].format(
**dict(workload_manifest.spec)
)
return workload_manifests
def load_config(config_path: str) -> dotdict:
"""Loads and parses a yaml file into a JETWorkloadManifest"""
with open(config_path) as stream:
try:
return dotdict(**yaml.safe_load(stream))
except yaml.YAMLError as exc:
raise exc
def load_and_flatten(config_path: str) -> List[dotdict]:
"""Wrapper function for doing all the fun at once."""
return set_build_dependency(
flatten_workload(flatten_products(load_config(config_path=config_path)))
)
def filter_by_test_case(workload_manifests: List[dotdict], test_case: str) -> Optional[dotdict]:
"""Returns a workload with matching name. Raises an error if there no or more than a single workload."""
workload_manifests = list(
workload_manifest
for workload_manifest in workload_manifests
if workload_manifest.spec["test_case"] == test_case
)
if len(workload_manifests) > 1:
print("Duplicate test_case found!")
return None
if len(workload_manifests) == 0:
print("No test_case found!")
return None
return workload_manifests[0]
def filter_by_scope(workload_manifests: List[dotdict], scope: str) -> List[dotdict]:
"""Returns all workload with matching scope."""
workload_manifests = list(
workload_manifest
for workload_manifest in workload_manifests
if workload_manifest.spec["scope"] == scope
)
if len(workload_manifests) == 0:
print("No test_case found!")
return []
return workload_manifests
def filter_by_environment(workload_manifests: List[dotdict], environment: str) -> List[dotdict]:
workload_manifests_copy = list(
workload_manifest
for workload_manifest in workload_manifests.copy()
if (
hasattr(dotdict(**workload_manifest["spec"]), "environment")
and workload_manifest["spec"]["environment"] == environment
)
)
if len(workload_manifests_copy) == 0:
print("No test_case found!")
return []
return workload_manifests_copy
def filter_by_platform(workload_manifests: List[dotdict], platform: str) -> List[dotdict]:
workload_manifests = list(
workload_manifest
for workload_manifest in workload_manifests
if (
hasattr(dotdict(**workload_manifest["spec"]), "platforms")
and workload_manifest.spec["platforms"] == platform
)
)
if len(workload_manifests) == 0:
print("No test_case found!")
return []
return workload_manifests
def filter_by_model(workload_manifests: List[dotdict], model: str) -> List[dotdict]:
"""Returns all workload with matching model."""
workload_manifests = list(
workload_manifest
for workload_manifest in workload_manifests
if workload_manifest.spec["model"] == model
)
if len(workload_manifests) == 0:
print("No test_case found!")
return []
return workload_manifests
def filter_by_tag(workload_manifests: List[dotdict], tag: str) -> List[dotdict]:
"""Returns all workload with matching tag."""
workload_manifests = list(
workload_manifest
for workload_manifest in workload_manifests
if hasattr(dotdict(**workload_manifest["spec"]), "tag")
and workload_manifest["spec"]["tag"] == tag
)
if len(workload_manifests) == 0:
print("No test_case found!")
return []
return workload_manifests
def filter_by_test_cases(workload_manifests: List[dotdict], test_cases: str) -> List[dotdict]:
"""Returns a workload with matching name. Raises an error if there no or more than a single workload."""
workload_manifests = list(
workload_manifest
for workload_manifest in workload_manifests
for test_case in test_cases.split(",")
if workload_manifest.spec.test_case == test_case
)
if len(workload_manifests) == 0:
print("No test_case found!")
return []
return workload_manifests
def load_workloads(
container_tag: str,
n_repeat: int = 1,
time_limit: int = 1800,
tag: Optional[str] = None,
environment: Optional[str] = None,
platform: Optional[str] = None,
test_cases: str = "all",
scope: Optional[str] = None,
model: Optional[str] = None,
test_case: Optional[str] = None,
container_image: Optional[str] = None,
record_checkpoints: Optional[str] = None,
) -> List[dotdict]:
"""Return all workloads from disk that match scope and platform."""
recipes_dir = BASE_PATH / ".." / "recipes"
local_dir = BASE_PATH / ".." / "local_recipes"
workloads: List[dotdict] = []
build_workloads: List = []
for file in list(recipes_dir.glob("*.yaml")) + list(local_dir.glob("*.yaml")):
workloads += load_and_flatten(config_path=str(file))
if file.stem.startswith("_build"):
build_workloads.append(load_config(config_path=str(file)))
if scope:
workloads = filter_by_scope(workload_manifests=workloads, scope=scope)
if workloads and environment:
workloads = filter_by_environment(workload_manifests=workloads, environment=environment)
if workloads and model:
workloads = filter_by_model(workload_manifests=workloads, model=model)
if workloads and tag:
workloads = filter_by_tag(workload_manifests=workloads, tag=tag)
if workloads and platform:
workloads = filter_by_platform(workload_manifests=workloads, platform=platform)
if workloads and test_cases != "all":
workloads = filter_by_test_cases(workload_manifests=workloads, test_cases=test_cases)
if workloads and test_case:
workloads = [filter_by_test_case(workload_manifests=workloads, test_case=test_case)]
if not workloads:
return []
for workload in list(workloads):
for build_workload in build_workloads:
if (
workload.spec["build"] == build_workload.spec["name"]
) and build_workload not in workloads:
container_image = container_image or build_workload.spec["source"]["image"]
build_workload.spec["source"]["image"] = f"{container_image}:{container_tag}"
workloads.append(build_workload)
workload.spec["n_repeat"] = n_repeat
workload.spec["time_limit"] = time_limit
workload.spec["artifacts"] = {
key: value.replace(r"{platforms}", workload.spec["platforms"])
for key, value in (
workload.spec["artifacts"].items() if "artifacts" in workload.spec else {}
)
}
if record_checkpoints == "true":
workload.outputs = [
{
"type": "artifact",
"key": f"unverified/model/mcore-ci/{container_tag}/{{model}}/{{name}}",
"subdir": "checkpoints",
"name": r"{model}/{name}",
"description": r"Checkpoint of {model}/{name}",
"pic": {"name": "Mcore CI", "email": "okoenig@nvidia.com"},
"labels": {"origin": "ADLR/Megatron-LM"},
}
]
return workloads
if __name__ == "__main__":
workflows = load_workloads(container_tag="main")
# Save workflows to YAML file
output_file = "workflows.yaml"
with open(output_file, "w") as f:
yaml.dump([dict(workflow) for workflow in workflows], f)