Factor Studios commited on
Commit
812be1c
·
verified ·
1 Parent(s): 81aa430

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +25 -15
torch_vgpu.py CHANGED
@@ -38,37 +38,41 @@ class VGPUDevice:
38
  def __init__(self, vram: Optional[VirtualVRAM] = None):
39
  self.vram = vram or VirtualVRAM()
40
  self.tensor_cores = None # Will be initialized when needed
41
- self.device_name = "vgpu"
42
  self._register_device()
43
 
44
  def _register_device(self):
45
  """Register vGPU device using PyTorch's device system"""
46
  try:
47
- # Register vGPU as a device
48
- self._device = torch.device(self.device_name)
49
 
50
- # Set up default allocator for vGPU device
51
- def vgpu_allocator(size, dtype=None, device=None):
52
- # Create tensor on CPU first
53
- cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
54
- # Move to vGPU storage
55
- return to_vgpu(cpu_tensor, self.vram)
56
-
57
- # Register allocator for the device
58
- torch.utils.set_default_tensor_type(vgpu_allocator)
59
 
 
 
 
 
 
 
 
 
 
60
  except Exception as e:
61
  raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
62
 
63
  @property
64
  def type(self):
65
- return self.device_name
66
 
67
  def __str__(self):
68
- return str(self._device)
69
 
70
  def __repr__(self):
71
- return str(self._device)
72
 
73
  def device(self):
74
  """Get the PyTorch device object"""
@@ -114,8 +118,14 @@ class VGPUDevice:
114
 
115
  def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
116
  """Move a tensor to vGPU device"""
 
 
 
117
  device = VGPUDevice(vram)
118
  tensor_id = device._to_vram(tensor)
119
  result = device._from_vram(tensor_id)
120
  result.requires_grad = tensor.requires_grad
 
 
 
121
  return result
 
38
  def __init__(self, vram: Optional[VirtualVRAM] = None):
39
  self.vram = vram or VirtualVRAM()
40
  self.tensor_cores = None # Will be initialized when needed
41
+ self.device_name = "privateuseone" # Use privateuseone as base device type
42
  self._register_device()
43
 
44
  def _register_device(self):
45
  """Register vGPU device using PyTorch's device system"""
46
  try:
47
+ # Create device instance using privateuseone backend
48
+ self._device = torch.device(f"{self.device_name}:0")
49
 
50
+ # Define custom operations for the device
51
+ class VGPUAllocator:
52
+ def __init__(self, vram):
53
+ self.vram = vram
 
 
 
 
 
54
 
55
+ def __call__(self, size, dtype=None, device=None):
56
+ # Create tensor on CPU first
57
+ cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
58
+ # Move to vGPU storage
59
+ return to_vgpu(cpu_tensor, self.vram)
60
+
61
+ # Set this device as the default for tensor allocation
62
+ self._allocator = VGPUAllocator(self.vram)
63
+
64
  except Exception as e:
65
  raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
66
 
67
  @property
68
  def type(self):
69
+ return "vgpu" # User-facing device type name
70
 
71
  def __str__(self):
72
+ return "vgpu"
73
 
74
  def __repr__(self):
75
+ return "vgpu"
76
 
77
  def device(self):
78
  """Get the PyTorch device object"""
 
118
 
119
  def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
120
  """Move a tensor to vGPU device"""
121
+ if not isinstance(tensor, torch.Tensor):
122
+ tensor = torch.tensor(tensor)
123
+
124
  device = VGPUDevice(vram)
125
  tensor_id = device._to_vram(tensor)
126
  result = device._from_vram(tensor_id)
127
  result.requires_grad = tensor.requires_grad
128
+
129
+ # Set the device correctly
130
+ result.data = result.data.to(device.device())
131
  return result