import copy import itertools import pathlib from typing import List, Optional import yaml BASE_PATH = pathlib.Path(__file__).parent.resolve() class dotdict(dict): """dot.notation access to dictionary attributes""" __getattr__ = dict.get __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ def resolve_cluster_config(cluster: str) -> str: if cluster == "dgxh100_eos": return "eos" if cluster == "dgxa100_dracooci": return "draco-oci-iad" if cluster == "dgxa100_dracooci-ord": return "draco-oci-ord" if cluster == "dgxh100_coreweave": return "coreweave" raise ValueError(f"Unknown cluster {cluster} provided.") def resolve_artifact_config(cluster: str) -> str: if cluster == "dgxh100_eos": return "eos_lustre" if cluster == "dgxa100_dracooci": return "draco-oci_lustre" if cluster == "dgxa100_dracooci-ord": return "draco-oci-ord_lustre" if cluster == "dgxh100_coreweave": return "coreweave_lustre" raise ValueError(f"Unknown cluster {cluster} provided.") def flatten_products(workload_manifest: dotdict) -> dotdict: """Flattens a nested dict of products""" workload_manifest.products = [ dict(**dict(zip(inp.keys(), values)), **{"test_case": product["test_case"][0]}) for product in (workload_manifest.products or []) if "products" in product for inp in product["products"] for values in itertools.product(*inp.values()) ] return workload_manifest def flatten_workload(workload_manifest: dotdict) -> List[dotdict]: """Flattens a workload with products into a list of workloads that don't have products.""" workload_manifest = dict(workload_manifest) products = workload_manifest.pop("products") workload_manifests = [] for product in products: workload = copy.deepcopy(workload_manifest) workload["spec"] = {k: v for k, v in workload["spec"].items() if k not in product.keys()} workload["spec"] = dict(**dict(workload["spec"].items()), **product) workload_manifests.append(dotdict(**workload)) return workload_manifests def set_build_dependency(workload_manifests: List[dotdict]) -> List[dotdict]: for workload_manifest in workload_manifests: workload_manifest.spec["build"] = workload_manifest.spec["build"].format( **dict(workload_manifest.spec) ) return workload_manifests def load_config(config_path: str) -> dotdict: """Loads and parses a yaml file into a JETWorkloadManifest""" with open(config_path) as stream: try: return dotdict(**yaml.safe_load(stream)) except yaml.YAMLError as exc: raise exc def load_and_flatten(config_path: str) -> List[dotdict]: """Wrapper function for doing all the fun at once.""" return set_build_dependency( flatten_workload(flatten_products(load_config(config_path=config_path))) ) def filter_by_test_case(workload_manifests: List[dotdict], test_case: str) -> Optional[dotdict]: """Returns a workload with matching name. Raises an error if there no or more than a single workload.""" workload_manifests = list( workload_manifest for workload_manifest in workload_manifests if workload_manifest.spec["test_case"] == test_case ) if len(workload_manifests) > 1: print("Duplicate test_case found!") return None if len(workload_manifests) == 0: print("No test_case found!") return None return workload_manifests[0] def filter_by_scope(workload_manifests: List[dotdict], scope: str) -> List[dotdict]: """Returns all workload with matching scope.""" workload_manifests = list( workload_manifest for workload_manifest in workload_manifests if workload_manifest.spec["scope"] == scope ) if len(workload_manifests) == 0: print("No test_case found!") return [] return workload_manifests def filter_by_environment(workload_manifests: List[dotdict], environment: str) -> List[dotdict]: workload_manifests_copy = list( workload_manifest for workload_manifest in workload_manifests.copy() if ( hasattr(dotdict(**workload_manifest["spec"]), "environment") and workload_manifest["spec"]["environment"] == environment ) ) if len(workload_manifests_copy) == 0: print("No test_case found!") return [] return workload_manifests_copy def filter_by_platform(workload_manifests: List[dotdict], platform: str) -> List[dotdict]: workload_manifests = list( workload_manifest for workload_manifest in workload_manifests if ( hasattr(dotdict(**workload_manifest["spec"]), "platforms") and workload_manifest.spec["platforms"] == platform ) ) if len(workload_manifests) == 0: print("No test_case found!") return [] return workload_manifests def filter_by_model(workload_manifests: List[dotdict], model: str) -> List[dotdict]: """Returns all workload with matching model.""" workload_manifests = list( workload_manifest for workload_manifest in workload_manifests if workload_manifest.spec["model"] == model ) if len(workload_manifests) == 0: print("No test_case found!") return [] return workload_manifests def filter_by_tag(workload_manifests: List[dotdict], tag: str) -> List[dotdict]: """Returns all workload with matching tag.""" workload_manifests = list( workload_manifest for workload_manifest in workload_manifests if hasattr(dotdict(**workload_manifest["spec"]), "tag") and workload_manifest["spec"]["tag"] == tag ) if len(workload_manifests) == 0: print("No test_case found!") return [] return workload_manifests def filter_by_test_cases(workload_manifests: List[dotdict], test_cases: str) -> List[dotdict]: """Returns a workload with matching name. Raises an error if there no or more than a single workload.""" workload_manifests = list( workload_manifest for workload_manifest in workload_manifests for test_case in test_cases.split(",") if workload_manifest.spec.test_case == test_case ) if len(workload_manifests) == 0: print("No test_case found!") return [] return workload_manifests def load_workloads( container_tag: str, n_repeat: int = 1, time_limit: int = 1800, tag: Optional[str] = None, environment: Optional[str] = None, platform: Optional[str] = None, test_cases: str = "all", scope: Optional[str] = None, model: Optional[str] = None, test_case: Optional[str] = None, container_image: Optional[str] = None, record_checkpoints: Optional[str] = None, ) -> List[dotdict]: """Return all workloads from disk that match scope and platform.""" recipes_dir = BASE_PATH / ".." / "recipes" local_dir = BASE_PATH / ".." / "local_recipes" workloads: List[dotdict] = [] build_workloads: List = [] for file in list(recipes_dir.glob("*.yaml")) + list(local_dir.glob("*.yaml")): workloads += load_and_flatten(config_path=str(file)) if file.stem.startswith("_build"): build_workloads.append(load_config(config_path=str(file))) if scope: workloads = filter_by_scope(workload_manifests=workloads, scope=scope) if workloads and environment: workloads = filter_by_environment(workload_manifests=workloads, environment=environment) if workloads and model: workloads = filter_by_model(workload_manifests=workloads, model=model) if workloads and tag: workloads = filter_by_tag(workload_manifests=workloads, tag=tag) if workloads and platform: workloads = filter_by_platform(workload_manifests=workloads, platform=platform) if workloads and test_cases != "all": workloads = filter_by_test_cases(workload_manifests=workloads, test_cases=test_cases) if workloads and test_case: workloads = [filter_by_test_case(workload_manifests=workloads, test_case=test_case)] if not workloads: return [] for workload in list(workloads): for build_workload in build_workloads: if ( workload.spec["build"] == build_workload.spec["name"] ) and build_workload not in workloads: container_image = container_image or build_workload.spec["source"]["image"] build_workload.spec["source"]["image"] = f"{container_image}:{container_tag}" workloads.append(build_workload) workload.spec["n_repeat"] = n_repeat workload.spec["time_limit"] = time_limit workload.spec["artifacts"] = { key: value.replace(r"{platforms}", workload.spec["platforms"]) for key, value in ( workload.spec["artifacts"].items() if "artifacts" in workload.spec else {} ) } if record_checkpoints == "true": workload.outputs = [ { "type": "artifact", "key": f"unverified/model/mcore-ci/{container_tag}/{{model}}/{{name}}", "subdir": "checkpoints", "name": r"{model}/{name}", "description": r"Checkpoint of {model}/{name}", "pic": {"name": "Mcore CI", "email": "okoenig@nvidia.com"}, "labels": {"origin": "ADLR/Megatron-LM"}, } ] return workloads if __name__ == "__main__": workflows = load_workloads(container_tag="main") # Save workflows to YAML file output_file = "workflows.yaml" with open(output_file, "w") as f: yaml.dump([dict(workflow) for workflow in workflows], f)