Factor Studios commited on
Commit
962c8c7
·
verified ·
1 Parent(s): 560e47a

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +183 -75
torch_vgpu.py CHANGED
@@ -12,28 +12,11 @@ def init_vgpu_backend():
12
  global VGPU_BACKEND_INITIALIZED
13
  try:
14
  if not VGPU_BACKEND_INITIALIZED:
15
- # First define our core library
16
- lib = Library("vgpu", "DEF")
17
- lib.define("custom_allocate(Device? device) -> Tensor")
18
- lib.define("custom_to_cpu(Tensor self) -> Tensor")
19
- lib.define("custom_from_cpu(Tensor self) -> Tensor")
20
-
21
- # Then implement the operations
22
- impl_lib = Library("vgpu", "IMPL", "PrivateUse1")
23
-
24
- @impl(impl_lib, "custom_allocate")
25
- def custom_allocate(device=None):
26
- return torch.empty((), device="cpu")
27
-
28
- @impl(impl_lib, "custom_to_cpu")
29
- def custom_to_cpu(tensor):
30
- return tensor.clone()
31
-
32
- @impl(impl_lib, "custom_from_cpu")
33
- def custom_from_cpu(tensor):
34
- return tensor.clone()
35
-
36
- # Generate all methods for our backend
37
  torch.utils.generate_methods_for_privateuse1_backend(
38
  for_tensor=True,
39
  for_module=True,
@@ -41,47 +24,133 @@ def init_vgpu_backend():
41
  for_storage=True
42
  )
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  VGPU_BACKEND_INITIALIZED = True
45
 
46
  return VGPU_BACKEND_INITIALIZED
47
  except Exception as e:
48
- print(f"Backend initialization warning: {e}")
 
 
49
  return False
50
 
51
  class VGPUStorage(torch.Storage):
52
  """Custom storage class that uses our virtual VRAM"""
53
 
54
  def __init__(self, *args, **kwargs):
 
 
 
 
55
  super().__init__(*args, **kwargs)
56
- self.vram = kwargs.get("vram")
57
  if not self.vram:
58
- from virtual_vram import VirtualVRAM
59
  self.vram = VirtualVRAM()
60
- self.tensor_id = kwargs.get("tensor_id", f"tensor_{id(self)}")
 
61
 
62
  def _new_shared(self, size):
63
  return VGPUStorage(size, vram=self.vram)
64
 
65
- class VGPUTensor:
66
  """Tensor implementation that uses vGPU for computations"""
 
67
  @staticmethod
68
- def __new__(cls, elem):
69
- return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
 
 
 
 
 
 
 
 
 
 
70
 
71
  class VGPUDevice:
72
  """
73
  Custom PyTorch device implementation that routes operations through vGPU.
74
  Usage:
75
  vgpu = VGPUDevice()
76
- with vgpu.mode():
77
- tensor = torch.randn(2, 3) # Will be on vGPU
78
  """
79
  _VGPU_INSTANCES = {} # Class-level dict to track instances
80
 
81
  def __init__(self, vram: Optional[VirtualVRAM] = None):
 
 
 
 
82
  self.vram = vram or VirtualVRAM()
83
  self.tensor_cores = None # Will be initialized when needed
84
- self.device_name = "privateuseone" # Our registered device type
85
  self._register_device()
86
 
87
  def _register_device(self):
@@ -91,57 +160,53 @@ class VGPUDevice:
91
  raise RuntimeError("VGPU backend not properly initialized")
92
 
93
  # Create device using our registered device type
94
- self._device = torch.device(self.device_name)
95
 
96
  # Store this instance for reuse
97
  VGPUDevice._VGPU_INSTANCES[self.device_name] = self
98
 
99
- # Define custom operations for the device
100
- class VGPUAllocator:
101
- def __init__(self, vram, device):
102
- self.vram = vram
103
- self.device = device
104
-
105
- def __call__(self, size, dtype=None, device=None):
106
- # Create tensor directly in vGPU memory
107
- tensor_id = f"tensor_empty_{id(size)}"
108
- # Initialize empty array of the right size and dtype
109
- shape = size if isinstance(size, (tuple, list)) else (size,)
110
- data = np.empty(shape, dtype=np.float32 if dtype is None else dtype)
111
- # Store directly in vRAM
112
- self.vram.storage.store_tensor(tensor_id, data)
113
- # Create tensor with our device type
114
- result = torch.as_tensor(data, device=self.device)
115
- return result
116
-
117
- # Set up allocator
118
- self._allocator = VGPUAllocator(self.vram, self._device)
119
-
120
  except Exception as e:
