Factor Studios commited on
Commit
e4541c8
·
verified ·
1 Parent(s): 2456f91

Upload 2 files

Browse files
Files changed (2) hide show
  1. test_ai_integration_http.py +198 -199
  2. torch_vgpu.py +200 -190
test_ai_integration_http.py CHANGED
@@ -1,200 +1,199 @@
1
- import logging
2
- import os
3
- import time
4
- from contextlib import contextmanager
5
- from typing import Any, Optional
6
-
7
- import torch
8
- from transformers import pipeline
9
- from virtual_vram import VirtualVRAM
10
- from http_storage import HTTPGPUStorage
11
- from torch_vgpu import VGPUDevice, to_vgpu
12
-
13
- def setup_vgpu():
14
- """Setup vGPU device"""
15
- try:
16
- # Initialize the backend first
17
- from torch_vgpu import init_vgpu_backend, VGPUDevice
18
- if not init_vgpu_backend():
19
- raise RuntimeError("Failed to initialize vGPU backend")
20
-
21
- # Create and register vGPU device
22
- vgpu = VGPUDevice()
23
- device = vgpu.device()
24
-
25
- # Set as default device for tensor operations
26
- return device
27
-
28
- except Exception as e:
29
- logging.error(f"vGPU setup failed: {str(e)}")
30
- raise
31
-
32
- # Configure logging
33
- logging.basicConfig(
34
- level=logging.INFO,
35
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
36
- )
37
- logger = logging.getLogger(__name__)
38
-
39
- @contextmanager
40
- def gpu_context():
41
- """Context manager for vGPU resources"""
42
- storage = None
43
- try:
44
- storage = HTTPGPUStorage()
45
- yield storage
46
- finally:
47
- if storage:
48
- storage.close()
49
- logger.info("vGPU resources cleaned up")
50
-
51
- def get_model_size(model):
52
- """Calculate model size in parameters and memory footprint"""
53
- param_size = 0
54
- for param in model.parameters():
55
- param_size += param.nelement() * param.element_size()
56
- buffer_size = 0
57
- for buffer in model.buffers():
58
- buffer_size += buffer.nelement() * buffer.element_size()
59
- return param_size + buffer_size
60
-
61
- def prepare_prompt(instruction: str) -> str:
62
- """Prepare a prompt for Llama-2 using its chat format."""
63
- # Format: <s>[INST] instruction [/INST] assistant response </s>[INST] ...
64
- return f"<s>[INST] {instruction} [/INST]"
65
-
66
- def test_ai_integration_http():
67
- """Test GPT OSS model on vGPU with text generation"""
68
- logger.info("Starting vGPU text generation test")
69
-
70
- status = {
71
- 'pipeline_loaded': False,
72
- 'model_on_vgpu': False,
73
- 'generation_complete': False,
74
- 'cleanup_success': False
75
- }
76
-
77
- with gpu_context() as storage:
78
- try:
79
- # Initialize vRAM with monitoring
80
- initial_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
81
- vram = VirtualVRAM(size_gb=None, storage=storage)
82
-
83
- # Initialize vGPU device
84
- device = setup_vgpu()
85
- logger.info(f"vGPU initialized with device {device}")
86
-
87
- # Load model using pipeline
88
- model_id = "openai/gpt-oss-20b"
89
- logger.info(f"Loading {model_id}")
90
-
91
- try:
92
- # Disable transformers logging temporarily
93
- transformers_logger = logging.getLogger("transformers")
94
- original_level = transformers_logger.level
95
- transformers_logger.setLevel(logging.ERROR)
96
-
97
- try:
98
- # Create pipeline with direct vGPU device mapping
99
- pipe = pipeline(
100
- "text-generation",
101
- model=model_id,
102
- torch_dtype=torch.float32, # Use full precision
103
- device=device, # Load directly to vGPU
104
- use_safetensors=True,
105
- trust_remote_code=True,
106
- model_kwargs={
107
- "device_map": device, # Ensure all model parts go to vGPU
108
- "vram": vram # Pass our vRAM manager
109
- }
110
- )
111
- status["pipeline_loaded"] = True
112
- status['model_on_vgpu'] = True
113
- pipe.model.eval()
114
-
115
- # Log model details
116
- logger.info(f"Pipeline created with model: {model_id}")
117
-
118
- # Log model size
119
- model_size = get_model_size(pipe.model)
120
- logger.info(f"Model loaded: {model_size/1e9:.2f} GB in parameters")
121
- logger.info(f"Model architecture: {pipe.model.__class__.__name__}")
122
-
123
- # Verify model location
124
- with torch.device(device):
125
- current_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
126
- logger.info(f"Model memory usage: {(current_mem - initial_mem)/1e9:.2f} GB")
127
-
128
- finally:
129
- # Restore original logging level
130
- transformers_logger.setLevel(original_level)
131
-
132
- except Exception as e:
133
- logger.error(f"Model loading failed: {str(e)}")
134
- raise
135
- except Exception as e:
136
- logger.error(f"Model transfer to vGPU failed: {str(e)}")
137
- raise
138
-
139
- # Run text generation
140
- logger.info("Running text generation...")
141
- start = time.time()
142
- peak_mem = initial_mem
143
-
144
- try:
145
- # Prepare input prompt
146
- prompt = "Explain how virtual GPUs work in simple terms."
147
-
148
- with torch.no_grad():
149
- outputs = pipe(
150
- prompt,
151
- max_new_tokens=256,
152
- temperature=0.7,
153
- top_p=0.95,
154
- top_k=40,
155
- num_beams=1,
156
- do_sample=True,
157
- return_full_text=True
158
- )
159
-
160
- if hasattr(storage, 'get_used_memory'):
161
- peak_mem = max(peak_mem, storage.get_used_memory())
162
-
163
- inference_time = time.time() - start
164
- status['generation_complete'] = True
165
-
166
- # Log performance metrics
167
- logger.info(f"\nGeneration stats:")
168
- logger.info(f"- Time: {inference_time:.4f}s")
169
- logger.info(f"- Memory peak: {(peak_mem - initial_mem)/1e9:.2f} GB")
170
- logger.info(f"- Generated text: {outputs[0]['generated_text']}")
171
-
172
- except Exception as e:
173
- logger.error(f"Text generation failed: {str(e)}")
174
- raise
175
-
176
- except Exception as e:
177
- logger.error(f"Test failed: {str(e)}")
178
- raise
179
- finally:
180
- # Cleanup and status report
181
- try:
182
- if 'pipe' in locals():
183
- del pipe
184
- if 'outputs' in locals():
185
- del outputs
186
- torch.cuda.empty_cache() if hasattr(torch, 'cuda') else None
187
- status['cleanup_success'] = True
188
- except Exception as e:
189
- logger.error(f"Cleanup error: {str(e)}")
190
-
191
- logger.info("\nTest Summary:")
192
- for key, value in status.items():
193
- logger.info(f"- {key}: {'✓' if value else '✗'}")
194
-
195
- final_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
196
- if final_mem > initial_mem:
197
- logger.warning(f"Memory leak detected: {(final_mem - initial_mem)/1e6:.2f} MB")
198
-
199
- if __name__ == "__main__":
200
  test_ai_integration_http()
 
