Factor Studios commited on
Commit
319082b
·
verified ·
1 Parent(s): 8aea612

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +82 -93
torch_vgpu.py CHANGED
@@ -4,64 +4,43 @@ from typing import Optional, Union, Tuple
4
  import numpy as np
5
  from virtual_vram import VirtualVRAM
6
 
7
- # Global variables for backend state
8
  VGPU_BACKEND_INITIALIZED = False
9
- CURRENT_VRAM = None # Global reference to current vRAM manager
10
-
11
- def set_current_vram(vram):
12
- """Set the current vRAM manager globally"""
13
- global CURRENT_VRAM
14
- CURRENT_VRAM = vram
15
-
16
- def get_current_vram():
17
- """Get the current vRAM manager"""
18
- return CURRENT_VRAM
19
-
20
- def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
21
- """Move a tensor to vGPU memory"""
22
- if vram is None:
23
- vram = get_current_vram()
24
- if vram is None:
25
- raise RuntimeError("No vRAM manager available. Initialize VGPUDevice first.")
26
-
27
- # Get data and store in vRAM
28
- cpu_data = tensor.detach().cpu().numpy()
29
- tensor_id = f"tensor_{id(tensor)}"
30
- vram.store(tensor_id, cpu_data)
31
-
32
- # Create vGPU tensor
33
- device = torch.device("privateuseone")
34
- vgpu_storage = VGPUStorage(
35
- cpu_data.size,
36
- vram=vram,
37
- tensor_id=tensor_id
38
- )
39
- vgpu_tensor = torch.tensor(
40
- [],
41
- device=device,
42
- requires_grad=tensor.requires_grad
43
- )
44
- vgpu_tensor.set_(vgpu_storage)
45
-
46
- return vgpu_tensor
47
 
48
  def init_vgpu_backend():
49
  """Initialize the vGPU backend. Must be called before creating any VGPUDevice instances."""
50
  global VGPU_BACKEND_INITIALIZED
51
  try:
52
  if not VGPU_BACKEND_INITIALIZED:
53
- # Create library for custom ops
54
  lib = Library("vgpu", "DEF")
55
- lib.define("custom_from_cpu(Tensor x) -> Tensor")
56
-
57
- impl_lib = Library("vgpu", "IMPL")
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  @impl(impl_lib, "custom_from_cpu")
60
- def custom_from_cpu(x):
61
- """Copy tensor to our vGPU memory"""
62
- return x.clone()
63
 
64
- # Set initialization flag
 
 
 
 
 
 
 
65
  VGPU_BACKEND_INITIALIZED = True
66
 
67
  return VGPU_BACKEND_INITIALIZED
@@ -89,17 +68,6 @@ class VGPUTensor:
89
  def __new__(cls, elem):
90
  return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
91
 
92
- from contextlib import contextmanager
93
-
94
- # Custom allocator for vGPU tensors
95
- class VGPUAllocator:
96
- def __init__(self, vram):
97
- self.vram = vram
98
-
99
- def __call__(self, size, dtype=None, device=None):
100
- cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
101
- return to_vgpu(cpu_tensor, self.vram)
102
-
103
  class VGPUDevice:
