Spaces:
Runtime error
Runtime error
Factor Studios
commited on
Update torch_vgpu.py
Browse files- torch_vgpu.py +5 -5
torch_vgpu.py
CHANGED
|
@@ -18,21 +18,21 @@ def init_vgpu_backend():
|
|
| 18 |
# First rename privateuse1 to vgpu
|
| 19 |
torch.utils.rename_privateuse1_backend("vgpu")
|
| 20 |
|
| 21 |
-
# Create library for
|
| 22 |
-
lib = Library("
|
| 23 |
|
| 24 |
# Register essential operations for the backend
|
| 25 |
-
@impl(lib, "
|
| 26 |
def empty_impl(size, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None):
|
| 27 |
# Create empty tensor on CPU first, will be moved to vGPU storage later
|
| 28 |
return torch.empty(size, dtype=dtype, device='cpu')
|
| 29 |
|
| 30 |
-
@impl(lib, "
|
| 31 |
def copy_impl(self, src, non_blocking=False):
|
| 32 |
# Handle tensor copying between devices
|
| 33 |
return torch.tensor(src.cpu().numpy(), device='cpu')
|
| 34 |
|
| 35 |
-
@impl(lib, "
|
| 36 |
def to_impl(self, dtype=None, device=None, non_blocking=False, copy=False):
|
| 37 |
# Handle tensor device transfer
|
| 38 |
if device and str(device).startswith("vgpu"):
|
|
|
|
| 18 |
# First rename privateuse1 to vgpu
|
| 19 |
torch.utils.rename_privateuse1_backend("vgpu")
|
| 20 |
|
| 21 |
+
# Create an aten library implementation for our backend
|
| 22 |
+
lib = Library("aten", "IMPL", "vgpu")
|
| 23 |
|
| 24 |
# Register essential operations for the backend
|
| 25 |
+
@impl(lib, "empty.memory_format")
|
| 26 |
def empty_impl(size, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None):
|
| 27 |
# Create empty tensor on CPU first, will be moved to vGPU storage later
|
| 28 |
return torch.empty(size, dtype=dtype, device='cpu')
|
| 29 |
|
| 30 |
+
@impl(lib, "copy.from")
|
| 31 |
def copy_impl(self, src, non_blocking=False):
|
| 32 |
# Handle tensor copying between devices
|
| 33 |
return torch.tensor(src.cpu().numpy(), device='cpu')
|
| 34 |
|
| 35 |
+
@impl(lib, "_to_copy")
|
| 36 |
def to_impl(self, dtype=None, device=None, non_blocking=False, copy=False):
|
| 37 |
# Handle tensor device transfer
|
| 38 |
if device and str(device).startswith("vgpu"):
|