Factor Studios commited on
Commit
e64ebad
Β·
verified Β·
1 Parent(s): 962c8c7

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +263 -228
torch_vgpu.py CHANGED
@@ -3,301 +3,336 @@ from torch.library import Library, impl
3
  from typing import Optional, Union, Tuple
4
  import numpy as np
5
  from virtual_vram import VirtualVRAM
 
6
 
7
  # Global flag for backend initialization
8
  VGPU_BACKEND_INITIALIZED = False
9
 
 
 
 
 
 
10
  def init_vgpu_backend():
11
  """Initialize the vGPU backend. Must be called before creating any VGPUDevice instances."""
12
  global VGPU_BACKEND_INITIALIZED
13
  try:
14
  if not VGPU_BACKEND_INITIALIZED:
15
- # Step 1: Register the backend name using PrivateUse1
16
  backend_name = "vgpu"
17
- torch._C._dispatch._rename_privateuse1_backend(backend_name)
18
-
19
- # Step 2: Generate methods for the backend
20
- torch.utils.generate_methods_for_privateuse1_backend(
21
- for_tensor=True,
22
- for_module=True,
23
- for_packed_sequence=True,
24
- for_storage=True
25
- )
26
-
27
- # Step 3: Define and implement core operations
28
- lib = Library(backend_name, "DEF")
29
- impl_lib = Library(backend_name, "IMPL", "PrivateUse1")
30
-
31
- # Define core tensor operations
32
- lib.define("empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor")
33
- lib.define("empty_strided(int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor")
34
- lib.define("copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)")
35
-
36
- @impl(impl_lib, "empty.memory_format")
37
- def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None):
38
- # Create tensor on CPU first, then we'll handle device placement
39
- dtype = dtype or torch.float32
40
- cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
41
- # Mark it as being on our custom device
42
- return cpu_tensor
43
 
44
- @impl(impl_lib, "empty_strided")
45
- def empty_strided(size, stride, dtype=None, layout=None, device=None, pin_memory=None):
46
- dtype = dtype or torch.float32
47
- # Create strided tensor on CPU
48
- cpu_tensor = torch.empty_strided(size, stride, dtype=dtype, device='cpu')
49
- return cpu_tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- @impl(impl_lib, "copy_")
52
- def copy_impl(self, src, non_blocking=False):
53
- # Handle copying between devices
54
- if src.device.type == 'cpu':
55
- # Copy from CPU to vGPU
56
- self.data.copy_(src.data)
57
- elif src.device.type == backend_name:
58
- # Copy from vGPU to vGPU
59
- self.data.copy_(src.data)
60
- else:
61
- # Copy from other device to vGPU
62
- cpu_src = src.cpu()
63
- self.data.copy_(cpu_src.data)
64
- return self
 
 
 
 
65
 
66
- # Register device guard
67
- class VGPUGuard:
68
- def __init__(self, device):
69
- self.device = device
70
- self.prev_device = None
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- def __enter__(self):
73
- # Store current device state
74
- self.prev_device = torch.cuda.current_device() if torch.cuda.is_available() else None
 
75
  return self
76
 
77
- def __exit__(self, exc_type, exc_val, exc_tb):
78
- # Restore previous device state
79
- if self.prev_device is not None and torch.cuda.is_available():
80
- torch.cuda.set_device(self.prev_device)
81
-
82
- # Register allocator functions
83
- def vgpu_allocator(size, dtype=None, device=None):
84
- """Custom allocator for vGPU tensors"""
85
- dtype = dtype or torch.float32
86
- # Create on CPU but track as vGPU
87
- tensor = torch.empty(size, dtype=dtype, device='cpu')
88
- return tensor
89
-
90
- # Register the allocator
91
- torch._C._set_print_device_type(backend_name, True)
 
 
 
 
92
 
93
  VGPU_BACKEND_INITIALIZED = True
94
 
95
  return VGPU_BACKEND_INITIALIZED
 
96
  except Exception as e:
97
  print(f"Backend initialization error: {e}")
98
  import traceback
99
  traceback.print_exc()
100
  return False
101
 
102
- class VGPUStorage(torch.Storage):
103
- """Custom storage class that uses our virtual VRAM"""
104
 
