Spaces:
Runtime error
Runtime error
Factor Studios
commited on
Update torch_vgpu.py
Browse files- torch_vgpu.py +35 -10
torch_vgpu.py
CHANGED
|
@@ -7,12 +7,35 @@ from typing import Optional, Union, Tuple
|
|
| 7 |
import numpy as np
|
| 8 |
from virtual_vram import VirtualVRAM
|
| 9 |
|
| 10 |
-
#
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
class VGPUStorage(torch.Storage):
|
| 18 |
"""Custom storage class that uses our virtual VRAM"""
|
|
@@ -47,15 +70,17 @@ class VGPUDevice:
|
|
| 47 |
def __init__(self, vram: Optional[VirtualVRAM] = None):
|
| 48 |
self.vram = vram or VirtualVRAM()
|
| 49 |
self.tensor_cores = None # Will be initialized when needed
|
| 50 |
-
self.
|
| 51 |
-
self.device_name = "vgpu" # User-facing device name
|
| 52 |
self._register_device()
|
| 53 |
|
| 54 |
def _register_device(self):
|
| 55 |
"""Register vGPU device using PyTorch's device system"""
|
| 56 |
try:
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
# Store this instance for reuse
|
| 61 |
VGPUDevice._VGPU_INSTANCES[self.device_name] = self
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
from virtual_vram import VirtualVRAM
|
| 9 |
|
| 10 |
+
# Initialize custom backend
|
| 11 |
+
def init_vgpu_backend():
|
| 12 |
+
try:
|
| 13 |
+
# First rename the backend
|
| 14 |
+
torch.utils.rename_privateuse1_backend("vgpu")
|
| 15 |
+
|
| 16 |
+
# Then generate all the necessary methods
|
| 17 |
+
torch.utils.generate_methods_for_privateuse1_backend(
|
| 18 |
+
for_tensor=True,
|
| 19 |
+
for_module=True,
|
| 20 |
+
for_packed_sequence=True,
|
| 21 |
+
for_storage=True
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Register our custom library
|
| 25 |
+
lib = Library("vgpu", "DEF")
|
| 26 |
+
lib.define("custom_op(Tensor self) -> Tensor")
|
| 27 |
+
|
| 28 |
+
@impl("vgpu", "custom_op", "Tensor")
|
| 29 |
+
def custom_op_impl(tensor):
|
| 30 |
+
return tensor.clone()
|
| 31 |
+
|
| 32 |
+
return True
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"Backend initialization warning: {e}")
|
| 35 |
+
return False
|
| 36 |
+
|
| 37 |
+
# Initialize the backend
|
| 38 |
+
VGPU_BACKEND_INITIALIZED = init_vgpu_backend()
|
| 39 |
|
| 40 |
class VGPUStorage(torch.Storage):
|
| 41 |
"""Custom storage class that uses our virtual VRAM"""
|
|
|
|
| 70 |
def __init__(self, vram: Optional[VirtualVRAM] = None):
|
| 71 |
self.vram = vram or VirtualVRAM()
|
| 72 |
self.tensor_cores = None # Will be initialized when needed
|
| 73 |
+
self.device_name = "vgpu" # Both internal and user-facing name
|
|
|
|
| 74 |
self._register_device()
|
| 75 |
|
| 76 |
def _register_device(self):
|
| 77 |
"""Register vGPU device using PyTorch's device system"""
|
| 78 |
try:
|
| 79 |
+
if not VGPU_BACKEND_INITIALIZED:
|
| 80 |
+
raise RuntimeError("VGPU backend not properly initialized")
|
| 81 |
+
|
| 82 |
+
# Create device with explicit index
|
| 83 |
+
self._device = torch.device("vgpu")
|
| 84 |
|
| 85 |
# Store this instance for reuse
|
| 86 |
VGPUDevice._VGPU_INSTANCES[self.device_name] = self
|