Factor Studios commited on
Commit
2548b28
·
verified ·
1 Parent(s): d2fdc72

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +24 -31
torch_vgpu.py CHANGED
@@ -15,44 +15,38 @@ def init_vgpu_backend():
15
  global VGPU_BACKEND_INITIALIZED
16
  try:
17
  if not VGPU_BACKEND_INITIALIZED:
18
- # Create an aten library implementation for our backend with PrivateUse1 dispatch key
19
- lib = Library("aten", "IMPL", "PrivateUse1")
 
 
 
20
 
21
- # Register essential operations for the backend
22
- @impl(lib, "empty.memory_format")
23
- def empty_impl(size, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None):
24
- # Create empty tensor on CPU first, will be moved to vGPU storage later
25
- return torch.empty(size, dtype=dtype, device='cpu')
26
 
27
- @impl(lib, "copy.from")
28
- def copy_impl(self, src, non_blocking=False):
29
- # Handle tensor copying between devices
30
- return torch.tensor(src.cpu().numpy(), device='cpu')
31
 
32
- @impl(lib, "_to_copy")
33
- def to_impl(self, dtype=None, device=None, non_blocking=False, copy=False):
34
- # Handle tensor device transfer
35
- if device and str(device).startswith("privateuse1"):
36
- return self.cpu().clone()
37
- return self.clone()
38
-
39
- # Finally rename privateuse1 to vgpu after implementing the ops
40
- torch.utils.rename_privateuse1_backend("vgpu")
41
 
42
- # Generate all methods for our backend
43
- torch.utils.generate_methods_for_privateuse1_backend(
44
- for_tensor=True,
45
- for_module=True,
46
- for_packed_sequence=True,
47
- for_storage=True
48
- )
49
 
 
50
  VGPU_BACKEND_INITIALIZED = True
51
 
52
  return VGPU_BACKEND_INITIALIZED
53
  except Exception as e:
54
  print(f"Backend initialization warning: {e}")
55
  return False
 
56
 
57
  class VGPUStorage(torch.Storage):
58
  """Custom storage class that uses our virtual VRAM"""
@@ -87,9 +81,8 @@ class VGPUDevice:
87
  def __init__(self, vram: Optional[VirtualVRAM] = None):
88
  self.vram = vram or VirtualVRAM()
89
  self.tensor_cores = None # Will be initialized when needed
90
- self.device_name = "privateuseone" # Use original name first
91
  self._register_device()
92
- self.device_name = "vgpu" # Then switch to renamed backend
93
 
94
  def _register_device(self):
95
  """Register vGPU device using PyTorch's device system"""
@@ -97,8 +90,8 @@ class VGPUDevice:
97
  if not VGPU_BACKEND_INITIALIZED:
98
  raise RuntimeError("VGPU backend not properly initialized")
99
 
100
- # Create device with explicit index - use privateuse1 since that's the actual device type
101
- self._device = torch.device("privateuse1")
102
 
103
  # Store this instance for reuse
104
  VGPUDevice._VGPU_INSTANCES[self.device_name] = self
 
15
  global VGPU_BACKEND_INITIALIZED
16
  try:
17
  if not VGPU_BACKEND_INITIALIZED:
18
+ # First define our core library
19
+ lib = Library("vgpu", "DEF")
20
+ lib.define("custom_allocate(Device? device) -> Tensor")
21
+ lib.define("custom_to_cpu(Tensor self) -> Tensor")
22
+ lib.define("custom_from_cpu(Tensor self) -> Tensor")
23
 
24
+ # Then implement the operations
25
+ impl_lib = Library("vgpu", "IMPL", "PrivateUse1")
 
 
 
26
 
27
+ @impl(impl_lib, "custom_allocate")
28
+ def custom_allocate(device=None):
29
+ return torch.empty((), device='cpu')
 
30
 
31
+ @impl(impl_lib, "custom_to_cpu")
32
+ def custom_to_cpu(tensor):
33
+ return tensor.clone()
 
 
 
 
 
 
34
 
35
+ @impl(impl_lib, "custom_from_cpu")
36
+ def custom_from_cpu(tensor):
37
+ return tensor.clone()
38
+
39
+ # Register our device type
40
+ torch._C._register_device_type("vgpu", 10) # Use a custom type ID
 
41
 
42
+ # Mark initialization as complete
43
  VGPU_BACKEND_INITIALIZED = True
44
 
45
  return VGPU_BACKEND_INITIALIZED
46
  except Exception as e:
47
  print(f"Backend initialization warning: {e}")
48
  return False
49
+
50
 
51
  class VGPUStorage(torch.Storage):
52
  """Custom storage class that uses our virtual VRAM"""
 
81
  def __init__(self, vram: Optional[VirtualVRAM] = None):
82
  self.vram = vram or VirtualVRAM()
83
  self.tensor_cores = None # Will be initialized when needed
84
+ self.device_name = "vgpu" # Our registered device type
85
  self._register_device()
 
86
 
87
  def _register_device(self):
88
  """Register vGPU device using PyTorch's device system"""
 
90
  if not VGPU_BACKEND_INITIALIZED:
91
  raise RuntimeError("VGPU backend not properly initialized")
92
 
93
+ # Create device using our registered device type
94
+ self._device = torch.device("vgpu")
95
 
96
  # Store this instance for reuse
97
  VGPUDevice._VGPU_INSTANCES[self.device_name] = self