Factor Studios commited on
Commit
4a613fd
·
verified ·
1 Parent(s): 7a9b3ec

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +11 -6
torch_vgpu.py CHANGED
@@ -81,7 +81,7 @@ class VGPUDevice:
81
  def __init__(self, vram: Optional[VirtualVRAM] = None):
82
  self.vram = vram or VirtualVRAM()
83
  self.tensor_cores = None # Will be initialized when needed
84
- self.device_name = "privateuseone" # Our registered device type
85
  self._register_device()
86
 
87
  def _register_device(self):
@@ -103,10 +103,16 @@ class VGPUDevice:
103
  self.device = device
104
 
105
  def __call__(self, size, dtype=None, device=None):
106
- # Create tensor on CPU first
107
- cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
108
- # Move to vGPU storage
109
- return to_vgpu(cpu_tensor, self.vram)
 
 
 
 
 
 
110
 
111
  # Set up allocator
112
  self._allocator = VGPUAllocator(self.vram, self._device)
@@ -187,4 +193,3 @@ def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.T
187
  # Set the device using the internal name
188
  result.data = result.data.to(device._device)
189
  return result
190
-
 
81
  def __init__(self, vram: Optional[VirtualVRAM] = None):
82
  self.vram = vram or VirtualVRAM()
83
  self.tensor_cores = None # Will be initialized when needed
84
+ self.device_name = "vgpu" # Our registered device type
85
  self._register_device()
86
 
87
  def _register_device(self):
 
103
  self.device = device
104
 
105
  def __call__(self, size, dtype=None, device=None):
106
+ # Create tensor directly in vGPU memory
107
+ tensor_id = f"tensor_empty_{id(size)}"
108
+ # Initialize empty array of the right size and dtype
109
+ shape = size if isinstance(size, (tuple, list)) else (size,)
110
+ data = np.empty(shape, dtype=np.float32 if dtype is None else dtype)
111
+ # Store directly in vRAM
112
+ self.vram.storage.store_tensor(tensor_id, data)
113
+ # Create tensor with our device type
114
+ result = torch.as_tensor(data, device=self.device)
115
+ return result
116
 
117
  # Set up allocator
118
  self._allocator = VGPUAllocator(self.vram, self._device)
 
193
  # Set the device using the internal name
194
  result.data = result.data.to(device._device)
195
  return result