1
+ import logging
2
+ import os
3
+ import time
4
+ from contextlib import contextmanager
5
+ from typing import Any, Optional
6
+
7
+ import torch
8
+ from transformers import pipeline
9
+ from virtual_vram import VirtualVRAM
10
+ from http_storage import HTTPGPUStorage
11
+ from torch_vgpu import VGPUDevice, to_vgpu
12
+
13
+ def setup_vgpu():
14
+ """Setup vGPU device"""
15
+ try:
16
+ # Initialize the backend first
17
+ from torch_vgpu import init_vgpu_backend, VGPUDevice
18
+ if not init_vgpu_backend():
19
+ raise RuntimeError("Failed to initialize vGPU backend")
20
+
21
+ # Create and register vGPU device
22
+ vgpu = VGPUDevice()
23
+ device = vgpu.device()
24
+
25
+ # Set as default device for tensor operations
26
+ return device
27
+
28
+ except Exception as e:
29
+ logging.error(f"vGPU setup failed: {str(e)}")
30
+ raise
31
+
32
+ # Configure logging
33
+ logging.basicConfig(
34
+ level=logging.INFO,
35
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
36
+ )
37
+ logger = logging.getLogger(__name__)
38
+
39
+ @contextmanager
40
+ def gpu_context():
41
+ """Context manager for vGPU resources"""
42
+ storage = None
43
+ try:
44
+ storage = HTTPGPUStorage()
45
+ yield storage
46
+ finally:
47
+ if storage:
48
+ storage.close()
49
+ logger.info("vGPU resources cleaned up")
50
+
51
+ def get_model_size(model):
52
+ """Calculate model size in parameters and memory footprint"""
53
+ param_size = 0
54
+ for param in model.parameters():
55
+ param_size += param.nelement() * param.element_size()
56
+ buffer_size = 0
57
+ for buffer in model.buffers():
58
+ buffer_size += buffer.nelement() * buffer.element_size()
59
+ return param_size + buffer_size
60
+
61
+ def prepare_prompt(instruction: str) -> str:
62
+ """Prepare a prompt for Llama-2 using its chat format."""
63
+ # Format: <s>[INST] instruction [/INST] assistant response </s>[INST] ...
64
+ return f"<s>[INST] {instruction} [/INST]"
65
+
66
+ def test_ai_integration_http():
67
+ """Test GPT OSS model on vGPU with text generation"""
68
+ logger.info("Starting vGPU text generation test")
69
+
70
+ status = {
71
+ 'pipeline_loaded': False,
72
+ 'model_on_vgpu': False,
73
+ 'generation_complete': False,
74
+ 'cleanup_success': False
75
+ }
76
+
77
+ with gpu_context() as storage:
78
+ try:
79
+ # Initialize vRAM with monitoring
80
+ initial_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
81
+ vram = VirtualVRAM(size_gb=None, storage=storage)
82
+
83
+ # Initialize vGPU device
84
+ device = setup_vgpu()
85
+ logger.info(f"vGPU initialized with device {device}")
86
+
87
+ # Load model using pipeline
88
+ model_id = "openai/gpt-oss-20b"
89
+ logger.info(f"Loading {model_id}")
90
+
91
+ try:
92
+ # Disable transformers logging temporarily
93
+ transformers_logger = logging.getLogger("transformers")
94
+ original_level = transformers_logger.level
95
+ transformers_logger.setLevel(logging.ERROR)
96
+
97
+ try:
98
+ # Create pipeline with direct vGPU device mapping
99
+ pipe = pipeline(
100
+ "text-generation",
101
+ model=model_id,
102
+ torch_dtype=torch.float32, # Use full precision
103
+ device=device, # Load directly to vGPU
104
+ use_safetensors=True,
105
+ trust_remote_code=True,
106
+ model_kwargs={
107
+ "device_map": device # Ensure all model parts go to vGPU
108
+ }
109
+ )
110
+ status["pipeline_loaded"] = True
111
+ status['model_on_vgpu'] = True
112
+ pipe.model.eval()
113
+
114
+ # Log model details
115
+ logger.info(f"Pipeline created with model: {model_id}")
116
+
117
+ # Log model size
118
+ model_size = get_model_size(pipe.model)
119
+ logger.info(f"Model loaded: {model_size/1e9:.2f} GB in parameters")
120
+ logger.info(f"Model architecture: {pipe.model.__class__.__name__}")
121
+
122
+ # Verify model location
123
+ with torch.device(device):
124
+ current_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
125
+ logger.info(f"Model memory usage: {(current_mem - initial_mem)/1e9:.2f} GB")
126
+
127
+ finally:
128
+ # Restore original logging level
129
+ transformers_logger.setLevel(original_level)
130
+
131
+ except Exception as e:
132
+ logger.error(f"Model loading failed: {str(e)}")
133
+ raise
134
+ except Exception as e:
135
+ logger.error(f"Model transfer to vGPU failed: {str(e)}")
136
+ raise
137
+
138
+ # Run text generation
139
+ logger.info("Running text generation...")
140
+ start = time.time()
141
+ peak_mem = initial_mem
142
+
143
+ try:
144
+ # Prepare input prompt
145
+ prompt = "Explain how virtual GPUs work in simple terms."
146
+
147
+ with torch.no_grad():
148
+ outputs = pipe(
149
+ prompt,
150
+ max_new_tokens=256,
151
+ temperature=0.7,
152
+ top_p=0.95,
153
+ top_k=40,
154
+ num_beams=1,
155
+ do_sample=True,
156
+ return_full_text=True
157
+ )
158
+
159
+ if hasattr(storage, 'get_used_memory'):
160
+ peak_mem = max(peak_mem, storage.get_used_memory())
161
+
162
+ inference_time = time.time() - start
163
+ status['generation_complete'] = True
164
+
165
+ # Log performance metrics
166
+ logger.info(f"\nGeneration stats:")
167
+ logger.info(f"- Time: {inference_time:.4f}s")
168
+ logger.info(f"- Memory peak: {(peak_mem - initial_mem)/1e9:.2f} GB")
169
+ logger.info(f"- Generated text: {outputs[0]['generated_text']}")
170
+
171
+ except Exception as e:
172
+ logger.error(f"Text generation failed: {str(e)}")
173
+ raise
174
+
175
+ except Exception as e:
176
+ logger.error(f"Test failed: {str(e)}")
177
+ raise
178
+ finally:
179
+ # Cleanup and status report
180
+ try:
181
+ if 'pipe' in locals():
182
+ del pipe
183
+ if 'outputs' in locals():
184
+ del outputs
185
+ torch.cuda.empty_cache() if hasattr(torch, 'cuda') else None
186
+ status['cleanup_success'] = True
187
+ except Exception as e:
188
+ logger.error(f"Cleanup error: {str(e)}")
189
+
190
+ logger.info("\nTest Summary:")
191
+ for key, value in status.items():
192
+ logger.info(f"- {key}: {'✓' if value else '✗'}")
193
+
194
+ final_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
195
+ if final_mem > initial_mem:
196
+ logger.warning(f"Memory leak detected: {(final_mem - initial_mem)/1e6:.2f} MB")
197
+
198
+ if __name__ == "__main__":
 
