drbh
Migrated from kernels-community/finegrained-fp8
45428ec unverified
raw
history blame
386 Bytes
import torch
from contextlib import contextmanager
@contextmanager
def device_context(device: torch.device):
"""Context manager that sets the active device for any backend (cuda, xpu, etc.)."""
backend = getattr(torch, device.type, None)
if backend is not None and hasattr(backend, "device"):
with backend.device(device):
yield
else:
yield