import torch from contextlib import contextmanager @contextmanager def device_context(device: torch.device): """Context manager that sets the active device for any backend (cuda, xpu, etc.).""" backend = getattr(torch, device.type, None) if backend is not None and hasattr(backend, "device"): with backend.device(device): yield else: yield