kernels-bot's picture
Build uploaded using `kernels`.
3ed240d verified
raw
history blame contribute delete
386 Bytes
import torch
from contextlib import contextmanager
@contextmanager
def device_context(device: torch.device):
"""Context manager that sets the active device for any backend (cuda, xpu, etc.)."""
backend = getattr(torch, device.type, None)
if backend is not None and hasattr(backend, "device"):
with backend.device(device):
yield
else:
yield