Factor Studios commited on
Commit
8051ee2
·
verified ·
1 Parent(s): d654bd0

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +5 -5
torch_vgpu.py CHANGED
@@ -18,21 +18,21 @@ def init_vgpu_backend():
18
  # First rename privateuse1 to vgpu
19
  torch.utils.rename_privateuse1_backend("vgpu")
20
 
21
- # Create library for the backend
22
- lib = Library("vgpu", "IMPL")
23
 
24
  # Register essential operations for the backend
25
- @impl(lib, "aten::empty.memory_format")
26
  def empty_impl(size, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None):
27
  # Create empty tensor on CPU first, will be moved to vGPU storage later
28
  return torch.empty(size, dtype=dtype, device='cpu')
29
 
30
- @impl(lib, "aten::copy.from")
31
  def copy_impl(self, src, non_blocking=False):
32
  # Handle tensor copying between devices
33
  return torch.tensor(src.cpu().numpy(), device='cpu')
34
 
35
- @impl(lib, "aten::_to_copy")
36
  def to_impl(self, dtype=None, device=None, non_blocking=False, copy=False):
37
  # Handle tensor device transfer
38
  if device and str(device).startswith("vgpu"):
 
18
  # First rename privateuse1 to vgpu
19
  torch.utils.rename_privateuse1_backend("vgpu")
20
 
21
+ # Create an aten library implementation for our backend
22
+ lib = Library("aten", "IMPL", "vgpu")
23
 
24
  # Register essential operations for the backend
25
+ @impl(lib, "empty.memory_format")
26
  def empty_impl(size, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None):
27
  # Create empty tensor on CPU first, will be moved to vGPU storage later
28
  return torch.empty(size, dtype=dtype, device='cpu')
29
 
30
+ @impl(lib, "copy.from")
31
  def copy_impl(self, src, non_blocking=False):
32
  # Handle tensor copying between devices
33
  return torch.tensor(src.cpu().numpy(), device='cpu')
34
 
35
+ @impl(lib, "_to_copy")
36
  def to_impl(self, dtype=None, device=None, non_blocking=False, copy=False):
37
  # Handle tensor device transfer
38
  if device and str(device).startswith("vgpu"):