104
  """
105
  Custom PyTorch device implementation that routes operations through vGPU.
@@ -111,49 +79,69 @@ class VGPUDevice:
111
  _VGPU_INSTANCES = {} # Class-level dict to track instances
112
 
113
  def __init__(self, vram: Optional[VirtualVRAM] = None):
114
- """Initialize a vGPU device with optional vRAM manager"""
115
  self.vram = vram or VirtualVRAM()
116
- self.device_name = "privateuseone" # Our device type
117
- self._init_device()
 
118
 
119
- def _init_device(self):
120
- """Initialize the device backend and settings"""
121
- if not VGPU_BACKEND_INITIALIZED:
122
- raise RuntimeError("VGPU backend not properly initialized")
 
 
 
 
123
 
124
- # Setup device and global vRAM
125
- self._device = torch.device(self.device_name)
126
- set_current_vram(self.vram)
127
-
128
- # Register instance
129
- VGPUDevice._VGPU_INSTANCES[self.device_name] = self
130
-
131
- # Setup allocator
132
- self._allocator = VGPUAllocator(self.vram)
133
 
134
- def device(self) -> torch.device:
135
- """Get the PyTorch device object for this vGPU"""
136
- return self._device
137
-
138
- @contextmanager
139
- def mode(self):
140
- """Context manager for using this device as the default"""
141
- prev_device = torch.device("cpu")
142
- try:
143
- prev_device = torch.cuda.current_device() if torch.cuda.is_available() else prev_device
144
- torch.set_device(self._device)
145
- yield
146
- finally:
147
- torch.set_device(prev_device)
148
 
 
 
 
 
 
 
 
 
 
 
149
  def __str__(self):
150
- """String representation of the device"""
151
- return f"{self.device_name}:0"
152
 
153
  def __repr__(self):
154
- """Detailed string representation"""
155
- return f"vgpu(device='{self.device_name}:0')"
156
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  return tensor_id
158
 
159
  def _from_vram(self, tensor_id: str) -> torch.Tensor:
@@ -199,3 +187,4 @@ def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.T
199
  # Set the device using the internal name
200
  result.data = result.data.to(device._device)
201
  return result
 
 
4
  import numpy as np
5
  from virtual_vram import VirtualVRAM
6
 
7
+ # Global flag for backend initialization
8
  VGPU_BACKEND_INITIALIZED = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def init_vgpu_backend():
11
  """Initialize the vGPU backend. Must be called before creating any VGPUDevice instances."""
12
  global VGPU_BACKEND_INITIALIZED
13
  try:
14
  if not VGPU_BACKEND_INITIALIZED:
15
+ # First define our core library
16
  lib = Library("vgpu", "DEF")
17
+ lib.define("custom_allocate(Device? device) -> Tensor")
18
+ lib.define("custom_to_cpu(Tensor self) -> Tensor")
19
+ lib.define("custom_from_cpu(Tensor self) -> Tensor")
20
+
21
+ # Then implement the operations
22
+ impl_lib = Library("vgpu", "IMPL", "PrivateUse1")
23
+
24
+ @impl(impl_lib, "custom_allocate")
25
+ def custom_allocate(device=None):
26
+ return torch.empty((), device="cpu")
27
+
28
+ @impl(impl_lib, "custom_to_cpu")
29
+ def custom_to_cpu(tensor):
30
+ return tensor.clone()
31
 
32
  @impl(impl_lib, "custom_from_cpu")
33
+ def custom_from_cpu(tensor):
34
+ return tensor.clone()
 
35
 
36
+ # Generate all methods for our backend
37
+ torch.utils.generate_methods_for_privateuse1_backend(
38
+ for_tensor=True,
39
+ for_module=True,
40
+ for_packed_sequence=True,
41
+ for_storage=True
42
+ )
43
+
44
  VGPU_BACKEND_INITIALIZED = True
45
 
46
  return VGPU_BACKEND_INITIALIZED
 
68
  def __new__(cls, elem):
69
  return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
70
 
 
 
 
 
 
 
 
 
 
 
 
71
  class VGPUDevice:
72
  """
73
  Custom PyTorch device implementation that routes operations through vGPU.
 
79
  _VGPU_INSTANCES = {} # Class-level dict to track instances
80
 
81
  def __init__(self, vram: Optional[VirtualVRAM] = None):
 
82
  self.vram = vram or VirtualVRAM()
83
+ self.tensor_cores = None # Will be initialized when needed
84
+ self.device_name = "privateuseone" # Our registered device type
85
+ self._register_device()
86
 
87
+ def _register_device(self):
88
+ """Register vGPU device using PyTorch's device system"""
89
+ try:
90
+ if not VGPU_BACKEND_INITIALIZED:
91
+ raise RuntimeError("VGPU backend not properly initialized")
92
+
93
+ # Create device using our registered device type
94
+ self._device = torch.device(self.device_name)
95
 
96
+ # Store this instance for reuse
97
+ VGPUDevice._VGPU_INSTANCES[self.device_name] = self
 
 
 
 
 
 
 
98
 
99
+ # Define custom operations for the device
100
+ class VGPUAllocator:
101
+ def __init__(self, vram, device):
102
+ self.vram = vram
103
+ self.device = device
104
+
105
+ def __call__(self, size, dtype=None, device=None):
106
+ # Create tensor on CPU first
107
+ cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
108
+ # Move to vGPU storage
109
+ return to_vgpu(cpu_tensor, self.vram)
 
 
 
110
 
111
+ # Set up allocator
112
+ self._allocator = VGPUAllocator(self.vram, self._device)
113
+
114
+ except Exception as e:
115
+ raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
116
+
117
+ @property
118
+ def type(self):
119
+ return self.internal_name
120
+
121
  def __str__(self):
122
+ return f"{self.internal_name}:0"
 
123
 
124
  def __repr__(self):
125
+ return f"vgpu(device='{self.internal_name}:0')"
126
+
127
+ def device(self):
128
+ """Get the PyTorch device object that maps to our vGPU"""
129
+ return self._device # Return the already created device object
130
+
131
+ def mode(self):
132
+ """Get a context manager for vGPU operations"""
133
+ return torch.device(self._device)
134
+
135
+ def _init_tensor_cores(self):
136
+ if self.tensor_cores is None:
137
+ from tensor_core import TensorCoreArray
138
+ self.tensor_cores = TensorCoreArray()
139
+
140
+ def _to_vram(self, tensor: torch.Tensor) -> str:
141
+ """Store tensor data in virtual VRAM"""
142
+ tensor_id = f"tensor_{id(tensor)}"
143
+ data = tensor.detach().cpu().numpy()
144
+ self.vram.storage.store_tensor(tensor_id, data)
145
  return tensor_id
146
 
147
  def _from_vram(self, tensor_id: str) -> torch.Tensor:
 
187
  # Set the device using the internal name
188
  result.data = result.data.to(device._device)
189
  return result
190
+