Factor Studios commited on
Commit
32137d1
·
verified ·
1 Parent(s): 812be1c

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +43 -15
torch_vgpu.py CHANGED
@@ -35,17 +35,35 @@ class VGPUDevice:
35
  with vgpu.mode():
36
  tensor = torch.randn(2, 3) # Will be on vGPU
37
  """
 
 
38
  def __init__(self, vram: Optional[VirtualVRAM] = None):
39
  self.vram = vram or VirtualVRAM()
40
  self.tensor_cores = None # Will be initialized when needed
41
- self.device_name = "privateuseone" # Use privateuseone as base device type
 
42
  self._register_device()
43
 
44
  def _register_device(self):
45
  """Register vGPU device using PyTorch's device system"""
46
  try:
47
- # Create device instance using privateuseone backend
48
- self._device = torch.device(f"{self.device_name}:0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # Define custom operations for the device
51
  class VGPUAllocator:
@@ -53,10 +71,12 @@ class VGPUDevice:
53
  self.vram = vram
54
 
55
  def __call__(self, size, dtype=None, device=None):
56
- # Create tensor on CPU first
57
- cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
58
- # Move to vGPU storage
59
- return to_vgpu(cpu_tensor, self.vram)
 
 
60
 
61
  # Set this device as the default for tensor allocation
62
  self._allocator = VGPUAllocator(self.vram)
@@ -66,17 +86,17 @@ class VGPUDevice:
66
 
67
  @property
68
  def type(self):
69
- return "vgpu" # User-facing device type name
70
 
71
  def __str__(self):
72
- return "vgpu"
73
 
74
  def __repr__(self):
75
- return "vgpu"
76
 
77
  def device(self):
78
- """Get the PyTorch device object"""
79
- return self._device
80
 
81
  def mode(self):
82
  """Get a context manager for vGPU operations"""
@@ -121,11 +141,19 @@ def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.T
121
  if not isinstance(tensor, torch.Tensor):
122
  tensor = torch.tensor(tensor)
123
 
124
- device = VGPUDevice(vram)
 
 
 
 
 
 
 
 
125
  tensor_id = device._to_vram(tensor)
126
  result = device._from_vram(tensor_id)
127
  result.requires_grad = tensor.requires_grad
128
 
129
- # Set the device correctly
130
- result.data = result.data.to(device.device())
131
  return result
 
35
  with vgpu.mode():
36
  tensor = torch.randn(2, 3) # Will be on vGPU
37
  """
38
+ _VGPU_INSTANCES = {} # Class-level dict to track instances
39
+
40
  def __init__(self, vram: Optional[VirtualVRAM] = None):
41
  self.vram = vram or VirtualVRAM()
42
  self.tensor_cores = None # Will be initialized when needed
43
+ self.internal_name = "privateuseone" # PyTorch backend name
44
+ self.device_name = "vgpu" # User-facing device name
45
  self._register_device()
46
 
47
  def _register_device(self):
48
  """Register vGPU device using PyTorch's device system"""
49
  try:
50
+ # Create internal device using privateuseone backend
51
+ self._device = torch.device(f"{self.internal_name}:0")
52
+
53
+ # Store this instance for device mapping
54
+ VGPUDevice._VGPU_INSTANCES[self.device_name] = self
55
+
56
+ # Register custom dispatcher for device mapping
57
+ def device_mapper(device_str):
58
+ if device_str.startswith(self.device_name):
59
+ # Map vgpu -> privateuseone
60
+ idx = device_str.split(":", 1)[1] if ":" in device_str else "0"
61
+ return torch.device(f"{self.internal_name}:{idx}")
62
+ return None
63
+
64
+ # Register the mapper with PyTorch
65
+ if not hasattr(torch, '_vgpu_device_mapper'):
66
+ torch._vgpu_device_mapper = device_mapper
67
 
68
  # Define custom operations for the device
69
  class VGPUAllocator:
 
71
  self.vram = vram
72
 
73
  def __call__(self, size, dtype=None, device=None):
74
+ if device is None or str(device).startswith("vgpu"):
75
+ # Create tensor on CPU first
76
+ cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
77
+ # Move to vGPU storage
78
+ return to_vgpu(cpu_tensor, self.vram)
79
+ return torch.empty(size, dtype=dtype, device=device)
80
 
81
  # Set this device as the default for tensor allocation
82
  self._allocator = VGPUAllocator(self.vram)
 
86
 
87
  @property
88
  def type(self):
89
+ return self.device_name
90
 
91
  def __str__(self):
92
+ return f"{self.device_name}:0"
93
 
94
  def __repr__(self):
95
+ return f"{self.device_name}:0"
96
 
97
  def device(self):
98
+ """Get the PyTorch device object that maps to our vGPU"""
99
+ return torch.device(str(self))
100
 
101
  def mode(self):
102
  """Get a context manager for vGPU operations"""
 
141
  if not isinstance(tensor, torch.Tensor):
142
  tensor = torch.tensor(tensor)
143
 
144
+ # Get or create vGPU device
145
+ if not VGPUDevice._VGPU_INSTANCES:
146
+ device = VGPUDevice(vram)
147
+ else:
148
+ device = next(iter(VGPUDevice._VGPU_INSTANCES.values()))
149
+ if vram is not None:
150
+ device.vram = vram
151
+
152
+ # Move data to vRAM
153
  tensor_id = device._to_vram(tensor)
154
  result = device._from_vram(tensor_id)
155
  result.requires_grad = tensor.requires_grad
156
 
157
+ # Set the device using the user-facing name
158
+ result.data = result.data.to(f"{device.device_name}:0")
159
  return result