| import enum |
| import random |
| from typing import Any, NamedTuple |
|
|
| import numpy as np |
| import torch |
|
|
| from fastvideo.logger import init_logger |
|
|
| logger = init_logger(__name__) |
|
|
|
|
| class AttentionBackendEnum(enum.Enum): |
| FLASH_ATTN = enum.auto() |
| TORCH_SDPA = enum.auto() |
| SAGE_ATTN = enum.auto() |
| SAGE_ATTN_THREE = enum.auto() |
| ATTN_QAT_INFER = enum.auto() |
| ATTN_QAT_TRAIN = enum.auto() |
| VIDEO_SPARSE_ATTN = enum.auto() |
| BSA_ATTN = enum.auto() |
| VMOBA_ATTN = enum.auto() |
| SLA_ATTN = enum.auto() |
| SAGE_SLA_ATTN = enum.auto() |
| SPARSE_FP4_ATTN = enum.auto() |
| SPARSE_FP4_OURS_P_ATTN = enum.auto() |
| SPARSE_FP4_COMPRESS_ATTN = enum.auto() |
| NO_ATTENTION = enum.auto() |
|
|
|
|
| class PlatformEnum(enum.Enum): |
| CUDA = enum.auto() |
| ROCM = enum.auto() |
| TPU = enum.auto() |
| XPU = enum.auto() |
| CPU = enum.auto() |
| MPS = enum.auto() |
| OOT = enum.auto() |
| UNSPECIFIED = enum.auto() |
| NPU = enum.auto() |
|
|
|
|
| class CpuArchEnum(enum.Enum): |
| X86 = enum.auto() |
| ARM = enum.auto() |
| UNSPECIFIED = enum.auto() |
|
|
|
|
| class DeviceCapability(NamedTuple): |
| major: int |
| minor: int |
|
|
| def as_version_str(self) -> str: |
| return f"{self.major}.{self.minor}" |
|
|
| def to_int(self) -> int: |
| """ |
| Express device capability as an integer ``<major><minor>``. |
| |
| It is assumed that the minor version is always a single digit. |
| """ |
| assert 0 <= self.minor < 10 |
| return self.major * 10 + self.minor |
|
|
|
|
| class Platform: |
| _enum: PlatformEnum |
| device_name: str |
| device_type: str |
|
|
| dispatch_key: str = "CPU" |
|
|
| |
| |
| |
| |
| device_control_env_var: str = "FASTVIDEO_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER" |
|
|
| |
| |
| |
| ray_device_key: str = "" |
| |
| |
| |
| |
| |
| simple_compile_backend: str = "inductor" |
|
|
| supported_quantization: list[str] = [] |
|
|
| additional_env_vars: list[str] = [] |
|
|
| def is_cuda(self) -> bool: |
| return self._enum == PlatformEnum.CUDA |
|
|
| def is_rocm(self) -> bool: |
| return self._enum == PlatformEnum.ROCM |
|
|
| def is_tpu(self) -> bool: |
| return self._enum == PlatformEnum.TPU |
|
|
| def is_xpu(self) -> bool: |
| return self._enum == PlatformEnum.XPU |
|
|
| def is_cpu(self) -> bool: |
| return self._enum == PlatformEnum.CPU |
|
|
| def is_out_of_tree(self) -> bool: |
| return self._enum == PlatformEnum.OOT |
|
|
| def is_cuda_alike(self) -> bool: |
| """Stateless version of :func:`torch.cuda.is_available`.""" |
| return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) |
|
|
| def is_mps(self) -> bool: |
| return self._enum == PlatformEnum.MPS |
|
|
| def is_npu(self) -> bool: |
| return self._enum == PlatformEnum.NPU |
|
|
| @classmethod |
| def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None, head_size: int, |
| dtype: torch.dtype) -> str: |
| """Get the attention backend class of a device.""" |
| return "" |
|
|
| @classmethod |
| def get_device_capability( |
| cls, |
| device_id: int = 0, |
| ) -> DeviceCapability | None: |
| """Stateless version of :func:`torch.cuda.get_device_capability`.""" |
| return None |
|
|
| @classmethod |
| def has_device_capability( |
| cls, |
| capability: tuple[int, int] | int, |
| device_id: int = 0, |
| ) -> bool: |
| """ |
| Test whether this platform is compatible with a device capability. |
| |
| The ``capability`` argument can either be: |
| |
| - A tuple ``(major, minor)``. |
| - An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`) |
| """ |
| current_capability = cls.get_device_capability(device_id=device_id) |
| if current_capability is None: |
| return False |
|
|
| if isinstance(capability, tuple): |
| return current_capability >= capability |
|
|
| return current_capability.to_int() >= capability |
|
|
| @classmethod |
| def get_device_name(cls, device_id: int = 0) -> str: |
| """Get the name of a device.""" |
| raise NotImplementedError |
|
|
| @classmethod |
| def get_device_uuid(cls, device_id: int = 0) -> str: |
| """Get the uuid of a device, e.g. the PCI bus ID.""" |
| raise NotImplementedError |
|
|
| @classmethod |
| def get_device_total_memory(cls, device_id: int = 0) -> int: |
| """Get the total memory of a device in bytes.""" |
| raise NotImplementedError |
|
|
| @classmethod |
| def is_async_output_supported(cls, enforce_eager: bool | None) -> bool: |
| """ |
| Check if the current platform supports async output. |
| """ |
| raise NotImplementedError |
|
|
| @classmethod |
| def get_torch_device(cls) -> Any: |
| """ |
| Check if the current platform supports torch device. |
| """ |
| raise NotImplementedError |
|
|
| @classmethod |
| def inference_mode(cls): |
| """A device-specific wrapper of `torch.inference_mode`. |
| |
| This wrapper is recommended because some hardware backends such as TPU |
| do not support `torch.inference_mode`. In such a case, they will fall |
| back to `torch.no_grad` by overriding this method. |
| """ |
| return torch.inference_mode(mode=True) |
|
|
| @classmethod |
| def seed_everything(cls, seed: int | None = None) -> None: |
| """ |
| Set the seed of each random module. |
| `torch.manual_seed` will set seed on all devices. |
| |
| Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 |
| """ |
| if seed is not None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
| @classmethod |
| def verify_model_arch(cls, model_arch: str) -> None: |
| """ |
| Verify whether the current platform supports the specified model |
| architecture. |
| |
| - This will raise an Error or Warning based on the model support on |
| the current platform. |
| - By default all models are considered supported. |
| """ |
| pass |
|
|
| @classmethod |
| def verify_quantization(cls, quant: str) -> None: |
| """ |
| Verify whether the quantization is supported by the current platform. |
| """ |
| if cls.supported_quantization and \ |
| quant not in cls.supported_quantization: |
| raise ValueError(f"{quant} quantization is currently not supported in " |
| f"{cls.device_name}.") |
|
|
| @classmethod |
| def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> float: |
| """ |
| Return the memory usage in bytes. |
| """ |
| raise NotImplementedError |
|
|
| @classmethod |
| def get_device_communicator_cls(cls) -> str: |
| """ |
| Get device specific communicator class for distributed communication. |
| """ |
| return "fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" |
|
|
| @classmethod |
| def get_cpu_architecture(cls) -> CpuArchEnum: |
| """Get the CPU architecture of the current platform.""" |
| return CpuArchEnum.UNSPECIFIED |
|
|
|
|
| class UnspecifiedPlatform(Platform): |
| _enum = PlatformEnum.UNSPECIFIED |
| device_type = "" |
|
|