| | |
| | |
| | |
| | |
| | |
| |
|
| | import gc |
| | import subprocess |
| | import time |
| | from dataclasses import dataclass |
| | from typing import Optional |
| |
|
| | import torch |
| | from torch._utils import _get_available_device_type, _get_device_module |
| |
|
| | from torchtitan.tools.logging import logger |
| |
|
| |
|
| | def get_device_info(): |
| | device_type = _get_available_device_type() |
| | if device_type is None: |
| | device_type = "cuda" |
| | device_module = _get_device_module(device_type) |
| | return device_type, device_module |
| |
|
| |
|
| | device_type, device_module = get_device_info() |
| |
|
| |
|
| | |
| | class GarbageCollection: |
| | def __init__(self, gc_freq=1000): |
| | assert gc_freq > 0, "gc_freq must be a positive integer" |
| | self.gc_freq = gc_freq |
| | gc.disable() |
| | self.collect("Initial GC collection.") |
| |
|
| | def run(self, step_count): |
| | if step_count > 1 and step_count % self.gc_freq == 0: |
| | self.collect("Peforming periodical GC collection.") |
| |
|
| | @staticmethod |
| | def collect(reason: str): |
| | begin = time.monotonic() |
| | gc.collect(1) |
| | logger.info("[GC] %s %.2f seconds.", reason, time.monotonic() - begin) |
| |
|
| |
|
| | |
| | def get_peak_flops(device_name: str) -> int: |
| | try: |
| | |
| | result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True) |
| | |
| | filtered_lines = [ |
| | line |
| | for line in result.stdout.splitlines() |
| | if "NVIDIA" in line and "H100" in line |
| | ] |
| | |
| | device_name = " ".join(filtered_lines) or device_name |
| | except FileNotFoundError as e: |
| | logger.warning(f"Error running lspci: {e}, fallback to use device_name") |
| | if "A100" in device_name: |
| | |
| | return 312e12 |
| | elif "H100" in device_name: |
| | |
| | |
| | if "NVL" in device_name: |
| | return 835e12 |
| | elif "PCIe" in device_name: |
| | return 756e12 |
| | else: |
| | return 989e12 |
| | elif "H200" in device_name: |
| | |
| | return 989e12 |
| | elif "MI300X" in device_name or "MI325X" in device_name: |
| | |
| | |
| | return 1300e12 |
| | elif "MI250X" in device_name: |
| | |
| | return 191.5e12 |
| | elif "Data Center GPU Max 1550" in device_name: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units |
| | return 512 * max_comp_units * 1300 * 10**6 |
| | else: |
| | logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100") |
| | return 312e12 |
| |
|
| |
|
| | @dataclass(frozen=True) |
| | class Color: |
| | black = "\033[30m" |
| | red = "\033[31m" |
| | green = "\033[32m" |
| | yellow = "\033[33m" |
| | blue = "\033[34m" |
| | magenta = "\033[35m" |
| | cyan = "\033[36m" |
| | white = "\033[37m" |
| | reset = "\033[39m" |
| |
|
| |
|
| | @dataclass(frozen=True) |
| | class NoColor: |
| | black = "" |
| | red = "" |
| | green = "" |
| | yellow = "" |
| | blue = "" |
| | magenta = "" |
| | cyan = "" |
| | white = "" |
| | reset = "" |
| |
|
| |
|
| | def check_if_feature_in_pytorch( |
| | feature_name: str, |
| | pull_request: str, |
| | min_nightly_version: Optional[str] = None, |
| | ) -> None: |
| | if "git" in torch.__version__: |
| | |
| | logger.warning( |
| | "detected that the pytorch is built from source. Please make sure the PR " |
| | f"({pull_request_link}) is included in pytorch for correct {feature_name}." |
| | ) |
| | elif min_nightly_version is not None and torch.__version__ < min_nightly_version: |
| | logger.warning( |
| | f"detected that the pytorch version {torch.__version__} is older than " |
| | f"{min_nightly_version}. Please upgrade a newer version to include the " |
| | f"change in ({pull_request_link}) for correct {feature_name}." |
| | ) |
| |
|