Spaces:
Running on Zero
Running on Zero
| import copy | |
| import itertools | |
| import pathlib | |
| from typing import List, Optional | |
| import yaml | |
| BASE_PATH = pathlib.Path(__file__).parent.resolve() | |
| class dotdict(dict): | |
| """dot.notation access to dictionary attributes""" | |
| __getattr__ = dict.get | |
| __setattr__ = dict.__setitem__ | |
| __delattr__ = dict.__delitem__ | |
| def resolve_cluster_config(cluster: str) -> str: | |
| if cluster == "dgxh100_eos": | |
| return "eos" | |
| if cluster == "dgxa100_dracooci": | |
| return "draco-oci-iad" | |
| if cluster == "dgxa100_dracooci-ord": | |
| return "draco-oci-ord" | |
| if cluster == "dgxh100_coreweave": | |
| return "coreweave" | |
| raise ValueError(f"Unknown cluster {cluster} provided.") | |
| def resolve_artifact_config(cluster: str) -> str: | |
| if cluster == "dgxh100_eos": | |
| return "eos_lustre" | |
| if cluster == "dgxa100_dracooci": | |
| return "draco-oci_lustre" | |
| if cluster == "dgxa100_dracooci-ord": | |
| return "draco-oci-ord_lustre" | |
| if cluster == "dgxh100_coreweave": | |
| return "coreweave_lustre" | |
| raise ValueError(f"Unknown cluster {cluster} provided.") | |
| def flatten_products(workload_manifest: dotdict) -> dotdict: | |
| """Flattens a nested dict of products""" | |
| workload_manifest.products = [ | |
| dict(**dict(zip(inp.keys(), values)), **{"test_case": product["test_case"][0]}) | |
| for product in (workload_manifest.products or []) | |
| if "products" in product | |
| for inp in product["products"] | |
| for values in itertools.product(*inp.values()) | |
| ] | |
| return workload_manifest | |
| def flatten_workload(workload_manifest: dotdict) -> List[dotdict]: | |
| """Flattens a workload with products into a list of workloads that don't have products.""" | |
| workload_manifest = dict(workload_manifest) | |
| products = workload_manifest.pop("products") | |
| workload_manifests = [] | |
| for product in products: | |
| workload = copy.deepcopy(workload_manifest) | |
| workload["spec"] = {k: v for k, v in workload["spec"].items() if k not in product.keys()} | |
| workload["spec"] = dict(**dict(workload["spec"].items()), **product) | |
| workload_manifests.append(dotdict(**workload)) | |
| return workload_manifests | |
| def set_build_dependency(workload_manifests: List[dotdict]) -> List[dotdict]: | |
| for workload_manifest in workload_manifests: | |
| workload_manifest.spec["build"] = workload_manifest.spec["build"].format( | |
| **dict(workload_manifest.spec) | |
| ) | |
| return workload_manifests | |
| def load_config(config_path: str) -> dotdict: | |
| """Loads and parses a yaml file into a JETWorkloadManifest""" | |
| with open(config_path) as stream: | |
| try: | |
| return dotdict(**yaml.safe_load(stream)) | |
| except yaml.YAMLError as exc: | |
| raise exc | |
| def load_and_flatten(config_path: str) -> List[dotdict]: | |
| """Wrapper function for doing all the fun at once.""" | |
| return set_build_dependency( | |
| flatten_workload(flatten_products(load_config(config_path=config_path))) | |
| ) | |
| def filter_by_test_case(workload_manifests: List[dotdict], test_case: str) -> Optional[dotdict]: | |
| """Returns a workload with matching name. Raises an error if there no or more than a single workload.""" | |
| workload_manifests = list( | |
| workload_manifest | |
| for workload_manifest in workload_manifests | |
| if workload_manifest.spec["test_case"] == test_case | |
| ) | |
| if len(workload_manifests) > 1: | |
| print("Duplicate test_case found!") | |
| return None | |
| if len(workload_manifests) == 0: | |
| print("No test_case found!") | |
| return None | |
| return workload_manifests[0] | |
| def filter_by_scope(workload_manifests: List[dotdict], scope: str) -> List[dotdict]: | |
| """Returns all workload with matching scope.""" | |
| workload_manifests = list( | |
| workload_manifest | |
| for workload_manifest in workload_manifests | |
| if workload_manifest.spec["scope"] == scope | |
| ) | |
| if len(workload_manifests) == 0: | |
| print("No test_case found!") | |
| return [] | |
| return workload_manifests | |
| def filter_by_environment(workload_manifests: List[dotdict], environment: str) -> List[dotdict]: | |
| workload_manifests_copy = list( | |
| workload_manifest | |
| for workload_manifest in workload_manifests.copy() | |
| if ( | |
| hasattr(dotdict(**workload_manifest["spec"]), "environment") | |
| and workload_manifest["spec"]["environment"] == environment | |
| ) | |
| ) | |
| if len(workload_manifests_copy) == 0: | |
| print("No test_case found!") | |
| return [] | |
| return workload_manifests_copy | |
| def filter_by_platform(workload_manifests: List[dotdict], platform: str) -> List[dotdict]: | |
| workload_manifests = list( | |
| workload_manifest | |
| for workload_manifest in workload_manifests | |
| if ( | |
| hasattr(dotdict(**workload_manifest["spec"]), "platforms") | |
| and workload_manifest.spec["platforms"] == platform | |
| ) | |
| ) | |
| if len(workload_manifests) == 0: | |
| print("No test_case found!") | |
| return [] | |
| return workload_manifests | |
| def filter_by_model(workload_manifests: List[dotdict], model: str) -> List[dotdict]: | |
| """Returns all workload with matching model.""" | |
| workload_manifests = list( | |
| workload_manifest | |
| for workload_manifest in workload_manifests | |
| if workload_manifest.spec["model"] == model | |
| ) | |
| if len(workload_manifests) == 0: | |
| print("No test_case found!") | |
| return [] | |
| return workload_manifests | |
| def filter_by_tag(workload_manifests: List[dotdict], tag: str) -> List[dotdict]: | |
| """Returns all workload with matching tag.""" | |
| workload_manifests = list( | |
| workload_manifest | |
| for workload_manifest in workload_manifests | |
| if hasattr(dotdict(**workload_manifest["spec"]), "tag") | |
| and workload_manifest["spec"]["tag"] == tag | |
| ) | |
| if len(workload_manifests) == 0: | |
| print("No test_case found!") | |
| return [] | |
| return workload_manifests | |
| def filter_by_test_cases(workload_manifests: List[dotdict], test_cases: str) -> List[dotdict]: | |
| """Returns a workload with matching name. Raises an error if there no or more than a single workload.""" | |
| workload_manifests = list( | |
| workload_manifest | |
| for workload_manifest in workload_manifests | |
| for test_case in test_cases.split(",") | |
| if workload_manifest.spec.test_case == test_case | |
| ) | |
| if len(workload_manifests) == 0: | |
| print("No test_case found!") | |
| return [] | |
| return workload_manifests | |
| def load_workloads( | |
| container_tag: str, | |
| n_repeat: int = 1, | |
| time_limit: int = 1800, | |
| tag: Optional[str] = None, | |
| environment: Optional[str] = None, | |
| platform: Optional[str] = None, | |
| test_cases: str = "all", | |
| scope: Optional[str] = None, | |
| model: Optional[str] = None, | |
| test_case: Optional[str] = None, | |
| container_image: Optional[str] = None, | |
| record_checkpoints: Optional[str] = None, | |
| ) -> List[dotdict]: | |
| """Return all workloads from disk that match scope and platform.""" | |
| recipes_dir = BASE_PATH / ".." / "recipes" | |
| local_dir = BASE_PATH / ".." / "local_recipes" | |
| workloads: List[dotdict] = [] | |
| build_workloads: List = [] | |
| for file in list(recipes_dir.glob("*.yaml")) + list(local_dir.glob("*.yaml")): | |
| workloads += load_and_flatten(config_path=str(file)) | |
| if file.stem.startswith("_build"): | |
| build_workloads.append(load_config(config_path=str(file))) | |
| if scope: | |
| workloads = filter_by_scope(workload_manifests=workloads, scope=scope) | |
| if workloads and environment: | |
| workloads = filter_by_environment(workload_manifests=workloads, environment=environment) | |
| if workloads and model: | |
| workloads = filter_by_model(workload_manifests=workloads, model=model) | |
| if workloads and tag: | |
| workloads = filter_by_tag(workload_manifests=workloads, tag=tag) | |
| if workloads and platform: | |
| workloads = filter_by_platform(workload_manifests=workloads, platform=platform) | |
| if workloads and test_cases != "all": | |
| workloads = filter_by_test_cases(workload_manifests=workloads, test_cases=test_cases) | |
| if workloads and test_case: | |
| workloads = [filter_by_test_case(workload_manifests=workloads, test_case=test_case)] | |
| if not workloads: | |
| return [] | |
| for workload in list(workloads): | |
| for build_workload in build_workloads: | |
| if ( | |
| workload.spec["build"] == build_workload.spec["name"] | |
| ) and build_workload not in workloads: | |
| container_image = container_image or build_workload.spec["source"]["image"] | |
| build_workload.spec["source"]["image"] = f"{container_image}:{container_tag}" | |
| workloads.append(build_workload) | |
| workload.spec["n_repeat"] = n_repeat | |
| workload.spec["time_limit"] = time_limit | |
| workload.spec["artifacts"] = { | |
| key: value.replace(r"{platforms}", workload.spec["platforms"]) | |
| for key, value in ( | |
| workload.spec["artifacts"].items() if "artifacts" in workload.spec else {} | |
| ) | |
| } | |
| if record_checkpoints == "true": | |
| workload.outputs = [ | |
| { | |
| "type": "artifact", | |
| "key": f"unverified/model/mcore-ci/{container_tag}/{{model}}/{{name}}", | |
| "subdir": "checkpoints", | |
| "name": r"{model}/{name}", | |
| "description": r"Checkpoint of {model}/{name}", | |
| "pic": {"name": "Mcore CI", "email": "okoenig@nvidia.com"}, | |
| "labels": {"origin": "ADLR/Megatron-LM"}, | |
| } | |
| ] | |
| return workloads | |
| if __name__ == "__main__": | |
| workflows = load_workloads(container_tag="main") | |
| # Save workflows to YAML file | |
| output_file = "workflows.yaml" | |
| with open(output_file, "w") as f: | |
| yaml.dump([dict(workflow) for workflow in workflows], f) | |