Factor Studios commited on
Commit
cebc24a
·
verified ·
1 Parent(s): def660c

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +35 -10
torch_vgpu.py CHANGED
@@ -7,12 +7,35 @@ from typing import Optional, Union, Tuple
7
  import numpy as np
8
  from virtual_vram import VirtualVRAM
9
 
10
- # Register and rename privateuse1 backend to vgpu
11
- try:
12
- torch.utils.rename_privateuse1_backend("vgpu")
13
- except (AttributeError, RuntimeError) as e:
14
- # Fallback for older PyTorch versions or if already renamed
15
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  class VGPUStorage(torch.Storage):
18
  """Custom storage class that uses our virtual VRAM"""
@@ -47,15 +70,17 @@ class VGPUDevice:
47
  def __init__(self, vram: Optional[VirtualVRAM] = None):
48
  self.vram = vram or VirtualVRAM()
49
  self.tensor_cores = None # Will be initialized when needed
50
- self.internal_name = "privateuse1" # PyTorch backend name
51
- self.device_name = "vgpu" # User-facing device name
52
  self._register_device()
53
 
54
  def _register_device(self):
55
  """Register vGPU device using PyTorch's device system"""
56
  try:
57
- # Register device using privateuse1
58
- self._device = torch.device(self.internal_name)
 
 
 
59
 
60
  # Store this instance for reuse
61
  VGPUDevice._VGPU_INSTANCES[self.device_name] = self
 
7
  import numpy as np
8
  from virtual_vram import VirtualVRAM
9
 
10
+ # Initialize custom backend
11
+ def init_vgpu_backend():
12
+ try:
13
+ # First rename the backend
14
+ torch.utils.rename_privateuse1_backend("vgpu")
15
+
16
+ # Then generate all the necessary methods
17
+ torch.utils.generate_methods_for_privateuse1_backend(
18
+ for_tensor=True,
19
+ for_module=True,
20
+ for_packed_sequence=True,
21
+ for_storage=True
22
+ )
23
+
24
+ # Register our custom library
25
+ lib = Library("vgpu", "DEF")
26
+ lib.define("custom_op(Tensor self) -> Tensor")
27
+
28
+ @impl("vgpu", "custom_op", "Tensor")
29
+ def custom_op_impl(tensor):
30
+ return tensor.clone()
31
+
32
+ return True
33
+ except Exception as e:
34
+ print(f"Backend initialization warning: {e}")
35
+ return False
36
+
37
+ # Initialize the backend
38
+ VGPU_BACKEND_INITIALIZED = init_vgpu_backend()
39
 
40
  class VGPUStorage(torch.Storage):
41
  """Custom storage class that uses our virtual VRAM"""
 
70
  def __init__(self, vram: Optional[VirtualVRAM] = None):
71
  self.vram = vram or VirtualVRAM()
72
  self.tensor_cores = None # Will be initialized when needed
73
+ self.device_name = "vgpu" # Both internal and user-facing name
 
74
  self._register_device()
75
 
76
  def _register_device(self):
77
  """Register vGPU device using PyTorch's device system"""
78
  try:
79
+ if not VGPU_BACKEND_INITIALIZED:
80
+ raise RuntimeError("VGPU backend not properly initialized")
81
+
82
+ # Create device with explicit index
83
+ self._device = torch.device("vgpu")
84
 
85
  # Store this instance for reuse
86
  VGPUDevice._VGPU_INSTANCES[self.device_name] = self