Factor Studios commited on
Commit
705a65c
·
verified ·
1 Parent(s): 8051ee2

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +8 -8
torch_vgpu.py CHANGED
@@ -15,11 +15,8 @@ def init_vgpu_backend():
15
  global VGPU_BACKEND_INITIALIZED
16
  try:
17
  if not VGPU_BACKEND_INITIALIZED:
18
- # First rename privateuse1 to vgpu
19
- torch.utils.rename_privateuse1_backend("vgpu")
20
-
21
- # Create an aten library implementation for our backend
22
- lib = Library("aten", "IMPL", "vgpu")
23
 
24
  # Register essential operations for the backend
25
  @impl(lib, "empty.memory_format")
@@ -35,9 +32,12 @@ def init_vgpu_backend():
35
  @impl(lib, "_to_copy")
36
  def to_impl(self, dtype=None, device=None, non_blocking=False, copy=False):
37
  # Handle tensor device transfer
38
- if device and str(device).startswith("vgpu"):
39
  return self.cpu().clone()
40
  return self.clone()
 
 
 
41
 
42
  # Generate all methods for our backend
43
  torch.utils.generate_methods_for_privateuse1_backend(
@@ -97,8 +97,8 @@ class VGPUDevice:
97
  if not VGPU_BACKEND_INITIALIZED:
98
  raise RuntimeError("VGPU backend not properly initialized")
99
 
100
- # Create device with explicit index
101
- self._device = torch.device("vgpu:0")
102
 
103
  # Store this instance for reuse
104
  VGPUDevice._VGPU_INSTANCES[self.device_name] = self
 
15
  global VGPU_BACKEND_INITIALIZED
16
  try:
17
  if not VGPU_BACKEND_INITIALIZED:
18
+ # Create an aten library implementation for our backend with PrivateUse1 dispatch key
19
+ lib = Library("aten", "IMPL", "PrivateUse1")
 
 
 
20
 
21
  # Register essential operations for the backend
22
  @impl(lib, "empty.memory_format")
 
32
  @impl(lib, "_to_copy")
33
  def to_impl(self, dtype=None, device=None, non_blocking=False, copy=False):
34
  # Handle tensor device transfer
35
+ if device and str(device).startswith("privateuse1"):
36
  return self.cpu().clone()
37
  return self.clone()
38
+
39
+ # Finally rename privateuse1 to vgpu after implementing the ops
40
+ torch.utils.rename_privateuse1_backend("vgpu")
41
 
42
  # Generate all methods for our backend
43
  torch.utils.generate_methods_for_privateuse1_backend(
 
97
  if not VGPU_BACKEND_INITIALIZED:
98
  raise RuntimeError("VGPU backend not properly initialized")
99
 
100
+ # Create device with explicit index - use privateuse1 since that's the actual device type
101
+ self._device = torch.device("privateuse1:0")
102
 
103
  # Store this instance for reuse
104
  VGPUDevice._VGPU_INSTANCES[self.device_name] = self