105
- def __init__(self, *args, **kwargs):
106
- # Extract our custom kwargs before calling parent
107
- self.vram = kwargs.pop("vram", None)
108
- self.tensor_id = kwargs.pop("tensor_id", None)
109
-
110
- super().__init__(*args, **kwargs)
111
-
112
- if not self.vram:
113
- self.vram = VirtualVRAM()
114
- if not self.tensor_id:
115
- self.tensor_id = f"tensor_{id(self)}"
116
 
117
- def _new_shared(self, size):
118
- return VGPUStorage(size, vram=self.vram)
 
 
 
 
 
119
 
120
  class VGPUTensor(torch.Tensor):
121
- """Tensor implementation that uses vGPU for computations"""
122
 
123
- @staticmethod
124
- def __new__(cls, data, device=None, requires_grad=False):
125
- # Ensure we have a proper tensor
126
  if not isinstance(data, torch.Tensor):
127
  data = torch.as_tensor(data)
128
 
129
- # Create the subclass
130
- r = torch.Tensor._make_subclass(cls, data, requires_grad)
 
 
131
  return r
132
 
133
- def __init__(self, data, device=None, requires_grad=False):
134
- super().__init__()
135
- self._vgpu_device = device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  class VGPUDevice:
138
  """
139
  Custom PyTorch device implementation that routes operations through vGPU.
140
  Usage:
141
  vgpu = VGPUDevice()
142
- tensor = torch.randn(2, 3, device=vgpu.device())
143
  """
144
- _VGPU_INSTANCES = {} # Class-level dict to track instances
145
 
146
- def __init__(self, vram: Optional[VirtualVRAM] = None):
147
- # Initialize backend first
148
  if not init_vgpu_backend():
149
- raise RuntimeError("Failed to initialize vGPU backend")
150
 
151
  self.vram = vram or VirtualVRAM()
152
- self.tensor_cores = None # Will be initialized when needed
153
- self.device_name = "vgpu" # Our registered device type
154
- self._register_device()
155
-
156
- def _register_device(self):
157
- """Register vGPU device using PyTorch's device system"""
158
- try:
159
- if not VGPU_BACKEND_INITIALIZED:
160
- raise RuntimeError("VGPU backend not properly initialized")
161
-
162
- # Create device using our registered device type
163
- self._device = torch.device(f"{self.device_name}:0")
164
-
165
- # Store this instance for reuse
166
- VGPUDevice._VGPU_INSTANCES[self.device_name] = self
167
-
168
- except Exception as e:
169
- raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
170
-
171
- @property
172
- def type(self):
173
- return self.device_name
174
 
175
- def __str__(self):
176
- return f"{self.device_name}:0"
177
-
178
- def __repr__(self):
179
- return f"vgpu(device='{self.device_name}:0')"
180
 
 
 
181
  def device(self):
182
- """Get the PyTorch device object that maps to our vGPU"""
183
  return self._device
184
 
185
- def context(self):
186
- """Get a context manager for vGPU operations"""
187
- class VGPUContext:
188
- def __init__(self, device):
189
- self.device = device
190
- self.prev_device = None
191
-
192
- def __enter__(self):
193
- # Could store previous device context here
194
- return self.device
195
-
196
- def __exit__(self, exc_type, exc_val, exc_tb):
197
- # Could restore previous device context here
198
- pass
199
-
200
- return VGPUContext(self._device)
201
-
202
- def _init_tensor_cores(self):
203
- if self.tensor_cores is None:
204
- try:
205
- from tensor_core import TensorCoreArray
206
- self.tensor_cores = TensorCoreArray()
207
- except ImportError:
208
- print("Warning: tensor_core module not available")
209
- self.tensor_cores = None
210
-
211
- def _to_vram(self, tensor: torch.Tensor) -> str:
212
- """Store tensor data in virtual VRAM"""
213
- tensor_id = f"tensor_{id(tensor)}"
214
- data = tensor.detach().cpu().numpy()
215
- self.vram.storage.store_tensor(tensor_id, data)
216
- return tensor_id
217
-
218
- def _from_vram(self, tensor_id: str) -> torch.Tensor:
219
- """Retrieve tensor data from virtual VRAM"""
220
- data = self.vram.storage.load_tensor(tensor_id)
221
- return torch.from_numpy(data)
222
-
223
- def matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
224
- """Matrix multiplication using tensor cores"""
225
- self._init_tensor_cores()
226
 