121
  raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
122
 
123
  @property
124
  def type(self):
125
- return self.internal_name
126
 
127
  def __str__(self):
128
- return f"{self.internal_name}:0"
129
 
130
  def __repr__(self):
131
- return f"vgpu(device='{self.internal_name}:0')"
132
 
133
  def device(self):
134
  """Get the PyTorch device object that maps to our vGPU"""
135
- return self._device # Return the already created device object
136
 
137
- def mode(self):
138
  """Get a context manager for vGPU operations"""
139
- return torch.device(self._device)
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  def _init_tensor_cores(self):
142
  if self.tensor_cores is None:
143
- from tensor_core import TensorCoreArray
144
- self.tensor_cores = TensorCoreArray()
 
 
 
 
145
 
146
  def _to_vram(self, tensor: torch.Tensor) -> str:
147
  """Store tensor data in virtual VRAM"""
@@ -163,14 +228,21 @@ class VGPUDevice:
163
  a_id = self._to_vram(a)
164
  b_id = self._to_vram(b)
165
 
166
- # Perform matmul using tensor cores
167
- result = self.tensor_cores.matmul(
168
- self.vram.storage.load_tensor(a_id),
169
- self.vram.storage.load_tensor(b_id)
170
- )
 
 
 
 
 
 
171
 
172
  # Create new tensor with result
173
- return torch.from_numpy(result)
 
174
 
175
  def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
176
  """Move a tensor to vGPU device"""
@@ -185,11 +257,47 @@ def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.T
185
  if vram is not None:
186
  device.vram = vram
187
 
188
- # Move data to vRAM
189
- tensor_id = device._to_vram(tensor)
190
- result = device._from_vram(tensor_id)
191
- result.requires_grad = tensor.requires_grad
 
 
 
 
 
 
 
 
 
 
192
 
193
- # Set the device using the internal name
194
- result.data = result.data.to(device._device)
195
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  global VGPU_BACKEND_INITIALIZED
13
  try:
14
  if not VGPU_BACKEND_INITIALIZED:
15
+ # Step 1: Register the backend name using PrivateUse1
16
+ backend_name = "vgpu"
17
+ torch._C._dispatch._rename_privateuse1_backend(backend_name)
18
+
19
+ # Step 2: Generate methods for the backend
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  torch.utils.generate_methods_for_privateuse1_backend(
21
  for_tensor=True,
22
  for_module=True,
 
24
  for_storage=True
25
  )
26
 
