NeMo
Megatron-LM / tests /functional_tests /python_test_utils /test_pretraining_regular_pipeline.py
KexuanShi's picture
Upload folder using huggingface_hub
88e6849 verified
Raw
History Blame Contribute Delete
2.22 kB
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
import logging
from typing import Dict, List, Optional
import yaml
from tests.functional_tests.python_test_utils import common
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
CHECK_THRESHOLDS = {
"iteration-time": [common.ApproximateTest(atol=0, rtol=0.05)],
"mem-allocated-bytes": [common.ApproximateTest(atol=0, rtol=0.05)],
"mem-max-allocated-bytes": [common.ApproximateTest(atol=0, rtol=0.05)],
"lm loss": [common.DeterministicTest(), common.ApproximateTest(atol=0, rtol=0.05)],
"mtp_1 loss": [common.DeterministicTest(), common.ApproximateTest(atol=0, rtol=0.05)],
"num-zeros": [common.DeterministicTest(), common.ApproximateTest(atol=0, rtol=0.05)],
"generated_tokens": [common.DeterministicTest(), common.ApproximateTest(atol=0, rtol=0.05)],
"logprobs": [common.DeterministicTest(), common.ApproximateTest(atol=0, rtol=0.05)],
}
def test_regular_pipeline(
compare_approximate_results: bool,
golden_values: Dict[str, common.GoldenValueMetric],
actual_values: Dict[str, common.GoldenValueMetric],
model_config_path: str,
checks: Optional[Dict[str, List[common.Test]]] = None,
):
if checks is None:
with open(model_config_path) as f:
model_config = yaml.safe_load(f)
checks_types = (
model_config["METRICS"] if "METRICS" in model_config else ["lm loss", "num-zeros"]
)
checks = {metric: CHECK_THRESHOLDS[metric] for metric in checks_types}
if (
len(
missing_metrics := [
golden_metric
for golden_metric in checks.keys()
if golden_metric not in golden_values.keys()
]
)
> 0
):
logger.error(
f"The following metrics are required but not provided in golden values: {', '.join(missing_metrics)}"
)
assert False
common.pipeline(
compare_approximate_results=compare_approximate_results,
golden_values=golden_values,
actual_values=actual_values,
checks=checks,
)