227
- # Store inputs in VRAM
228
- a_id = self._to_vram(a)
229
- b_id = self._to_vram(b)
230
-
231
- # Perform matmul using tensor cores if available
232
- if self.tensor_cores:
233
- result = self.tensor_cores.matmul(
234
- self.vram.storage.load_tensor(a_id),
235
- self.vram.storage.load_tensor(b_id)
236
- )
237
  else:
238
- # Fallback to numpy
239
- a_data = self.vram.storage.load_tensor(a_id)
240
- b_data = self.vram.storage.load_tensor(b_id)
241
- result = np.matmul(a_data, b_data)
242
 
243
- # Create new tensor with result
244
- result_tensor = torch.from_numpy(result)
245
- return result_tensor.to(self._device)
246
-
247
- def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
248
- """Move a tensor to vGPU device"""
249
- if not isinstance(tensor, torch.Tensor):
250
- tensor = torch.tensor(tensor)
251
 
252
- # Get or create vGPU device
253
- if not VGPUDevice._VGPU_INSTANCES:
254
- device = VGPUDevice(vram)
255
- else:
256
- device = next(iter(VGPUDevice._VGPU_INSTANCES.values()))
257
- if vram is not None:
258
- device.vram = vram
259
 
260
- # Move tensor to vGPU device
261
- return tensor.to(device.device())
262
-
263
- # Convenience function for creating tensors directly on vGPU
264
- def vgpu_tensor(*args, **kwargs):
265
- """Create a tensor directly on vGPU device"""
266
- # Remove device from kwargs if present
267
- kwargs.pop('device', None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
- # Get or create vGPU device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  if not VGPUDevice._VGPU_INSTANCES:
271
- device = VGPUDevice()
272
  else:
273
  device = next(iter(VGPUDevice._VGPU_INSTANCES.values()))
274
 
275
- # Create tensor on vGPU
276
- return torch.tensor(*args, device=device.device(), **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
  # Example usage and testing
279
  if __name__ == "__main__":
280
- # Initialize the backend
 
 
281
  if init_vgpu_backend():
282
- print("βœ“ vGPU backend initialized successfully")
283
-
284
- # Create vGPU device
 
 
 
285
  vgpu = VGPUDevice()
286
  print(f"βœ“ vGPU device created: {vgpu}")
287
 
288
  # Test tensor creation
289
- try:
290
- x = torch.randn(2, 3, device=vgpu.device())
291
- print(f"βœ“ Tensor created on {x.device}: shape {x.shape}")
292
-
293
- # Test tensor operations
294
- y = torch.randn(3, 4, device=vgpu.device())
295
- z = torch.mm(x, y)
296
- print(f"βœ“ Matrix multiplication result shape: {z.shape}")
297
-
298
- except Exception as e:
299
- print(f"βœ— Tensor operation failed: {e}")
300
- import traceback
301
- traceback.print_exc()
302
- else:
303
- print("βœ— Failed to initialize vGPU backend")
 
 
 
 
3
  from typing import Optional, Union, Tuple
4
  import numpy as np
5
  from virtual_vram import VirtualVRAM
6
+ import warnings
7
 
8
  # Global flag for backend initialization
9
  VGPU_BACKEND_INITIALIZED = False
10
 
11
+ def get_pytorch_version():
12
+ """Get PyTorch version as tuple for comparison"""
13
+ version = torch.__version__.split('.')
14
+ return tuple(int(x.split('+')[0]) for x in version[:2])
15
+
16
  def init_vgpu_backend():
17
  """Initialize the vGPU backend. Must be called before creating any VGPUDevice instances."""
18
  global VGPU_BACKEND_INITIALIZED
19
  try:
20
  if not VGPU_BACKEND_INITIALIZED:
21
+ pytorch_version = get_pytorch_version()
22
  backend_name = "vgpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # Method 1: Try modern PyTorch approach (2.0+)
25
+ if pytorch_version >= (2, 0):
26
+ try:
27
+ # Try the new API first
28
+ if hasattr(torch._C, '_dispatch') and hasattr(torch._C._dispatch, '_rename_privateuse1_backend'):
29
+ torch._C._dispatch._rename_privateuse1_backend(backend_name)
30
+ elif hasattr(torch, '_register_privateuse1_backend'):
31
+ # Alternative API in some PyTorch versions
32
+ torch._register_privateuse1_backend(backend_name)
33
+ else:
34
+ # Fallback: use torch.utils approach
35
+ raise AttributeError("Modern API not available")
36
+
37
+ # Generate methods for the backend
38
+ torch.utils.generate_methods_for_privateuse1_backend(
39
+ for_tensor=True,
40
+ for_module=True,
41
+ for_packed_sequence=True,
42
+ for_storage=True
43
+ )
44
+ backend_registered = True
45
+ except (AttributeError, RuntimeError) as e:
46
+ print(f"Modern backend registration failed: {e}")
47
+ backend_registered = False
48
+ else:
49
+ backend_registered = False
50
 
51
+ # Method 2: Fallback approach for older PyTorch or when modern approach fails
52
+ if not backend_registered:
53
+ print(f"Using fallback registration method for PyTorch {torch.__version__}")
54
+
55
+ # Create a mock device type that behaves like a custom device
56
+ class VGPUDeviceType:
57
+ def __init__(self, name):
58
+ self.name = name
59
+ self.index = 0
60
+
61
+ def __str__(self):
62
+ return f"{self.name}:{self.index}"
63
+
64
+ def __repr__(self):
65
+ return f"device(type='{self.name}', index={self.index})"
66
+
67
+ # Register our device type manually
68
+ backend_name = "vgpu"
69
 
70
+ # Define core operations using Library
71
+ try:
72
+ lib = Library(backend_name, "DEF")
73
+ impl_lib = Library(backend_name, "IMPL", "PrivateUse1")
74
+
75
+ # Define essential operations
76
+ lib.define("empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor")
77
+ lib.define("copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)")
78
+ lib.define("add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor")
79
+ lib.define("mm(Tensor self, Tensor mat2) -> Tensor")
80
+
81
+ @impl(impl_lib, "empty.memory_format")
82
+ def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None):
83
+ dtype = dtype or torch.float32
84
+ # Create on CPU but track metadata for vGPU
85
+ result = torch.empty(size, dtype=dtype, device='cpu')
86
+ return result
87
 
88
+ @impl(impl_lib, "copy_")
89
+ def copy_impl(self, src, non_blocking=False):
90
+ if isinstance(src, torch.Tensor):
91
+ self.data.copy_(src.cpu().data if hasattr(src, 'cpu') else src.data)
92
  return self
93
 
94
+ @impl(impl_lib, "add.Tensor")
95
+ def add_tensor(self, other, alpha=1):
96
+ # Perform add on CPU then return result
97
+ self_cpu = self.cpu() if hasattr(self, 'cpu') else self
98
+ other_cpu = other.cpu() if hasattr(other, 'cpu') else other
99
+ result = torch.add(self_cpu, other_cpu, alpha=alpha)
100
+ return result
101
+
102
+ @impl(impl_lib, "mm")
103
+ def mm_impl(self, mat2):
104
+ # Perform matmul on CPU
105
+ self_cpu = self.cpu() if hasattr(self, 'cpu') else self
106
+ mat2_cpu = mat2.cpu() if hasattr(mat2, 'cpu') else mat2
107
+ result = torch.mm(self_cpu, mat2_cpu)
108
+ return result
109
+
110
+ except Exception as e:
111
+ print(f"Library registration warning: {e}")
112
+ # Continue without library registration
113
 
114
  VGPU_BACKEND_INITIALIZED = True
115
 
116
  return VGPU_BACKEND_INITIALIZED
117
+
118
  except Exception as e:
119
  print(f"Backend initialization error: {e}")
120
  import traceback
121
  traceback.print_exc()
122
  return False
123
 
124
+ class VGPUDeviceMock:
125
+ """Mock device class that behaves like a PyTorch device"""
126
 
127
+ def __init__(self, device_name="vgpu", index=0):
128
+ self.type = device_name
129
+ self.index = index
130
+
131
+ def __str__(self):
132
+ return f"{self.type}:{self.index}"
133
+
134
+ def __repr__(self):
135
+ return f"device(type='{self.type}', index={self.index})"
 
 
136
 
137
+ def __eq__(self, other):
138
+ if isinstance(other, (VGPUDeviceMock, torch.device)):
139
+ return str(self) == str(other)
140
+ return str(self) == str(other)
141
+
142
+ def __hash__(self):
143
+ return hash(str(self))
144
 
145
  class VGPUTensor(torch.Tensor):
146
+ """Custom tensor class that handles vGPU operations"""
147
 
148
+ @staticmethod
149
+ def __new__(cls, data, device=None, requires_grad=False, vram=None):
 
150
  if not isinstance(data, torch.Tensor):
151
  data = torch.as_tensor(data)
152
 
153
+ # Create tensor on CPU but track vGPU device
154
+ r = torch.Tensor._make_subclass(cls, data.cpu(), requires_grad)
155
+ r._vgpu_device = device
156
+ r._vram = vram
157
  return r
158
 
159
+ @property
160
+ def device(self):
161
+ """Return the vGPU device"""
162
+ return self._vgpu_device or VGPUDeviceMock()
163
+
164
+ def cpu(self):
165
+ """Move tensor to CPU"""
166
+ cpu_tensor = torch.Tensor(self.data)
167
+ cpu_tensor.requires_grad = self.requires_grad
168
+ return cpu_tensor
169
+
170
+ def to(self, device, **kwargs):
171
+ """Handle device transfers"""
172
+ if isinstance(device, (VGPUDeviceMock, str)) and ('vgpu' in str(device)):
173
+ # Stay on vGPU
174
+ return self
175
+ else:
176
+ # Move to requested device
177
+ return self.data.to(device, **kwargs)
178
 
179
  class VGPUDevice:
180
  """
181
  Custom PyTorch device implementation that routes operations through vGPU.
182
  Usage:
183
  vgpu = VGPUDevice()
184
+ tensor = vgpu.tensor([1, 2, 3]) # Create tensor on vGPU
185
  """
186
+ _VGPU_INSTANCES = {}
187
 
188
+ def __init__(self, vram: Optional[VirtualVRAM] = None, device_index: int = 0):
189
+ # Initialize backend
190
  if not init_vgpu_backend():
191
+ print("Warning: Backend initialization incomplete, using fallback mode")
192
 
193
  self.vram = vram or VirtualVRAM()
194
+ self.tensor_cores = None
195
+ self.device_name = "vgpu"
196
+ self.device_index = device_index
197
+ self._device = VGPUDeviceMock(self.device_name, device_index)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ # Store this instance
200
+ VGPUDevice._VGPU_INSTANCES[f"{self.device_name}:{device_index}"] = self
 
 
 
201
 
202
+ print(f"βœ“ vGPU device initialized: {self._device}")
203
+
204
  def device(self):
205
+ """Get the device object"""
206
  return self._device
207
 
208
+ def tensor(self, data, **kwargs):
209
+ """Create a tensor on this vGPU device"""
210
+ kwargs.pop('device', None) # Remove device if specified
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
+ if isinstance(data, torch.Tensor):
213
+ result = VGPUTensor(data, device=self._device, vram=self.vram, **kwargs)
 
 
 
 
 
 
 
 
214
  else:
215
+ cpu_tensor = torch.tensor(data, **kwargs)
216
+ result = VGPUTensor(cpu_tensor, device=self._device, vram=self.vram)
 
 
217
 
218
+ # Store in vRAM
219
+ self._to_vram(result)
220
+ return result
 
 
 
 
 
221
 
222
+ def randn(self, *size, **kwargs):
223
+ """Create random tensor on vGPU"""
224
+ kwargs.pop('device', None)
225
+ cpu_tensor = torch.randn(*size, **kwargs)
226
+ result = VGPUTensor(cpu_tensor, device=self._device, vram=self.vram)
227
+ self._to_vram(result)
228
+ return result
229
 
230
+ def zeros(self, *size, **kwargs):
231
+ """Create zero tensor on vGPU"""
232
+ kwargs.pop('device', None)
233
+ cpu_tensor = torch.zeros(*size, **kwargs)
234
+ result = VGPUTensor(cpu_tensor, device=self._device, vram=self.vram)
235
+ self._to_vram(result)
236
+ return result
237
+
238
+ def ones(self, *size, **kwargs):
239
+ """Create ones tensor on vGPU"""
240
+ kwargs.pop('device', None)
241
+ cpu_tensor = torch.ones(*size, **kwargs)
242
+ result = VGPUTensor(cpu_tensor, device=self._device, vram=self.vram)
243
+ self._to_vram(result)
244
+ return result
245
+
246
+ def empty(self, *size, **kwargs):
247
+ """Create empty tensor on vGPU"""
248
+ kwargs.pop('device', None)
249
+ cpu_tensor = torch.empty(*size, **kwargs)
250
+ result = VGPUTensor(cpu_tensor, device=self._device, vram=self.vram)
251
+ self._to_vram(result)
252
+ return result
253
+
254
+ def _to_vram(self, tensor):
255
+ """Store tensor in vRAM"""
256
+ if hasattr(tensor, '_vram') and tensor._vram:
257
+ tensor_id = f"tensor_{id(tensor)}"
258
+ data = tensor.detach().cpu().numpy()
259
+ tensor._vram.storage.store_tensor(tensor_id, data)
260
+ tensor._vram_id = tensor_id
261
 
262
+ def _from_vram(self, tensor):
263
+ """Load tensor from vRAM"""
264
+ if hasattr(tensor, '_vram_id') and hasattr(tensor, '_vram'):
265
+ data = tensor._vram.storage.load_tensor(tensor._vram_id)
266
+ return torch.from_numpy(data)
267
+ return tensor.cpu()
268
+
269
+ def __str__(self):
270
+ return str(self._device)
271
+
272
+ def __repr__(self):
273
+ return f"VGPUDevice({self._device})"
274
+
275
+ # Convenience functions
276
+ def to_vgpu(tensor, vram=None):
277
+ """Move tensor to vGPU"""
278
  if not VGPUDevice._VGPU_INSTANCES:
279
+ device = VGPUDevice(vram)
280
  else:
281
  device = next(iter(VGPUDevice._VGPU_INSTANCES.values()))
282
 
283
+ if isinstance(tensor, VGPUTensor):
284
+ return tensor
285
+
286
+ result = VGPUTensor(tensor, device=device.device(), vram=device.vram)
287
+ device._to_vram(result)
288
+ return result
289
+
290
+ # Monkey patch torch functions to handle vGPU device strings
291
+ original_device = torch.device
292
+
293
+ def patched_device(device_spec):
294
+ """Patched device function to handle vGPU devices"""
295
+ if isinstance(device_spec, str) and device_spec.startswith('vgpu'):
296
+ parts = device_spec.split(':')
297
+ device_name = parts[0]
298
+ device_index = int(parts[1]) if len(parts) > 1 else 0
299
+ return VGPUDeviceMock(device_name, device_index)
300
+ return original_device(device_spec)
301
+
302
+ # Apply the patch
303
+ torch.device = patched_device
304
 
305
  # Example usage and testing
306
  if __name__ == "__main__":
307
+ print(f"PyTorch version: {torch.__version__}")
308
+
309
+ # Test backend initialization
310
  if init_vgpu_backend():
311
+ print("βœ“ vGPU backend initialized")
312
+ else:
313
+ print("! vGPU backend initialization incomplete, using fallback")
314
+
315
+ # Create vGPU device
316
+ try:
317
  vgpu = VGPUDevice()
318
  print(f"βœ“ vGPU device created: {vgpu}")
319
 
320
  # Test tensor creation
321
+ x = vgpu.randn(2, 3)
322
+ print(f"βœ“ Random tensor created on {x.device}: shape {x.shape}")
323
+
324
+ y = vgpu.ones(3, 4)
325
+ print(f"βœ“ Ones tensor created on {y.device}: shape {y.shape}")
326
+
327
+ # Test basic operations
328
+ z = x.data @ y.data # Matrix multiply on CPU data
329
+ print(f"βœ“ Matrix multiplication result shape: {z.shape}")
330
+
331
+ # Test device string parsing
332
+ device_str = torch.device("vgpu:0")
333
+ print(f"βœ“ Device string parsing: {device_str}")
334
+
335
+ except Exception as e:
336
+ print(f"βœ— Test failed: {e}")
337
+ import traceback
338
+ traceback.print_exc()