Factor Studios commited on
Commit
48c05bf
·
verified ·
1 Parent(s): de03e5f

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +13 -44
torch_vgpu.py CHANGED
@@ -40,36 +40,22 @@ class VGPUDevice:
40
  self.tensor_cores = None # Will be initialized when needed
41
  self.device_name = "vgpu"
42
  self._register_device()
43
-
44
  def _register_device(self):
45
- """Register vGPU device using torch.library"""
46
  try:
47
- # Create library for vGPU backend
48
- lib = torch.library.Library(self.device_name, "IMPL")
49
 
50
- # Register basic tensor operations
51
- @torch.library.impl(lib, "aten::empty.memory_format")
52
- def empty_impl(size, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None):
53
- # Create empty tensor in CPU and move to vGPU
54
  cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
 
55
  return to_vgpu(cpu_tensor, self.vram)
56
 
57
- @torch.library.impl(lib, "aten::add.Tensor")
58
- def add_impl(self, other):
59
- # Custom implementation of add operation
60
- # Move tensors to CPU, add, then back to vGPU
61
- cpu_result = self.cpu() + other.cpu()
62
- return to_vgpu(cpu_result, self.vram)
63
-
64
- @torch.library.impl(lib, "aten::copy_")
65
- def copy_impl(self, src, non_blocking=False):
66
- # Handle tensor copy operations
67
- if not isinstance(src, torch.Tensor):
68
- src = torch.tensor(src)
69
- return to_vgpu(src.cpu(), self.vram)
70
-
71
- # Get device after registration
72
- self._device = torch.device(f"{self.device_name}:0")
73
 
74
  except Exception as e:
75
  raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
@@ -79,10 +65,10 @@ class VGPUDevice:
79
  return self.device_name
80
 
81
  def __str__(self):
82
- return f"{self.device_name}:0"
83
 
84
  def __repr__(self):
85
- return f"{self.device_name}:0"
86
 
87
  def device(self):
88
  """Get the PyTorch device object"""
@@ -90,24 +76,7 @@ class VGPUDevice:
90
 
91
  def mode(self):
92
  """Get a context manager for vGPU operations"""
93
- class _VGPUMode:
94
- def __init__(self, device):
95
- self.device = device
96
-
97
- def __enter__(self):
98
- return self
99
-
100
- def __exit__(self, exc_type, exc_val, exc_tb):
101
- pass
102
-
103
- def __call__(self, fn):
104
- def wrapped(*args, **kwargs):
105
- if 'device' in kwargs:
106
- kwargs['device'] = str(self.device)
107
- return fn(*args, **kwargs)
108
- return wrapped
109
-
110
- return _VGPUMode(self)
111
 
112
  def _init_tensor_cores(self):
113
  if self.tensor_cores is None:
 
40
  self.tensor_cores = None # Will be initialized when needed
41
  self.device_name = "vgpu"
42
  self._register_device()
43
+
44
  def _register_device(self):
45
+ """Register vGPU device using PyTorch's device system"""
46
  try:
47
+ # Register vGPU as a device
48
+ self._device = torch.device(self.device_name)
49
 
50
+ # Set up default allocator for vGPU device
51
+ def vgpu_allocator(size, dtype=None, device=None):
52
+ # Create tensor on CPU first
 
53
  cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
54
+ # Move to vGPU storage
55
  return to_vgpu(cpu_tensor, self.vram)
56
 
57
+ # Register allocator for the device
58
+ torch.utils.set_default_tensor_type(vgpu_allocator)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  except Exception as e:
61
  raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
 
65
  return self.device_name
66
 
67
  def __str__(self):
68
+ return str(self._device)
69
 
70
  def __repr__(self):
71
+ return str(self._device)
72
 
73
  def device(self):
74
  """Get the PyTorch device object"""
 
76
 
77
  def mode(self):
78
  """Get a context manager for vGPU operations"""
79
+ return torch.device(self._device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  def _init_tensor_cores(self):
82
  if self.tensor_cores is None: