Spaces:
Runtime error
Runtime error
Factor Studios
commited on
Update torch_vgpu.py
Browse files- torch_vgpu.py +9 -2
torch_vgpu.py
CHANGED
|
@@ -7,6 +7,13 @@ from typing import Optional, Union, Tuple
|
|
| 7 |
import numpy as np
|
| 8 |
from virtual_vram import VirtualVRAM
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
class VGPUStorage(torch.Storage):
|
| 11 |
"""Custom storage class that uses our virtual VRAM"""
|
| 12 |
|
|
@@ -47,8 +54,8 @@ class VGPUDevice:
|
|
| 47 |
def _register_device(self):
|
| 48 |
"""Register vGPU device using PyTorch's device system"""
|
| 49 |
try:
|
| 50 |
-
# Register device using
|
| 51 |
-
self._device = torch.device(
|
| 52 |
|
| 53 |
# Store this instance for reuse
|
| 54 |
VGPUDevice._VGPU_INSTANCES[self.device_name] = self
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
from virtual_vram import VirtualVRAM
|
| 9 |
|
| 10 |
+
# Register and rename privateuse1 backend to vgpu
|
| 11 |
+
try:
|
| 12 |
+
torch.utils.rename_privateuse1_backend("vgpu")
|
| 13 |
+
except (AttributeError, RuntimeError) as e:
|
| 14 |
+
# Fallback for older PyTorch versions or if already renamed
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
class VGPUStorage(torch.Storage):
|
| 18 |
"""Custom storage class that uses our virtual VRAM"""
|
| 19 |
|
|
|
|
| 54 |
def _register_device(self):
|
| 55 |
"""Register vGPU device using PyTorch's device system"""
|
| 56 |
try:
|
| 57 |
+
# Register device using privateuse1
|
| 58 |
+
self._device = torch.device(self.internal_name)
|
| 59 |
|
| 60 |
# Store this instance for reuse
|
| 61 |
VGPUDevice._VGPU_INSTANCES[self.device_name] = self
|