199
  test_ai_integration_http()
torch_vgpu.py CHANGED
@@ -1,190 +1,200 @@
1
- import torch
2
- from torch.library import Library, impl
3
- from typing import Optional, Union, Tuple
4
- import numpy as np
5
- from virtual_vram import VirtualVRAM
6
-
7
- # Global flag for backend initialization
8
- VGPU_BACKEND_INITIALIZED = False
9
-
10
- def init_vgpu_backend():
11
- """Initialize the vGPU backend. Must be called before creating any VGPUDevice instances."""
12
- global VGPU_BACKEND_INITIALIZED
13
- try:
14
- if not VGPU_BACKEND_INITIALIZED:
15
- # First define our core library
16
- lib = Library("vgpu", "DEF")
17
- lib.define("custom_allocate(Device? device) -> Tensor")
18
- lib.define("custom_to_cpu(Tensor self) -> Tensor")
19
- lib.define("custom_from_cpu(Tensor self) -> Tensor")
20
-
21
- # Then implement the operations
22
- impl_lib = Library("vgpu", "IMPL", "PrivateUse1")
23
-
24
- @impl(impl_lib, "custom_allocate")
25
- def custom_allocate(device=None):
26
- return torch.empty((), device="cpu")
27
-
28
- @impl(impl_lib, "custom_to_cpu")
29
- def custom_to_cpu(tensor):
30
- return tensor.clone()
31
-
32
- @impl(impl_lib, "custom_from_cpu")
33
- def custom_from_cpu(tensor):
34
- return tensor.clone()
35
-
36
- # Generate all methods for our backend
37
- torch.utils.generate_methods_for_privateuse1_backend(
38
- for_tensor=True,
39
- for_module=True,
40
- for_packed_sequence=True,
41
- for_storage=True
42
- )
43
-
44
- VGPU_BACKEND_INITIALIZED = True
45
-
46
- return VGPU_BACKEND_INITIALIZED
47
- except Exception as e:
48
- print(f"Backend initialization warning: {e}")
49
- return False
50
-
51
- class VGPUStorage(torch.Storage):
52
- """Custom storage class that uses our virtual VRAM"""
53
-
54
- def __init__(self, *args, **kwargs):
55
- super().__init__(*args, **kwargs)
56
- self.vram = kwargs.get("vram")
57
- if not self.vram:
58
- from virtual_vram import VirtualVRAM
59
- self.vram = VirtualVRAM()
60
- self.tensor_id = kwargs.get("tensor_id", f"tensor_{id(self)}")
61
-
62
- def _new_shared(self, size):
63
- return VGPUStorage(size, vram=self.vram)
64
-
65
- class VGPUTensor:
66
- """Tensor implementation that uses vGPU for computations"""
67
- @staticmethod
68
- def __new__(cls, elem):
69
- return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
70
-
71
- class VGPUDevice:
72
- """
73
- Custom PyTorch device implementation that routes operations through vGPU.
74
- Usage:
75
- vgpu = VGPUDevice()
76
- with vgpu.mode():
77
- tensor = torch.randn(2, 3) # Will be on vGPU
78
- """
79
- _VGPU_INSTANCES = {} # Class-level dict to track instances
80
-
81
- def __init__(self, vram: Optional[VirtualVRAM] = None):
82
- self.vram = vram or VirtualVRAM()
83
- self.tensor_cores = None # Will be initialized when needed
84
- self.device_name = "privateuseone" # Our registered device type
85
- self._register_device()
86
-
87
- def _register_device(self):
88
- """Register vGPU device using PyTorch's device system"""
89
- try:
90
- if not VGPU_BACKEND_INITIALIZED:
91
- raise RuntimeError("VGPU backend not properly initialized")
92
-
93
- # Create device using our registered device type
94
- self._device = torch.device(self.device_name)
95
-
96
- # Store this instance for reuse
97
- VGPUDevice._VGPU_INSTANCES[self.device_name] = self
98
-
99
- # Define custom operations for the device
100
- class VGPUAllocator:
101
- def __init__(self, vram, device):
102
- self.vram = vram
103
- self.device = device
104
-
105
- def __call__(self, size, dtype=None, device=None):
106
- # Create tensor on CPU first
107
- cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
108
- # Move to vGPU storage
109
- return to_vgpu(cpu_tensor, self.vram)
110
-
111
- # Set up allocator
112
- self._allocator = VGPUAllocator(self.vram, self._device)
113
-
114
- except Exception as e:
115
- raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
116
-
117
- @property
118
- def type(self):
119
- return self.internal_name
120
-
121
- def __str__(self):
122
- return f"{self.internal_name}:0"
123
-
124
- def __repr__(self):
125
- return f"vgpu(device='{self.internal_name}:0')"
126
-
127
- def device(self):
128
- """Get the PyTorch device object that maps to our vGPU"""
129
- return self._device # Return the already created device object
130
-
131
- def mode(self):
132
- """Get a context manager for vGPU operations"""
133
- return torch.device(self._device)
134
-
135
- def _init_tensor_cores(self):
136
- if self.tensor_cores is None:
137
- from tensor_core import TensorCoreArray
138
- self.tensor_cores = TensorCoreArray()
139
-
140
- def _to_vram(self, tensor: torch.Tensor) -> str:
141
- """Store tensor data in virtual VRAM"""
142
- tensor_id = f"tensor_{id(tensor)}"
143
- data = tensor.detach().cpu().numpy()
144
- self.vram.storage.store_tensor(tensor_id, data)
145
- return tensor_id
146
-
147
- def _from_vram(self, tensor_id: str) -> torch.Tensor:
148
- """Retrieve tensor data from virtual VRAM"""
149
- data = self.vram.storage.load_tensor(tensor_id)
150
- return torch.from_numpy(data)
151
-
152
- def matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
153
- """Matrix multiplication using tensor cores"""
154
- self._init_tensor_cores()
155
-
156
- # Store inputs in VRAM
157
- a_id = self._to_vram(a)
158
- b_id = self._to_vram(b)
159
-
160
- # Perform matmul using tensor cores
161
- result = self.tensor_cores.matmul(
162
- self.vram.storage.load_tensor(a_id),
163
- self.vram.storage.load_tensor(b_id)
164
- )
165
-
166
- # Create new tensor with result
167
- return torch.from_numpy(result)
168
-
169
- def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
170
- """Move a tensor to vGPU device"""
171
- if not isinstance(tensor, torch.Tensor):
172
- tensor = torch.tensor(tensor)
173
-
174
- # Get or create vGPU device
175
- if not VGPUDevice._VGPU_INSTANCES:
176
- device = VGPUDevice(vram)
177
- else:
178
- device = next(iter(VGPUDevice._VGPU_INSTANCES.values()))
179
- if vram is not None:
180
- device.vram = vram
181
-
182
- # Move data to vRAM
183
- tensor_id = device._to_vram(tensor)
184
- result = device._from_vram(tensor_id)
185
- result.requires_grad = tensor.requires_grad
186
-
187
- # Set the device using the internal name
188
- result.data = result.data.to(device._device)
189
- return result
190
-
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.library import Library, impl
3
+ from typing import Optional, Union, Tuple
4
+ import numpy as np
5
+ from virtual_vram import VirtualVRAM
6
+
7
+ # Global variables for backend state
8
+ VGPU_BACKEND_INITIALIZED = False
9
+ CURRENT_VRAM = None # Global reference to current vRAM manager
10
+
11
+ def set_current_vram(vram):
12
+ """Set the current vRAM manager globally"""
13
+ global CURRENT_VRAM
14
+ CURRENT_VRAM = vram
15
+
16
+ def get_current_vram():
17
+ """Get the current vRAM manager"""
18
+ return CURRENT_VRAM
19
+
20
+ def init_vgpu_backend():
21
+ """Initialize the vGPU backend. Must be called before creating any VGPUDevice instances."""
22
+ global VGPU_BACKEND_INITIALIZED
23
+ try:
24
+ if not VGPU_BACKEND_INITIALIZED:
25
+ # First define our core library
26
+ lib = Library("vgpu", "DEF")
27
+ lib.define("custom_allocate(Device? device) -> Tensor")
28
+ lib.define("custom_to_cpu(Tensor self) -> Tensor")
29
+ lib.define("custom_from_cpu(Tensor self) -> Tensor")
30
+
31
+ # Then implement the operations
32
+ impl_lib = Library("vgpu", "IMPL", "PrivateUse1")
33
+
34
+ @impl(impl_lib, "custom_allocate")
35
+ def custom_allocate(device=None):
36
+ return torch.empty((), device="cpu")
37
+
38
+ @impl(impl_lib, "custom_to_cpu")
39
+ def custom_to_cpu(tensor):
40
+ return tensor.clone()
41
+
42
+ @impl(impl_lib, "custom_from_cpu")
43
+ def custom_from_cpu(tensor):
44
+ return tensor.clone()
45
+
46
+ # Generate all methods for our backend
47
+ torch.utils.generate_methods_for_privateuse1_backend(
48
+ for_tensor=True,
49
+ for_module=True,
50
+ for_packed_sequence=True,
51
+ for_storage=True
52
+ )
53
+
54
+ VGPU_BACKEND_INITIALIZED = True
55
+
56
+ return VGPU_BACKEND_INITIALIZED
57
+ except Exception as e:
58
+ print(f"Backend initialization warning: {e}")
59
+ return False
60
+
61
+ class VGPUStorage(torch.Storage):
62
+ """Custom storage class that uses our virtual VRAM"""
63
+
64
+ def __init__(self, *args, **kwargs):
65
+ super().__init__(*args, **kwargs)
66
+ self.vram = kwargs.get("vram")
67
+ if not self.vram:
68
+ from virtual_vram import VirtualVRAM
69
+ self.vram = VirtualVRAM()
70
+ self.tensor_id = kwargs.get("tensor_id", f"tensor_{id(self)}")
71
+
72
+ def _new_shared(self, size):
73
+ return VGPUStorage(size, vram=self.vram)
74
+
75
+ class VGPUTensor:
76
+ """Tensor implementation that uses vGPU for computations"""
77
+ @staticmethod
78
+ def __new__(cls, elem):
79
+ return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
80
+
81
+ class VGPUDevice:
82
+ """
83
+ Custom PyTorch device implementation that routes operations through vGPU.
84
+ Usage:
85
+ vgpu = VGPUDevice()
86
+ with vgpu.mode():
87
+ tensor = torch.randn(2, 3) # Will be on vGPU
88
+ """
89
+ _VGPU_INSTANCES = {} # Class-level dict to track instances
90
+
91
+ def __init__(self, vram: Optional[VirtualVRAM] = None):
92
+ self.vram = vram or VirtualVRAM()
93
+ self.tensor_cores = None # Will be initialized when needed
94
+ self.device_name = "privateuseone" # Our registered device type
95
+ set_current_vram(self.vram) # Set up global vRAM reference
96
+ self._register_device()
97
+
98
+ def _register_device(self):
99
+ """Register vGPU device using PyTorch's device system"""
100
+ try:
101
+ if not VGPU_BACKEND_INITIALIZED:
102
+ raise RuntimeError("VGPU backend not properly initialized")
103
+
104
+ # Create device using our registered device type
105
+ self._device = torch.device(self.device_name)
106
+
107
+ # Store this instance for reuse
108
+ VGPUDevice._VGPU_INSTANCES[self.device_name] = self
109
+
110
+ # Define custom operations for the device
111
+ class VGPUAllocator:
112
+ def __init__(self, vram, device):
113
+ self.vram = vram
114
+ self.device = device
115
+
116
+ def __call__(self, size, dtype=None, device=None):
117
+ # Create tensor on CPU first
118
+ cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
119
+ # Move to vGPU storage
120
+ return to_vgpu(cpu_tensor, self.vram)
121
+
122
+ # Set up allocator
123
+ self._allocator = VGPUAllocator(self.vram, self._device)
124
+
125
+ except Exception as e:
126
+ raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
127
+
128
+ @property
129
+ def type(self):
130
+ return self.internal_name
131
+
132
+ def __str__(self):
133
+ return f"{self.internal_name}:0"
134
+
135
+ def __repr__(self):
136
+ return f"vgpu(device='{self.internal_name}:0')"
137
+
138
+ def device(self):
139
+ """Get the PyTorch device object that maps to our vGPU"""
140
+ return self._device # Return the already created device object
141
+
142
+ def mode(self):
143
+ """Get a context manager for vGPU operations"""
144
+ return torch.device(self._device)
145
+
146
+ def _init_tensor_cores(self):
147
+ if self.tensor_cores is None:
148
+ from tensor_core import TensorCoreArray
149
+ self.tensor_cores = TensorCoreArray()
150
+
151
+ def _to_vram(self, tensor: torch.Tensor) -> str:
152
+ """Store tensor data in virtual VRAM"""
153
+ tensor_id = f"tensor_{id(tensor)}"
154
+ data = tensor.detach().cpu().numpy()
155
+ self.vram.storage.store_tensor(tensor_id, data)
156
+ return tensor_id
157
+
158
+ def _from_vram(self, tensor_id: str) -> torch.Tensor:
159
+ """Retrieve tensor data from virtual VRAM"""
160
+ data = self.vram.storage.load_tensor(tensor_id)
161
+ return torch.from_numpy(data)
162
+
163
+ def matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
164
+ """Matrix multiplication using tensor cores"""
165
+ self._init_tensor_cores()
166
+
167
+ # Store inputs in VRAM
168
+ a_id = self._to_vram(a)
169
+ b_id = self._to_vram(b)
170
+
171
+ # Perform matmul using tensor cores
172
+ result = self.tensor_cores.matmul(
173
+ self.vram.storage.load_tensor(a_id),
174
+ self.vram.storage.load_tensor(b_id)
175
+ )
176
+
177
+ # Create new tensor with result
178
+ return torch.from_numpy(result)
179
+
180
+ def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
181
+ """Move a tensor to vGPU device"""
182
+ if not isinstance(tensor, torch.Tensor):
183
+ tensor = torch.tensor(tensor)
184
+
185
+ # Get or create vGPU device
186
+ if not VGPUDevice._VGPU_INSTANCES:
187
+ device = VGPUDevice(vram)
188
+ else:
189
+ device = next(iter(VGPUDevice._VGPU_INSTANCES.values()))
190
+ if vram is not None:
191
+ device.vram = vram
192
+
193
+ # Move data to vRAM
194
+ tensor_id = device._to_vram(tensor)
195
+ result = device._from_vram(tensor_id)
196
+ result.requires_grad = tensor.requires_grad
197
+
198
+ # Set the device using the internal name
199
+ result.data = result.data.to(device._device)
200
+ return result