Factor Studios commited on
Commit
897c760
·
verified ·
1 Parent(s): 32137d1

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +16 -29
torch_vgpu.py CHANGED
@@ -47,56 +47,43 @@ class VGPUDevice:
47
  def _register_device(self):
48
  """Register vGPU device using PyTorch's device system"""
49
  try:
50
- # Create internal device using privateuseone backend
51
  self._device = torch.device(f"{self.internal_name}:0")
52
 
53
- # Store this instance for device mapping
54
  VGPUDevice._VGPU_INSTANCES[self.device_name] = self
55
 
56
- # Register custom dispatcher for device mapping
57
- def device_mapper(device_str):
58
- if device_str.startswith(self.device_name):
59
- # Map vgpu -> privateuseone
60
- idx = device_str.split(":", 1)[1] if ":" in device_str else "0"
61
- return torch.device(f"{self.internal_name}:{idx}")
62
- return None
63
-
64
- # Register the mapper with PyTorch
65
- if not hasattr(torch, '_vgpu_device_mapper'):
66
- torch._vgpu_device_mapper = device_mapper
67
-
68
  # Define custom operations for the device
69
  class VGPUAllocator:
70
- def __init__(self, vram):
71
  self.vram = vram
 
72
 
73
  def __call__(self, size, dtype=None, device=None):
74
- if device is None or str(device).startswith("vgpu"):
75
- # Create tensor on CPU first
76
- cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
77
- # Move to vGPU storage
78
- return to_vgpu(cpu_tensor, self.vram)
79
- return torch.empty(size, dtype=dtype, device=device)
80
 
81
- # Set this device as the default for tensor allocation
82
- self._allocator = VGPUAllocator(self.vram)
83
 
84
  except Exception as e:
85
  raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
86
 
87
  @property
88
  def type(self):
89
- return self.device_name
90
 
91
  def __str__(self):
92
- return f"{self.device_name}:0"
93
 
94
  def __repr__(self):
95
- return f"{self.device_name}:0"
96
 
97
  def device(self):
98
  """Get the PyTorch device object that maps to our vGPU"""
99
- return torch.device(str(self))
100
 
101
  def mode(self):
102
  """Get a context manager for vGPU operations"""
@@ -154,6 +141,6 @@ def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.T
154
  result = device._from_vram(tensor_id)
155
  result.requires_grad = tensor.requires_grad
156
 
157
- # Set the device using the user-facing name
158
- result.data = result.data.to(f"{device.device_name}:0")
159
  return result
 
47
  def _register_device(self):
48
  """Register vGPU device using PyTorch's device system"""
49
  try:
50
+ # Register device using privateuseone
51
  self._device = torch.device(f"{self.internal_name}:0")
52
 
53
+ # Store this instance for reuse
54
  VGPUDevice._VGPU_INSTANCES[self.device_name] = self
55
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Define custom operations for the device
57
  class VGPUAllocator:
58
+ def __init__(self, vram, device):
59
  self.vram = vram
60
+ self.device = device
61
 
62
  def __call__(self, size, dtype=None, device=None):
63
+ # Create tensor on CPU first
64
+ cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
65
+ # Move to vGPU storage
66
+ return to_vgpu(cpu_tensor, self.vram)
 
 
67
 
68
+ # Set up allocator
69
+ self._allocator = VGPUAllocator(self.vram, self._device)
70
 
71
  except Exception as e:
72
  raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
73
 
74
  @property
75
  def type(self):
76
+ return self.internal_name
77
 
78
  def __str__(self):
79
+ return f"{self.internal_name}:0"
80
 
81
  def __repr__(self):
82
+ return f"vgpu(device='{self.internal_name}:0')"
83
 
84
  def device(self):
85
  """Get the PyTorch device object that maps to our vGPU"""
86
+ return self._device # Return the already created device object
87
 
88
  def mode(self):
89
  """Get a context manager for vGPU operations"""
 
141
  result = device._from_vram(tensor_id)
142
  result.requires_grad = tensor.requires_grad
143
 
144
+ # Set the device using the internal name
145
+ result.data = result.data.to(device._device)
146
  return result