ColabWan / preprocessing /sam3 /model /device_utils.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
1.07 kB
from contextlib import nullcontext
import torch
def mps_is_available() -> bool:
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
def get_accelerator_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
if mps_is_available():
return torch.device("mps")
return torch.device("cpu")
def is_accelerator_device(device) -> bool:
if device is None:
return False
return torch.device(device).type in {"cuda", "mps"}
def accelerator_autocast(dtype=torch.bfloat16):
device_type = get_accelerator_device().type
if device_type in {"cuda", "mps"}:
return torch.autocast(device_type=device_type, dtype=dtype)
return nullcontext()
def empty_accelerator_cache():
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
try:
torch.cuda.ipc_collect()
except Exception:
pass
elif mps_is_available():
torch.mps.synchronize()
torch.mps.empty_cache()