import enum import random from typing import Any, NamedTuple import numpy as np import torch from fastvideo.logger import init_logger logger = init_logger(__name__) class AttentionBackendEnum(enum.Enum): FLASH_ATTN = enum.auto() TORCH_SDPA = enum.auto() SAGE_ATTN = enum.auto() SAGE_ATTN_THREE = enum.auto() ATTN_QAT_INFER = enum.auto() ATTN_QAT_TRAIN = enum.auto() VIDEO_SPARSE_ATTN = enum.auto() BSA_ATTN = enum.auto() VMOBA_ATTN = enum.auto() SLA_ATTN = enum.auto() SAGE_SLA_ATTN = enum.auto() SPARSE_FP4_ATTN = enum.auto() SPARSE_FP4_OURS_P_ATTN = enum.auto() SPARSE_FP4_COMPRESS_ATTN = enum.auto() NO_ATTENTION = enum.auto() class PlatformEnum(enum.Enum): CUDA = enum.auto() ROCM = enum.auto() TPU = enum.auto() XPU = enum.auto() CPU = enum.auto() MPS = enum.auto() OOT = enum.auto() UNSPECIFIED = enum.auto() NPU = enum.auto() class CpuArchEnum(enum.Enum): X86 = enum.auto() ARM = enum.auto() UNSPECIFIED = enum.auto() class DeviceCapability(NamedTuple): major: int minor: int def as_version_str(self) -> str: return f"{self.major}.{self.minor}" def to_int(self) -> int: """ Express device capability as an integer ````. It is assumed that the minor version is always a single digit. """ assert 0 <= self.minor < 10 return self.major * 10 + self.minor class Platform: _enum: PlatformEnum device_name: str device_type: str dispatch_key: str = "CPU" # platform-agnostic way to specify the device control environment variable, # .e.g. CUDA_VISIBLE_DEVICES for CUDA. # hint: search for "get_visible_accelerator_ids_env_var" in # https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa device_control_env_var: str = "FASTVIDEO_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER" # available ray device keys: # https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa # empty string means the device does not support ray ray_device_key: str = "" # The torch.compile backend for compiling simple and # standalone functions. The default value is "inductor" to keep # the same behavior as PyTorch. # NOTE: for the forward part of the model, vLLM has another separate # compilation strategy. simple_compile_backend: str = "inductor" supported_quantization: list[str] = [] additional_env_vars: list[str] = [] def is_cuda(self) -> bool: return self._enum == PlatformEnum.CUDA def is_rocm(self) -> bool: return self._enum == PlatformEnum.ROCM def is_tpu(self) -> bool: return self._enum == PlatformEnum.TPU def is_xpu(self) -> bool: return self._enum == PlatformEnum.XPU def is_cpu(self) -> bool: return self._enum == PlatformEnum.CPU def is_out_of_tree(self) -> bool: return self._enum == PlatformEnum.OOT def is_cuda_alike(self) -> bool: """Stateless version of :func:`torch.cuda.is_available`.""" return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) def is_mps(self) -> bool: return self._enum == PlatformEnum.MPS def is_npu(self) -> bool: return self._enum == PlatformEnum.NPU @classmethod def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None, head_size: int, dtype: torch.dtype) -> str: """Get the attention backend class of a device.""" return "" @classmethod def get_device_capability( cls, device_id: int = 0, ) -> DeviceCapability | None: """Stateless version of :func:`torch.cuda.get_device_capability`.""" return None @classmethod def has_device_capability( cls, capability: tuple[int, int] | int, device_id: int = 0, ) -> bool: """ Test whether this platform is compatible with a device capability. The ``capability`` argument can either be: - A tuple ``(major, minor)``. - An integer ````. (See :meth:`DeviceCapability.to_int`) """ current_capability = cls.get_device_capability(device_id=device_id) if current_capability is None: return False if isinstance(capability, tuple): return current_capability >= capability return current_capability.to_int() >= capability @classmethod def get_device_name(cls, device_id: int = 0) -> str: """Get the name of a device.""" raise NotImplementedError @classmethod def get_device_uuid(cls, device_id: int = 0) -> str: """Get the uuid of a device, e.g. the PCI bus ID.""" raise NotImplementedError @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: """Get the total memory of a device in bytes.""" raise NotImplementedError @classmethod def is_async_output_supported(cls, enforce_eager: bool | None) -> bool: """ Check if the current platform supports async output. """ raise NotImplementedError @classmethod def get_torch_device(cls) -> Any: """ Check if the current platform supports torch device. """ raise NotImplementedError @classmethod def inference_mode(cls): """A device-specific wrapper of `torch.inference_mode`. This wrapper is recommended because some hardware backends such as TPU do not support `torch.inference_mode`. In such a case, they will fall back to `torch.no_grad` by overriding this method. """ return torch.inference_mode(mode=True) @classmethod def seed_everything(cls, seed: int | None = None) -> None: """ Set the seed of each random module. `torch.manual_seed` will set seed on all devices. Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 """ if seed is not None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @classmethod def verify_model_arch(cls, model_arch: str) -> None: """ Verify whether the current platform supports the specified model architecture. - This will raise an Error or Warning based on the model support on the current platform. - By default all models are considered supported. """ pass @classmethod def verify_quantization(cls, quant: str) -> None: """ Verify whether the quantization is supported by the current platform. """ if cls.supported_quantization and \ quant not in cls.supported_quantization: raise ValueError(f"{quant} quantization is currently not supported in " f"{cls.device_name}.") @classmethod def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> float: """ Return the memory usage in bytes. """ raise NotImplementedError @classmethod def get_device_communicator_cls(cls) -> str: """ Get device specific communicator class for distributed communication. """ return "fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa @classmethod def get_cpu_architecture(cls) -> CpuArchEnum: """Get the CPU architecture of the current platform.""" return CpuArchEnum.UNSPECIFIED class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED device_type = ""