| import torch | |
| from contextlib import contextmanager | |
| def device_context(device: torch.device): | |
| """Context manager that sets the active device for any backend (cuda, xpu, etc.).""" | |
| backend = getattr(torch, device.type, None) | |
| if backend is not None and hasattr(backend, "device"): | |
| with backend.device(device): | |
| yield | |
| else: | |
| yield | |