Factor Studios commited on
Commit
4023c98
·
verified ·
1 Parent(s): 2548b28

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +13 -12
torch_vgpu.py CHANGED
@@ -1,6 +1,3 @@
1
- """
2
- Custom PyTorch device implementation that routes operations through our virtual GPU.
3
- """
4
  import torch
5
  from torch.library import Library, impl
6
  from typing import Optional, Union, Tuple
@@ -26,7 +23,7 @@ def init_vgpu_backend():
26
 
27
  @impl(impl_lib, "custom_allocate")
28
  def custom_allocate(device=None):
29
- return torch.empty((), device='cpu')
30
 
31
  @impl(impl_lib, "custom_to_cpu")
32
  def custom_to_cpu(tensor):
@@ -36,28 +33,31 @@ def init_vgpu_backend():
36
  def custom_from_cpu(tensor):
37
  return tensor.clone()
38
 
39
- # Register our device type
40
- torch._C._register_device_type("vgpu", 10) # Use a custom type ID
 
 
 
 
 
41
 
42
- # Mark initialization as complete
43
  VGPU_BACKEND_INITIALIZED = True
44
 
45
  return VGPU_BACKEND_INITIALIZED
46
  except Exception as e:
47
  print(f"Backend initialization warning: {e}")
48
  return False
49
-
50
 
51
  class VGPUStorage(torch.Storage):
52
  """Custom storage class that uses our virtual VRAM"""
53
 
54
  def __init__(self, *args, **kwargs):
55
  super().__init__(*args, **kwargs)
56
- self.vram = kwargs.get('vram')
57
  if not self.vram:
58
  from virtual_vram import VirtualVRAM
59
  self.vram = VirtualVRAM()
60
- self.tensor_id = kwargs.get('tensor_id', f"tensor_{id(self)}")
61
 
62
  def _new_shared(self, size):
63
  return VGPUStorage(size, vram=self.vram)
@@ -81,7 +81,7 @@ class VGPUDevice:
81
  def __init__(self, vram: Optional[VirtualVRAM] = None):
82
  self.vram = vram or VirtualVRAM()
83
  self.tensor_cores = None # Will be initialized when needed
84
- self.device_name = "vgpu" # Our registered device type
85
  self._register_device()
86
 
87
  def _register_device(self):
@@ -91,7 +91,7 @@ class VGPUDevice:
91
  raise RuntimeError("VGPU backend not properly initialized")
92
 
93
  # Create device using our registered device type
94
- self._device = torch.device("vgpu")
95
 
96
  # Store this instance for reuse
97
  VGPUDevice._VGPU_INSTANCES[self.device_name] = self
@@ -187,3 +187,4 @@ def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.T
187
  # Set the device using the internal name
188
  result.data = result.data.to(device._device)
189
  return result
 
 
 
 
 
1
  import torch
2
  from torch.library import Library, impl
3
  from typing import Optional, Union, Tuple
 
23
 
24
  @impl(impl_lib, "custom_allocate")
25
  def custom_allocate(device=None):
26
+ return torch.empty((), device="cpu")
27
 
28
  @impl(impl_lib, "custom_to_cpu")
29
  def custom_to_cpu(tensor):
 
33
  def custom_from_cpu(tensor):
34
  return tensor.clone()
35
 
36
+ # Generate all methods for our backend
37
+ torch.utils.generate_methods_for_privateuse1_backend(
38
+ for_tensor=True,
39
+ for_module=True,
40
+ for_packed_sequence=True,
41
+ for_storage=True
42
+ )
43
 
 
44
  VGPU_BACKEND_INITIALIZED = True
45
 
46
  return VGPU_BACKEND_INITIALIZED
47
  except Exception as e:
48
  print(f"Backend initialization warning: {e}")
49
  return False
 
50
 
51
  class VGPUStorage(torch.Storage):
52
  """Custom storage class that uses our virtual VRAM"""
53
 
54
  def __init__(self, *args, **kwargs):
55
  super().__init__(*args, **kwargs)
56
+ self.vram = kwargs.get("vram")
57
  if not self.vram:
58
  from virtual_vram import VirtualVRAM
59
  self.vram = VirtualVRAM()
60
+ self.tensor_id = kwargs.get("tensor_id", f"tensor_{id(self)}")
61
 
62
  def _new_shared(self, size):
63
  return VGPUStorage(size, vram=self.vram)
 
81
  def __init__(self, vram: Optional[VirtualVRAM] = None):
82
  self.vram = vram or VirtualVRAM()
83
  self.tensor_cores = None # Will be initialized when needed
84
+ self.device_name = "privateuseone" # Our registered device type
85
  self._register_device()
86
 
87
  def _register_device(self):
 
91
  raise RuntimeError("VGPU backend not properly initialized")
92
 
93
  # Create device using our registered device type
94
+ self._device = torch.device(self.device_name)
95
 
96
  # Store this instance for reuse
97
  VGPUDevice._VGPU_INSTANCES[self.device_name] = self
 
187
  # Set the device using the internal name
188
  result.data = result.data.to(device._device)
189
  return result
190
+