27
+ # Step 3: Define and implement core operations
28
+ lib = Library(backend_name, "DEF")
29
+ impl_lib = Library(backend_name, "IMPL", "PrivateUse1")
30
+
31
+ # Define core tensor operations
32
+ lib.define("empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor")
33
+ lib.define("empty_strided(int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor")
34
+ lib.define("copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)")
35
+
36
+ @impl(impl_lib, "empty.memory_format")
37
+ def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None):
38
+ # Create tensor on CPU first, then we'll handle device placement
39
+ dtype = dtype or torch.float32
40
+ cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
41
+ # Mark it as being on our custom device
42
+ return cpu_tensor
43
+
44
+ @impl(impl_lib, "empty_strided")
45
+ def empty_strided(size, stride, dtype=None, layout=None, device=None, pin_memory=None):
46
+ dtype = dtype or torch.float32
47
+ # Create strided tensor on CPU
48
+ cpu_tensor = torch.empty_strided(size, stride, dtype=dtype, device='cpu')
49
+ return cpu_tensor
50
+
51
+ @impl(impl_lib, "copy_")
52
+ def copy_impl(self, src, non_blocking=False):
53
+ # Handle copying between devices
54
+ if src.device.type == 'cpu':
55
+ # Copy from CPU to vGPU
56
+ self.data.copy_(src.data)
57
+ elif src.device.type == backend_name:
58
+ # Copy from vGPU to vGPU
59
+ self.data.copy_(src.data)
60
+ else:
61
+ # Copy from other device to vGPU
62
+ cpu_src = src.cpu()
63
+ self.data.copy_(cpu_src.data)
64
+ return self
65
+
66
+ # Register device guard
67
+ class VGPUGuard:
68
+ def __init__(self, device):
69
+ self.device = device
70
+ self.prev_device = None
71
+
72
+ def __enter__(self):
73
+ # Store current device state
74
+ self.prev_device = torch.cuda.current_device() if torch.cuda.is_available() else None
75
+ return self
76
+
77
+ def __exit__(self, exc_type, exc_val, exc_tb):
78
+ # Restore previous device state
79
+ if self.prev_device is not None and torch.cuda.is_available():
80
+ torch.cuda.set_device(self.prev_device)
81
+
82
+ # Register allocator functions
83
+ def vgpu_allocator(size, dtype=None, device=None):
84
+ """Custom allocator for vGPU tensors"""
85
+ dtype = dtype or torch.float32
86
+ # Create on CPU but track as vGPU
87
+ tensor = torch.empty(size, dtype=dtype, device='cpu')
88
+ return tensor
89
+
90
+ # Register the allocator
91
+ torch._C._set_print_device_type(backend_name, True)
92
+
93
  VGPU_BACKEND_INITIALIZED = True
94
 
95
  return VGPU_BACKEND_INITIALIZED
96
  except Exception as e:
97
+ print(f"Backend initialization error: {e}")
98
+ import traceback
99
+ traceback.print_exc()
100
  return False
101
 
102
  class VGPUStorage(torch.Storage):
103
  """Custom storage class that uses our virtual VRAM"""
104
 
105
  def __init__(self, *args, **kwargs):
106
+ # Extract our custom kwargs before calling parent
107
+ self.vram = kwargs.pop("vram", None)
108
+ self.tensor_id = kwargs.pop("tensor_id", None)
109
+
110
  super().__init__(*args, **kwargs)
111
+
112
  if not self.vram:
 
113
  self.vram = VirtualVRAM()
114
+ if not self.tensor_id:
115
+ self.tensor_id = f"tensor_{id(self)}"
116
 
117
  def _new_shared(self, size):
118
  return VGPUStorage(size, vram=self.vram)
119
 
120
+ class VGPUTensor(torch.Tensor):
121
  """Tensor implementation that uses vGPU for computations"""
122
+
123
  @staticmethod
124
+ def __new__(cls, data, device=None, requires_grad=False):
125
+ # Ensure we have a proper tensor
126
+ if not isinstance(data, torch.Tensor):
127
+ data = torch.as_tensor(data)
128
+
129
+ # Create the subclass
130
+ r = torch.Tensor._make_subclass(cls, data, requires_grad)
131
+ return r
132
+
133
+ def __init__(self, data, device=None, requires_grad=False):
134
+ super().__init__()
135
+ self._vgpu_device = device
136
 
137
  class VGPUDevice:
138
  """
139
  Custom PyTorch device implementation that routes operations through vGPU.
140
  Usage:
141
  vgpu = VGPUDevice()
142
+ tensor = torch.randn(2, 3, device=vgpu.device())
 
143
  """
144
  _VGPU_INSTANCES = {} # Class-level dict to track instances
145
 
146
  def __init__(self, vram: Optional[VirtualVRAM] = None):
147
+ # Initialize backend first
148
+ if not init_vgpu_backend():
149
+ raise RuntimeError("Failed to initialize vGPU backend")
150
+
151
  self.vram = vram or VirtualVRAM()
152
  self.tensor_cores = None # Will be initialized when needed
153
+ self.device_name = "vgpu" # Our registered device type
154
  self._register_device()
155
 
156
  def _register_device(self):
 
160
  raise RuntimeError("VGPU backend not properly initialized")
161
 
162
  # Create device using our registered device type
163
+ self._device = torch.device(f"{self.device_name}:0")
164
 
165
  # Store this instance for reuse
