| |
| |
| |
| |
| |
|
|
| import gc |
| import subprocess |
| import time |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import torch |
| from torch._utils import _get_available_device_type, _get_device_module |
|
|
| from torchtitan.tools.logging import logger |
|
|
|
|
| def get_device_info(): |
| device_type = _get_available_device_type() |
| if device_type is None: |
| device_type = "cuda" |
| device_module = _get_device_module(device_type) |
| return device_type, device_module |
|
|
|
|
| device_type, device_module = get_device_info() |
|
|
|
|
| |
| class GarbageCollection: |
| def __init__(self, gc_freq=1000): |
| assert gc_freq > 0, "gc_freq must be a positive integer" |
| self.gc_freq = gc_freq |
| gc.disable() |
| self.collect("Initial GC collection.") |
|
|
| def run(self, step_count): |
| if step_count > 1 and step_count % self.gc_freq == 0: |
| self.collect("Peforming periodical GC collection.") |
|
|
| @staticmethod |
| def collect(reason: str): |
| begin = time.monotonic() |
| gc.collect(1) |
| logger.info("[GC] %s %.2f seconds.", reason, time.monotonic() - begin) |
|
|
|
|
| |
| def get_peak_flops(device_name: str) -> int: |
| try: |
| |
| result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True) |
| |
| filtered_lines = [ |
| line |
| for line in result.stdout.splitlines() |
| if "NVIDIA" in line and "H100" in line |
| ] |
| |
| device_name = " ".join(filtered_lines) or device_name |
| except FileNotFoundError as e: |
| logger.warning(f"Error running lspci: {e}, fallback to use device_name") |
| if "A100" in device_name: |
| |
| return 312e12 |
| elif "H100" in device_name: |
| |
| |
| if "NVL" in device_name: |
| return 835e12 |
| elif "PCIe" in device_name: |
| return 756e12 |
| else: |
| return 989e12 |
| elif "H200" in device_name: |
| |
| return 989e12 |
| elif "MI300X" in device_name or "MI325X" in device_name: |
| |
| |
| return 1300e12 |
| elif "MI250X" in device_name: |
| |
| return 191.5e12 |
| elif "Data Center GPU Max 1550" in device_name: |
| |
| |
| |
| |
| |
| |
| |
| max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units |
| return 512 * max_comp_units * 1300 * 10**6 |
| else: |
| logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100") |
| return 312e12 |
|
|
|
|
| @dataclass(frozen=True) |
| class Color: |
| black = "\033[30m" |
| red = "\033[31m" |
| green = "\033[32m" |
| yellow = "\033[33m" |
| blue = "\033[34m" |
| magenta = "\033[35m" |
| cyan = "\033[36m" |
| white = "\033[37m" |
| reset = "\033[39m" |
|
|
|
|
| @dataclass(frozen=True) |
| class NoColor: |
| black = "" |
| red = "" |
| green = "" |
| yellow = "" |
| blue = "" |
| magenta = "" |
| cyan = "" |
| white = "" |
| reset = "" |
|
|
|
|
| def check_if_feature_in_pytorch( |
| feature_name: str, |
| pull_request: str, |
| min_nightly_version: Optional[str] = None, |
| ) -> None: |
| if "git" in torch.__version__: |
| |
| logger.warning( |
| "detected that the pytorch is built from source. Please make sure the PR " |
| f"({pull_request_link}) is included in pytorch for correct {feature_name}." |
| ) |
| elif min_nightly_version is not None and torch.__version__ < min_nightly_version: |
| logger.warning( |
| f"detected that the pytorch version {torch.__version__} is older than " |
| f"{min_nightly_version}. Please upgrade a newer version to include the " |
| f"change in ({pull_request_link}) for correct {feature_name}." |
| ) |
|
|