Spaces:
Runtime error
Runtime error
| """ | |
| Custom PyTorch device implementation that routes operations through our virtual GPU. | |
| """ | |
| import torch | |
| from torch.library import Library, impl | |
| from typing import Optional, Union, Tuple | |
| import numpy as np | |
| from virtual_vram import VirtualVRAM | |
| # Initialize custom backend | |
| def init_vgpu_backend(): | |
| try: | |
| # First rename the backend | |
| torch.utils.rename_privateuse1_backend("vgpu") | |
| # Then generate all the necessary methods | |
| torch.utils.generate_methods_for_privateuse1_backend( | |
| for_tensor=True, | |
| for_module=True, | |
| for_packed_sequence=True, | |
| for_storage=True | |
| ) | |
| # Register our custom library | |
| lib = Library("vgpu", "DEF") | |
| lib.define("custom_op(Tensor self) -> Tensor") | |
| def custom_op_impl(tensor): | |
| return tensor.clone() | |
| return True | |
| except Exception as e: | |
| print(f"Backend initialization warning: {e}") | |
| return False | |
| # Initialize the backend | |
| VGPU_BACKEND_INITIALIZED = init_vgpu_backend() | |
| class VGPUStorage(torch.Storage): | |
| """Custom storage class that uses our virtual VRAM""" | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.vram = kwargs.get('vram') | |
| if not self.vram: | |
| from virtual_vram import VirtualVRAM | |
| self.vram = VirtualVRAM() | |
| self.tensor_id = kwargs.get('tensor_id', f"tensor_{id(self)}") | |
| def _new_shared(self, size): | |
| return VGPUStorage(size, vram=self.vram) | |
| class VGPUTensor: | |
| """Tensor implementation that uses vGPU for computations""" | |
| def __new__(cls, elem): | |
| return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) | |
| class VGPUDevice: | |
| """ | |
| Custom PyTorch device implementation that routes operations through vGPU. | |
| Usage: | |
| vgpu = VGPUDevice() | |
| with vgpu.mode(): | |
| tensor = torch.randn(2, 3) # Will be on vGPU | |
| """ | |
| _VGPU_INSTANCES = {} # Class-level dict to track instances | |
| def __init__(self, vram: Optional[VirtualVRAM] = None): | |
| self.vram = vram or VirtualVRAM() | |
| self.tensor_cores = None # Will be initialized when needed | |
| self.device_name = "vgpu" # Both internal and user-facing name | |
| self._register_device() | |
| def _register_device(self): | |
| """Register vGPU device using PyTorch's device system""" | |
| try: | |
| if not VGPU_BACKEND_INITIALIZED: | |
| raise RuntimeError("VGPU backend not properly initialized") | |
| # Create device with explicit index | |
| self._device = torch.device("vgpu") | |
| # Store this instance for reuse | |
| VGPUDevice._VGPU_INSTANCES[self.device_name] = self | |
| # Define custom operations for the device | |
| class VGPUAllocator: | |
| def __init__(self, vram, device): | |
| self.vram = vram | |
| self.device = device | |
| def __call__(self, size, dtype=None, device=None): | |
| # Create tensor on CPU first | |
| cpu_tensor = torch.empty(size, dtype=dtype, device='cpu') | |
| # Move to vGPU storage | |
| return to_vgpu(cpu_tensor, self.vram) | |
| # Set up allocator | |
| self._allocator = VGPUAllocator(self.vram, self._device) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to register vGPU device: {str(e)}") | |
| def type(self): | |
| return self.internal_name | |
| def __str__(self): | |
| return f"{self.internal_name}:0" | |
| def __repr__(self): | |
| return f"vgpu(device='{self.internal_name}:0')" | |
| def device(self): | |
| """Get the PyTorch device object that maps to our vGPU""" | |
| return self._device # Return the already created device object | |
| def mode(self): | |
| """Get a context manager for vGPU operations""" | |
| return torch.device(self._device) | |
| def _init_tensor_cores(self): | |
| if self.tensor_cores is None: | |
| from tensor_core import TensorCoreArray | |
| self.tensor_cores = TensorCoreArray() | |
| def _to_vram(self, tensor: torch.Tensor) -> str: | |
| """Store tensor data in virtual VRAM""" | |
| tensor_id = f"tensor_{id(tensor)}" | |
| data = tensor.detach().cpu().numpy() | |
| self.vram.storage.store_tensor(tensor_id, data) | |
| return tensor_id | |
| def _from_vram(self, tensor_id: str) -> torch.Tensor: | |
| """Retrieve tensor data from virtual VRAM""" | |
| data = self.vram.storage.load_tensor(tensor_id) | |
| return torch.from_numpy(data) | |
| def matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | |
| """Matrix multiplication using tensor cores""" | |
| self._init_tensor_cores() | |
| # Store inputs in VRAM | |
| a_id = self._to_vram(a) | |
| b_id = self._to_vram(b) | |
| # Perform matmul using tensor cores | |
| result = self.tensor_cores.matmul( | |
| self.vram.storage.load_tensor(a_id), | |
| self.vram.storage.load_tensor(b_id) | |
| ) | |
| # Create new tensor with result | |
| return torch.from_numpy(result) | |
| def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor: | |
| """Move a tensor to vGPU device""" | |
| if not isinstance(tensor, torch.Tensor): | |
| tensor = torch.tensor(tensor) | |
| # Get or create vGPU device | |
| if not VGPUDevice._VGPU_INSTANCES: | |
| device = VGPUDevice(vram) | |
| else: | |
| device = next(iter(VGPUDevice._VGPU_INSTANCES.values())) | |
| if vram is not None: | |
| device.vram = vram | |
| # Move data to vRAM | |
| tensor_id = device._to_vram(tensor) | |
| result = device._from_vram(tensor_id) | |
| result.requires_grad = tensor.requires_grad | |
| # Set the device using the internal name | |
| result.data = result.data.to(device._device) | |
| return result | |