166
  VGPUDevice._VGPU_INSTANCES[self.device_name] = self
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  except Exception as e:
169
  raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
170
 
171
  @property
172
  def type(self):
173
+ return self.device_name
174
 
175
  def __str__(self):
176
+ return f"{self.device_name}:0"
177
 
178
  def __repr__(self):
179
+ return f"vgpu(device='{self.device_name}:0')"
180
 
181
  def device(self):
182
  """Get the PyTorch device object that maps to our vGPU"""
183
+ return self._device
184
 
185
+ def context(self):
186
  """Get a context manager for vGPU operations"""
187
+ class VGPUContext:
188
+ def __init__(self, device):
189
+ self.device = device
190
+ self.prev_device = None
191
+
192
+ def __enter__(self):
193
+ # Could store previous device context here
194
+ return self.device
195
+
196
+ def __exit__(self, exc_type, exc_val, exc_tb):
197
+ # Could restore previous device context here
198
+ pass
199
+
200
+ return VGPUContext(self._device)
201
 
202
  def _init_tensor_cores(self):
203
  if self.tensor_cores is None:
204
+ try:
205
+ from tensor_core import TensorCoreArray
206
+ self.tensor_cores = TensorCoreArray()
207
+ except ImportError:
208
+ print("Warning: tensor_core module not available")
209
+ self.tensor_cores = None
210
 
211
  def _to_vram(self, tensor: torch.Tensor) -> str:
212
  """Store tensor data in virtual VRAM"""
 
228
  a_id = self._to_vram(a)
229
  b_id = self._to_vram(b)
230
 
231
+ # Perform matmul using tensor cores if available
232
+ if self.tensor_cores:
233
+ result = self.tensor_cores.matmul(
234
+ self.vram.storage.load_tensor(a_id),
235
+ self.vram.storage.load_tensor(b_id)
236
+ )
237
+ else:
238
+ # Fallback to numpy
239
+ a_data = self.vram.storage.load_tensor(a_id)
240
+ b_data = self.vram.storage.load_tensor(b_id)
241
+ result = np.matmul(a_data, b_data)
242
 
243
  # Create new tensor with result
244
+ result_tensor = torch.from_numpy(result)
245
+ return result_tensor.to(self._device)
246
 
247
  def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
248
  """Move a tensor to vGPU device"""
 
257
  if vram is not None:
258
  device.vram = vram
259
 
260
+ # Move tensor to vGPU device
261
+ return tensor.to(device.device())
262
+
263
+ # Convenience function for creating tensors directly on vGPU
264
+ def vgpu_tensor(*args, **kwargs):
265
+ """Create a tensor directly on vGPU device"""
266
+ # Remove device from kwargs if present
267
+ kwargs.pop('device', None)
268
+
269
+ # Get or create vGPU device
270
+ if not VGPUDevice._VGPU_INSTANCES:
271
+ device = VGPUDevice()
272
+ else:
273
+ device = next(iter(VGPUDevice._VGPU_INSTANCES.values()))
274
 
275
+ # Create tensor on vGPU
276
+ return torch.tensor(*args, device=device.device(), **kwargs)
277
+
278
+ # Example usage and testing
279
+ if __name__ == "__main__":
280
+ # Initialize the backend
281
+ if init_vgpu_backend():
282
+ print("✓ vGPU backend initialized successfully")
283
+
284
+ # Create vGPU device
285
+ vgpu = VGPUDevice()
286
+ print(f"✓ vGPU device created: {vgpu}")
287
+
288
+ # Test tensor creation
289
+ try:
290
+ x = torch.randn(2, 3, device=vgpu.device())
291
+ print(f"✓ Tensor created on {x.device}: shape {x.shape}")
292
+
293
+ # Test tensor operations
294
+ y = torch.randn(3, 4, device=vgpu.device())
295
+ z = torch.mm(x, y)
296
+ print(f"✓ Matrix multiplication result shape: {z.shape}")
297
+
298
+ except Exception as e:
299
+ print(f"✗ Tensor operation failed: {e}")
300
+ import traceback
301
+ traceback.print_exc()
302
+ else:
303
+ print("✗ Failed to initialize vGPU backend")