Factor Studios commited on
Commit
01c7c6f
·
verified ·
1 Parent(s): 976d2c2

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +45 -27
torch_vgpu.py CHANGED
@@ -2,6 +2,7 @@
2
  Custom PyTorch device implementation that routes operations through our virtual GPU.
3
  """
4
  import torch
 
5
  from typing import Optional, Union, Tuple
6
  import numpy as np
7
  from virtual_vram import VirtualVRAM
@@ -41,9 +42,35 @@ class VGPUDevice:
41
  self._register_device()
42
 
43
  def _register_device(self):
44
- """Register vGPU device using privateuse1 backend"""
45
  try:
46
- torch._C._dispatch._rename_privateuse1_backend(self.device_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  except Exception as e:
48
  raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
49
 
@@ -59,32 +86,26 @@ class VGPUDevice:
59
 
60
  def device(self):
61
  """Get the PyTorch device object"""
62
- return torch.device(str(self))
63
 
64
  def mode(self):
65
  """Get a context manager for vGPU operations"""
66
- from torch.utils._python_dispatch import TorchFunctionMode
67
-
68
- class _VGPUMode(TorchFunctionMode):
69
  def __init__(self, device):
70
  self.device = device
71
 
72
- def __torch_function__(self, func, types, args=(), kwargs=None):
73
- kwargs = kwargs or {}
74
 
75
- # Handle tensor creation and device placement
76
- if func is torch.tensor or 'device' in kwargs:
77
- kwargs['device'] = str(self.device)
78
-
79
- # Handle tensor operations
80
- new_args = []
81
- for arg in args:
82
- if isinstance(arg, torch.Tensor):
83
- if not str(arg.device).startswith(self.device.device_name):
84
- arg = to_vgpu(arg, self.device.vram)
85
- new_args.append(arg)
86
-
87
- return func(*new_args, **kwargs)
88
 
89
  return _VGPUMode(self)
90
 
@@ -126,9 +147,6 @@ def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.T
126
  """Move a tensor to vGPU device"""
127
  device = VGPUDevice(vram)
128
  tensor_id = device._to_vram(tensor)
129
- return VGPUTensor(device._from_vram(tensor_id))
130
-
131
- def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
132
- """Helper function to move tensors to vGPU"""
133
- device = VGPUDevice(vram)
134
- return tensor.to(device=device)
 
2
  Custom PyTorch device implementation that routes operations through our virtual GPU.
3
  """
4
  import torch
5
+ from torch.library import Library, impl
6
  from typing import Optional, Union, Tuple
7
  import numpy as np
8
  from virtual_vram import VirtualVRAM
 
42
  self._register_device()
43
 
44
  def _register_device(self):
45
+ """Register vGPU device using torch.library"""
46
  try:
47
+ # Create library for vGPU backend
48
+ lib = torch.library.Library(self.device_name, "IMPL")
49
+
50
+ # Register basic tensor operations
51
+ @torch.library.impl(lib, "aten::empty.memory_format")
52
+ def empty_impl(size, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None):
53
+ # Create empty tensor in CPU and move to vGPU
54
+ cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
55
+ return to_vgpu(cpu_tensor, self.vram)
56
+
57
+ @torch.library.impl(lib, "aten::add.Tensor")
58
+ def add_impl(self, other):
59
+ # Custom implementation of add operation
60
+ # Move tensors to CPU, add, then back to vGPU
61
+ cpu_result = self.cpu() + other.cpu()
62
+ return to_vgpu(cpu_result, self.vram)
63
+
64
+ @torch.library.impl(lib, "aten::copy_")
65
+ def copy_impl(self, src, non_blocking=False):
66
+ # Handle tensor copy operations
67
+ if not isinstance(src, torch.Tensor):
68
+ src = torch.tensor(src)
69
+ return to_vgpu(src.cpu(), self.vram)
70
+
71
+ # Get device after registration
72
+ self._device = torch.device(f"{self.device_name}:0")
73
+
74
  except Exception as e:
75
  raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
76
 
 
86
 
87
  def device(self):
88
  """Get the PyTorch device object"""
89
+ return self._device
90
 
91
  def mode(self):
92
  """Get a context manager for vGPU operations"""
93
+ class _VGPUMode:
 
 
94
  def __init__(self, device):
95
  self.device = device
96
 
97
+ def __enter__(self):
98
+ return self
99
 
100
+ def __exit__(self, exc_type, exc_val, exc_tb):
101
+ pass
102
+
103
+ def __call__(self, fn):
104
+ def wrapped(*args, **kwargs):
105
+ if 'device' in kwargs:
106
+ kwargs['device'] = str(self.device)
107
+ return fn(*args, **kwargs)
108
+ return wrapped
 
 
 
 
109
 
110
  return _VGPUMode(self)
111
 
 
147
  """Move a tensor to vGPU device"""
148
  device = VGPUDevice(vram)
149
  tensor_id = device._to_vram(tensor)
150
+ result = device._from_vram(tensor_id)
151
+ result.requires_grad = tensor.requires_grad
152
+ return result