yitongl's picture
Upload FastVideo 5090 safetensors checkpoint2950
d4cc469 verified
import enum
import random
from typing import Any, NamedTuple
import numpy as np
import torch
from fastvideo.logger import init_logger
logger = init_logger(__name__)
class AttentionBackendEnum(enum.Enum):
FLASH_ATTN = enum.auto()
TORCH_SDPA = enum.auto()
SAGE_ATTN = enum.auto()
SAGE_ATTN_THREE = enum.auto()
ATTN_QAT_INFER = enum.auto()
ATTN_QAT_TRAIN = enum.auto()
VIDEO_SPARSE_ATTN = enum.auto()
BSA_ATTN = enum.auto()
VMOBA_ATTN = enum.auto()
SLA_ATTN = enum.auto()
SAGE_SLA_ATTN = enum.auto()
SPARSE_FP4_ATTN = enum.auto()
SPARSE_FP4_OURS_P_ATTN = enum.auto()
SPARSE_FP4_COMPRESS_ATTN = enum.auto()
NO_ATTENTION = enum.auto()
class PlatformEnum(enum.Enum):
CUDA = enum.auto()
ROCM = enum.auto()
TPU = enum.auto()
XPU = enum.auto()
CPU = enum.auto()
MPS = enum.auto()
OOT = enum.auto()
UNSPECIFIED = enum.auto()
NPU = enum.auto()
class CpuArchEnum(enum.Enum):
X86 = enum.auto()
ARM = enum.auto()
UNSPECIFIED = enum.auto()
class DeviceCapability(NamedTuple):
major: int
minor: int
def as_version_str(self) -> str:
return f"{self.major}.{self.minor}"
def to_int(self) -> int:
"""
Express device capability as an integer ``<major><minor>``.
It is assumed that the minor version is always a single digit.
"""
assert 0 <= self.minor < 10
return self.major * 10 + self.minor
class Platform:
_enum: PlatformEnum
device_name: str
device_type: str
dispatch_key: str = "CPU"
# platform-agnostic way to specify the device control environment variable,
# .e.g. CUDA_VISIBLE_DEVICES for CUDA.
# hint: search for "get_visible_accelerator_ids_env_var" in
# https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
device_control_env_var: str = "FASTVIDEO_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"
# available ray device keys:
# https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
# empty string means the device does not support ray
ray_device_key: str = ""
# The torch.compile backend for compiling simple and
# standalone functions. The default value is "inductor" to keep
# the same behavior as PyTorch.
# NOTE: for the forward part of the model, vLLM has another separate
# compilation strategy.
simple_compile_backend: str = "inductor"
supported_quantization: list[str] = []
additional_env_vars: list[str] = []
def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA
def is_rocm(self) -> bool:
return self._enum == PlatformEnum.ROCM
def is_tpu(self) -> bool:
return self._enum == PlatformEnum.TPU
def is_xpu(self) -> bool:
return self._enum == PlatformEnum.XPU
def is_cpu(self) -> bool:
return self._enum == PlatformEnum.CPU
def is_out_of_tree(self) -> bool:
return self._enum == PlatformEnum.OOT
def is_cuda_alike(self) -> bool:
"""Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
def is_mps(self) -> bool:
return self._enum == PlatformEnum.MPS
def is_npu(self) -> bool:
return self._enum == PlatformEnum.NPU
@classmethod
def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None, head_size: int,
dtype: torch.dtype) -> str:
"""Get the attention backend class of a device."""
return ""
@classmethod
def get_device_capability(
cls,
device_id: int = 0,
) -> DeviceCapability | None:
"""Stateless version of :func:`torch.cuda.get_device_capability`."""
return None
@classmethod
def has_device_capability(
cls,
capability: tuple[int, int] | int,
device_id: int = 0,
) -> bool:
"""
Test whether this platform is compatible with a device capability.
The ``capability`` argument can either be:
- A tuple ``(major, minor)``.
- An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
"""
current_capability = cls.get_device_capability(device_id=device_id)
if current_capability is None:
return False
if isinstance(capability, tuple):
return current_capability >= capability
return current_capability.to_int() >= capability
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
"""Get the name of a device."""
raise NotImplementedError
@classmethod
def get_device_uuid(cls, device_id: int = 0) -> str:
"""Get the uuid of a device, e.g. the PCI bus ID."""
raise NotImplementedError
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
"""Get the total memory of a device in bytes."""
raise NotImplementedError
@classmethod
def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:
"""
Check if the current platform supports async output.
"""
raise NotImplementedError
@classmethod
def get_torch_device(cls) -> Any:
"""
Check if the current platform supports torch device.
"""
raise NotImplementedError
@classmethod
def inference_mode(cls):
"""A device-specific wrapper of `torch.inference_mode`.
This wrapper is recommended because some hardware backends such as TPU
do not support `torch.inference_mode`. In such a case, they will fall
back to `torch.no_grad` by overriding this method.
"""
return torch.inference_mode(mode=True)
@classmethod
def seed_everything(cls, seed: int | None = None) -> None:
"""
Set the seed of each random module.
`torch.manual_seed` will set seed on all devices.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
"""
Verify whether the current platform supports the specified model
architecture.
- This will raise an Error or Warning based on the model support on
the current platform.
- By default all models are considered supported.
"""
pass
@classmethod
def verify_quantization(cls, quant: str) -> None:
"""
Verify whether the quantization is supported by the current platform.
"""
if cls.supported_quantization and \
quant not in cls.supported_quantization:
raise ValueError(f"{quant} quantization is currently not supported in "
f"{cls.device_name}.")
@classmethod
def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> float:
"""
Return the memory usage in bytes.
"""
raise NotImplementedError
@classmethod
def get_device_communicator_cls(cls) -> str:
"""
Get device specific communicator class for distributed communication.
"""
return "fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa
@classmethod
def get_cpu_architecture(cls) -> CpuArchEnum:
"""Get the CPU architecture of the current platform."""
return CpuArchEnum.UNSPECIFIED
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
device_type = ""