Factor Studios commited on
Commit
172ea54
·
verified ·
1 Parent(s): 6797d5d

Upload 2 files

Browse files
Files changed (2) hide show
  1. test_ai_integration_http.py +208 -203
  2. torch_vgpu.py +182 -178
test_ai_integration_http.py CHANGED
@@ -1,203 +1,208 @@
1
- """
2
- Test Llama-2-7b-instruct model integration with vGPU.
3
- Configure PyTorch to use vGPU as device for text generation.
4
- """
5
- import logging
6
- import os
7
- import time
8
- from contextlib import contextmanager
9
- from typing import Any, Optional
10
-
11
- import torch
12
- from transformers import pipeline
13
- from virtual_vram import VirtualVRAM
14
- from http_storage import HTTPGPUStorage
15
- from torch_vgpu import VGPUDevice, to_vgpu
16
-
17
- def setup_vgpu():
18
- """Setup vGPU device"""
19
- try:
20
- # Create and register vGPU device
21
- vgpu = VGPUDevice()
22
- device = vgpu.device()
23
-
24
- # Set as default device for tensor operations
25
- torch.set_default_device(device)
26
-
27
- return device
28
-
29
- except Exception as e:
30
- logging.error(f"vGPU setup failed: {str(e)}")
31
- raise
32
-
33
- # Configure logging
34
- logging.basicConfig(
35
- level=logging.INFO,
36
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
37
- )
38
- logger = logging.getLogger(__name__)
39
-
40
- @contextmanager
41
- def gpu_context():
42
- """Context manager for vGPU resources"""
43
- storage = None
44
- try:
45
- storage = HTTPGPUStorage()
46
- yield storage
47
- finally:
48
- if storage:
49
- storage.close()
50
- logger.info("vGPU resources cleaned up")
51
-
52
- def get_model_size(model):
53
- """Calculate model size in parameters and memory footprint"""
54
- param_size = 0
55
- for param in model.parameters():
56
- param_size += param.nelement() * param.element_size()
57
- buffer_size = 0
58
- for buffer in model.buffers():
59
- buffer_size += buffer.nelement() * buffer.element_size()
60
- return param_size + buffer_size
61
-
62
- def prepare_prompt(instruction: str) -> str:
63
- """Prepare a prompt for Llama-2 using its chat format."""
64
- # Format: <s>[INST] instruction [/INST] assistant response </s>[INST] ...
65
- return f"<s>[INST] {instruction} [/INST]"
66
-
67
- def test_ai_integration_http():
68
- """Test GPT OSS model on vGPU with text generation"""
69
- logger.info("Starting vGPU text generation test")
70
-
71
- status = {
72
- 'pipeline_loaded': False,
73
- 'model_on_vgpu': False,
74
- 'generation_complete': False,
75
- 'cleanup_success': False
76
- }
77
-
78
- with gpu_context() as storage:
79
- try:
80
- # Initialize vRAM with monitoring
81
- initial_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
82
- vram = VirtualVRAM(size_gb=None, storage=storage)
83
-
84
- # Initialize vGPU device
85
- device = setup_vgpu()
86
- logger.info(f"vGPU initialized with device {device}")
87
-
88
- # Load model using pipeline
89
- model_id = "openai/gpt-oss-20b"
90
- logger.info(f"Loading {model_id}")
91
-
92
- try:
93
- # Disable transformers logging temporarily
94
- transformers_logger = logging.getLogger("transformers")
95
- original_level = transformers_logger.level
96
- transformers_logger.setLevel(logging.ERROR)
97
-
98
- try:
99
- # Create pipeline
100
- # Create pipeline with vGPU device
101
- pipe = pipeline(
102
- "text-generation",
103
- model=model_id,
104
- torch_dtype=torch.float32, # Use full precision
105
- device=device, # Use our vGPU device
106
- use_safetensors=True,
107
- trust_remote_code=True
108
- )
109
- status['pipeline_loaded'] = True
110
-
111
- # Move pipeline model to vGPU
112
- pipe.model = to_vgpu(pipe.model, vram=vram)
113
- pipe.model.eval()
114
- status['model_on_vgpu'] = True
115
-
116
- # Log model details
117
- logger.info(f"Pipeline created with model: {model_id}")
118
-
119
- # Log model size
120
- model_size = get_model_size(pipe.model)
121
- logger.info(f"Model loaded: {model_size/1e9:.2f} GB in parameters")
122
- logger.info(f"Model architecture: {pipe.model.__class__.__name__}")
123
-
124
- # Verify model location
125
- with torch.device(device):
126
- current_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
127
- logger.info(f"Model memory usage: {(current_mem - initial_mem)/1e9:.2f} GB")
128
-
129
- finally:
130
- # Restore original logging level
131
- transformers_logger.setLevel(original_level)
132
-
133
- except Exception as e:
134
- logger.error(f"Model loading failed: {str(e)}")
135
- raise
136
- except Exception as e:
137
- logger.error(f"Model transfer to vGPU failed: {str(e)}")
138
- raise
139
-
140
- # Run text generation
141
- logger.info("Running text generation...")
142
- start = time.time()
143
- peak_mem = initial_mem
144
-
145
- try:
146
- # Prepare input prompt
147
- prompt = "Explain how virtual GPUs work in simple terms."
148
-
149
- with torch.no_grad():
150
- # Generate text
151
- outputs = pipe(
152
- prompt,
153
- max_new_tokens=256,
154
- temperature=0.7,
155
- top_p=0.95,
156
- top_k=40,
157
- num_beams=1,
158
- do_sample=True,
159
- return_full_text=True
160
- )
161
-
162
- if hasattr(storage, 'get_used_memory'):
163
- peak_mem = max(peak_mem, storage.get_used_memory())
164
-
165
- inference_time = time.time() - start
166
- status['generation_complete'] = True
167
-
168
- # Log performance metrics
169
- logger.info(f"\nGeneration stats:")
170
- logger.info(f"- Time: {inference_time:.4f}s")
171
- logger.info(f"- Memory peak: {(peak_mem - initial_mem)/1e9:.2f} GB")
172
- logger.info(f"- Generated text: {outputs[0]['generated_text']}")
173
-
174
- except Exception as e:
175
- logger.error(f"Text generation failed: {str(e)}")
176
- raise
177
-
178
- except Exception as e:
179
- logger.error(f"Test failed: {str(e)}")
180
- raise
181
- finally:
182
- # Cleanup and status report
183
- try:
184
- if 'pipe' in locals():
185
- del pipe
186
- if 'outputs' in locals():
187
- del outputs
188
- torch.cuda.empty_cache() if hasattr(torch, 'cuda') else None
189
- status['cleanup_success'] = True
190
- except Exception as e:
191
- logger.error(f"Cleanup error: {str(e)}")
192
-
193
- logger.info("\nTest Summary:")
194
- for key, value in status.items():
195
- logger.info(f"- {key}: {'✓' if value else '✗'}")
196
-
197
- final_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
198
- if final_mem > initial_mem:
199
- logger.warning(f"Memory leak detected: {(final_mem - initial_mem)/1e6:.2f} MB")
200
-
201
- if __name__ == "__main__":
202
- test_ai_integration_http()
203
-
 
 
 
 
 
 
1
+ """
2
+ Test Llama-2-7b-instruct model integration with vGPU.
3
+ Configure PyTorch to use vGPU as device for text generation.
4
+ """
5
+ import logging
6
+ import os
7
+ import time
8
+ from contextlib import contextmanager
9
+ from typing import Any, Optional
10
+
11
+ import torch
12
+ from transformers import pipeline
13
+ from virtual_vram import VirtualVRAM
14
+ from http_storage import HTTPGPUStorage
15
+ from torch_vgpu import VGPUDevice, to_vgpu
16
+
17
+ def setup_vgpu():
18
+ """Setup vGPU device"""
19
+ try:
20
+ # Initialize the backend first
21
+ from torch_vgpu import init_vgpu_backend, VGPUDevice
22
+ if not init_vgpu_backend():
23
+ raise RuntimeError("Failed to initialize vGPU backend")
24
+
25
+ # Create and register vGPU device
26
+ vgpu = VGPUDevice()
27
+ device = vgpu.device()
28
+
29
+ # Set as default device for tensor operations
30
+ torch.set_default_device(device)
31
+
32
+ return device
33
+
34
+ except Exception as e:
35
+ logging.error(f"vGPU setup failed: {str(e)}")
36
+ raise
37
+
38
+ # Configure logging
39
+ logging.basicConfig(
40
+ level=logging.INFO,
41
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
42
+ )
43
+ logger = logging.getLogger(__name__)
44
+
45
+ @contextmanager
46
+ def gpu_context():
47
+ """Context manager for vGPU resources"""
48
+ storage = None
49
+ try:
50
+ storage = HTTPGPUStorage()
51
+ yield storage
52
+ finally:
53
+ if storage:
54
+ storage.close()
55
+ logger.info("vGPU resources cleaned up")
56
+
57
+ def get_model_size(model):
58
+ """Calculate model size in parameters and memory footprint"""
59
+ param_size = 0
60
+ for param in model.parameters():
61
+ param_size += param.nelement() * param.element_size()
62
+ buffer_size = 0
63
+ for buffer in model.buffers():
64
+ buffer_size += buffer.nelement() * buffer.element_size()
65
+ return param_size + buffer_size
66
+
67
+ def prepare_prompt(instruction: str) -> str:
68
+ """Prepare a prompt for Llama-2 using its chat format."""
69
+ # Format: <s>[INST] instruction [/INST] assistant response </s>[INST] ...
70
+ return f"<s>[INST] {instruction} [/INST]"
71
+
72
+ def test_ai_integration_http():
73
+ """Test GPT OSS model on vGPU with text generation"""
74
+ logger.info("Starting vGPU text generation test")
75
+
76
+ status = {
77
+ 'pipeline_loaded': False,
78
+ 'model_on_vgpu': False,
79
+ 'generation_complete': False,
80
+ 'cleanup_success': False
81
+ }
82
+
83
+ with gpu_context() as storage:
84
+ try:
85
+ # Initialize vRAM with monitoring
86
+ initial_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
87
+ vram = VirtualVRAM(size_gb=None, storage=storage)
88
+
89
+ # Initialize vGPU device
90
+ device = setup_vgpu()
91
+ logger.info(f"vGPU initialized with device {device}")
92
+
93
+ # Load model using pipeline
94
+ model_id = "openai/gpt-oss-20b"
95
+ logger.info(f"Loading {model_id}")
96
+
97
+ try:
98
+ # Disable transformers logging temporarily
99
+ transformers_logger = logging.getLogger("transformers")
100
+ original_level = transformers_logger.level
101
+ transformers_logger.setLevel(logging.ERROR)
102
+
103
+ try:
104
+ # Create pipeline
105
+ # Create pipeline with vGPU device
106
+ pipe = pipeline(
107
+ "text-generation",
108
+ model=model_id,
109
+ torch_dtype=torch.float32, # Use full precision
110
+ device=device, # Use our vGPU device
111
+ use_safetensors=True,
112
+ trust_remote_code=True
113
+ )
114
+ status['pipeline_loaded'] = True
115
+
116
+ # Move pipeline model to vGPU
117
+ pipe.model = to_vgpu(pipe.model, vram=vram)
118
+ pipe.model.eval()
119
+ status['model_on_vgpu'] = True
120
+
121
+ # Log model details
122
+ logger.info(f"Pipeline created with model: {model_id}")
123
+
124
+ # Log model size
125
+ model_size = get_model_size(pipe.model)
126
+ logger.info(f"Model loaded: {model_size/1e9:.2f} GB in parameters")
127
+ logger.info(f"Model architecture: {pipe.model.__class__.__name__}")
128
+
129
+ # Verify model location
130
+ with torch.device(device):
131
+ current_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
132
+ logger.info(f"Model memory usage: {(current_mem - initial_mem)/1e9:.2f} GB")
133
+
134
+ finally:
135
+ # Restore original logging level
136
+ transformers_logger.setLevel(original_level)
137
+
138
+ except Exception as e:
139
+ logger.error(f"Model loading failed: {str(e)}")
140
+ raise
141
+ except Exception as e:
142
+ logger.error(f"Model transfer to vGPU failed: {str(e)}")
143
+ raise
144
+
145
+ # Run text generation
146
+ logger.info("Running text generation...")
147
+ start = time.time()
148
+ peak_mem = initial_mem
149
+
150
+ try:
151
+ # Prepare input prompt
152
+ prompt = "Explain how virtual GPUs work in simple terms."
153
+
154
+ with torch.no_grad():
155
+ # Generate text
156
+ outputs = pipe(
157
+ prompt,
158
+ max_new_tokens=256,
159
+ temperature=0.7,
160
+ top_p=0.95,
161
+ top_k=40,
162
+ num_beams=1,
163
+ do_sample=True,
164
+ return_full_text=True
165
+ )
166
+
167
+ if hasattr(storage, 'get_used_memory'):
168
+ peak_mem = max(peak_mem, storage.get_used_memory())
169
+
170
+ inference_time = time.time() - start
171
+ status['generation_complete'] = True
172
+
173
+ # Log performance metrics
174
+ logger.info(f"\nGeneration stats:")
175
+ logger.info(f"- Time: {inference_time:.4f}s")
176
+ logger.info(f"- Memory peak: {(peak_mem - initial_mem)/1e9:.2f} GB")
177
+ logger.info(f"- Generated text: {outputs[0]['generated_text']}")
178
+
179
+ except Exception as e:
180
+ logger.error(f"Text generation failed: {str(e)}")
181
+ raise
182
+
183
+ except Exception as e:
184
+ logger.error(f"Test failed: {str(e)}")
185
+ raise
186
+ finally:
187
+ # Cleanup and status report
188
+ try:
189
+ if 'pipe' in locals():
190
+ del pipe
191
+ if 'outputs' in locals():
192
+ del outputs
193
+ torch.cuda.empty_cache() if hasattr(torch, 'cuda') else None
194
+ status['cleanup_success'] = True
195
+ except Exception as e:
196
+ logger.error(f"Cleanup error: {str(e)}")
197
+
198
+ logger.info("\nTest Summary:")
199
+ for key, value in status.items():
200
+ logger.info(f"- {key}: {'✓' if value else '✗'}")
201
+
202
+ final_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
203
+ if final_mem > initial_mem:
204
+ logger.warning(f"Memory leak detected: {(final_mem - initial_mem)/1e6:.2f} MB")
205
+
206
+ if __name__ == "__main__":
207
+ test_ai_integration_http()
208
+
torch_vgpu.py CHANGED
@@ -1,178 +1,182 @@
1
- """
2
- Custom PyTorch device implementation that routes operations through our virtual GPU.
3
- """
4
- import torch
5
- from torch.library import Library, impl
6
- from typing import Optional, Union, Tuple
7
- import numpy as np
8
- from virtual_vram import VirtualVRAM
9
-
10
- # Initialize custom backend
11
- def init_vgpu_backend():
12
- try:
13
- # First rename the backend
14
- torch.utils.rename_privateuse1_backend("vgpu")
15
-
16
- # Then generate all the necessary methods
17
- torch.utils.generate_methods_for_privateuse1_backend(
18
- for_tensor=True,
19
- for_module=True,
20
- for_packed_sequence=True,
21
- for_storage=True
22
- )
23
-
24
- # Register our custom library
25
- lib = Library("vgpu", "DEF")
26
- lib.define("custom_op(Tensor self) -> Tensor")
27
-
28
- @impl("vgpu", "custom_op", "Tensor")
29
- def custom_op_impl(tensor):
30
- return tensor.clone()
31
-
32
- return True
33
- except Exception as e:
34
- print(f"Backend initialization warning: {e}")
35
- return False
36
-
37
- # Initialize the backend
38
- VGPU_BACKEND_INITIALIZED = init_vgpu_backend()
39
-
40
- class VGPUStorage(torch.Storage):
41
- """Custom storage class that uses our virtual VRAM"""
42
-
43
- def __init__(self, *args, **kwargs):
44
- super().__init__(*args, **kwargs)
45
- self.vram = kwargs.get('vram')
46
- if not self.vram:
47
- from virtual_vram import VirtualVRAM
48
- self.vram = VirtualVRAM()
49
- self.tensor_id = kwargs.get('tensor_id', f"tensor_{id(self)}")
50
-
51
- def _new_shared(self, size):
52
- return VGPUStorage(size, vram=self.vram)
53
-
54
- class VGPUTensor:
55
- """Tensor implementation that uses vGPU for computations"""
56
- @staticmethod
57
- def __new__(cls, elem):
58
- return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
59
-
60
- class VGPUDevice:
61
- """
62
- Custom PyTorch device implementation that routes operations through vGPU.
63
- Usage:
64
- vgpu = VGPUDevice()
65
- with vgpu.mode():
66
- tensor = torch.randn(2, 3) # Will be on vGPU
67
- """
68
- _VGPU_INSTANCES = {} # Class-level dict to track instances
69
-
70
- def __init__(self, vram: Optional[VirtualVRAM] = None):
71
- self.vram = vram or VirtualVRAM()
72
- self.tensor_cores = None # Will be initialized when needed
73
- self.device_name = "vgpu" # Both internal and user-facing name
74
- self._register_device()
75
-
76
- def _register_device(self):
77
- """Register vGPU device using PyTorch's device system"""
78
- try:
79
- if not VGPU_BACKEND_INITIALIZED:
80
- raise RuntimeError("VGPU backend not properly initialized")
81
-
82
- # Create device with explicit index
83
- self._device = torch.device("vgpu")
84
-
85
- # Store this instance for reuse
86
- VGPUDevice._VGPU_INSTANCES[self.device_name] = self
87
-
88
- # Define custom operations for the device
89
- class VGPUAllocator:
90
- def __init__(self, vram, device):
91
- self.vram = vram
92
- self.device = device
93
-
94
- def __call__(self, size, dtype=None, device=None):
95
- # Create tensor on CPU first
96
- cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
97
- # Move to vGPU storage
98
- return to_vgpu(cpu_tensor, self.vram)
99
-
100
- # Set up allocator
101
- self._allocator = VGPUAllocator(self.vram, self._device)
102
-
103
- except Exception as e:
104
- raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
105
-
106
- @property
107
- def type(self):
108
- return self.internal_name
109
-
110
- def __str__(self):
111
- return f"{self.internal_name}:0"
112
-
113
- def __repr__(self):
114
- return f"vgpu(device='{self.internal_name}:0')"
115
-
116
- def device(self):
117
- """Get the PyTorch device object that maps to our vGPU"""
118
- return self._device # Return the already created device object
119
-
120
- def mode(self):
121
- """Get a context manager for vGPU operations"""
122
- return torch.device(self._device)
123
-
124
- def _init_tensor_cores(self):
125
- if self.tensor_cores is None:
126
- from tensor_core import TensorCoreArray
127
- self.tensor_cores = TensorCoreArray()
128
-
129
- def _to_vram(self, tensor: torch.Tensor) -> str:
130
- """Store tensor data in virtual VRAM"""
131
- tensor_id = f"tensor_{id(tensor)}"
132
- data = tensor.detach().cpu().numpy()
133
- self.vram.storage.store_tensor(tensor_id, data)
134
- return tensor_id
135
-
136
- def _from_vram(self, tensor_id: str) -> torch.Tensor:
137
- """Retrieve tensor data from virtual VRAM"""
138
- data = self.vram.storage.load_tensor(tensor_id)
139
- return torch.from_numpy(data)
140
-
141
- def matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
142
- """Matrix multiplication using tensor cores"""
143
- self._init_tensor_cores()
144
-
145
- # Store inputs in VRAM
146
- a_id = self._to_vram(a)
147
- b_id = self._to_vram(b)
148
-
149
- # Perform matmul using tensor cores
150
- result = self.tensor_cores.matmul(
151
- self.vram.storage.load_tensor(a_id),
152
- self.vram.storage.load_tensor(b_id)
153
- )
154
-
155
- # Create new tensor with result
156
- return torch.from_numpy(result)
157
-
158
- def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
159
- """Move a tensor to vGPU device"""
160
- if not isinstance(tensor, torch.Tensor):
161
- tensor = torch.tensor(tensor)
162
-
163
- # Get or create vGPU device
164
- if not VGPUDevice._VGPU_INSTANCES:
165
- device = VGPUDevice(vram)
166
- else:
167
- device = next(iter(VGPUDevice._VGPU_INSTANCES.values()))
168
- if vram is not None:
169
- device.vram = vram
170
-
171
- # Move data to vRAM
172
- tensor_id = device._to_vram(tensor)
173
- result = device._from_vram(tensor_id)
174
- result.requires_grad = tensor.requires_grad
175
-
176
- # Set the device using the internal name
177
- result.data = result.data.to(device._device)
178
- return result
 
 
 
 
 
1
+ """
2
+ Custom PyTorch device implementation that routes operations through our virtual GPU.
3
+ """
4
+ import torch
5
+ from torch.library import Library, impl
6
+ from typing import Optional, Union, Tuple
7
+ import numpy as np
8
+ from virtual_vram import VirtualVRAM
9
+
10
+ # Global flag for backend initialization
11
+ VGPU_BACKEND_INITIALIZED = False
12
+
13
+ def init_vgpu_backend():
14
+ """Initialize the vGPU backend. Must be called before creating any VGPUDevice instances."""
15
+ global VGPU_BACKEND_INITIALIZED
16
+ try:
17
+ if not VGPU_BACKEND_INITIALIZED:
18
+ # First rename the backend
19
+ torch.utils.rename_privateuse1_backend("vgpu")
20
+
21
+ # Then generate all the necessary methods
22
+ torch.utils.generate_methods_for_privateuse1_backend(
23
+ for_tensor=True,
24
+ for_module=True,
25
+ for_packed_sequence=True,
26
+ for_storage=True
27
+ )
28
+
29
+ # Register our custom library
30
+ lib = Library("vgpu", "DEF")
31
+ lib.define("custom_op(Tensor self) -> Tensor")
32
+
33
+ @impl("vgpu", "custom_op", "Tensor")
34
+ def custom_op_impl(tensor):
35
+ return tensor.clone()
36
+
37
+ VGPU_BACKEND_INITIALIZED = True
38
+
39
+ return VGPU_BACKEND_INITIALIZED
40
+ except Exception as e:
41
+ print(f"Backend initialization warning: {e}")
42
+ return False
43
+
44
+ class VGPUStorage(torch.Storage):
45
+ """Custom storage class that uses our virtual VRAM"""
46
+
47
+ def __init__(self, *args, **kwargs):
48
+ super().__init__(*args, **kwargs)
49
+ self.vram = kwargs.get('vram')
50
+ if not self.vram:
51
+ from virtual_vram import VirtualVRAM
52
+ self.vram = VirtualVRAM()
53
+ self.tensor_id = kwargs.get('tensor_id', f"tensor_{id(self)}")
54
+
55
+ def _new_shared(self, size):
56
+ return VGPUStorage(size, vram=self.vram)
57
+
58
+ class VGPUTensor:
59
+ """Tensor implementation that uses vGPU for computations"""
60
+ @staticmethod
61
+ def __new__(cls, elem):
62
+ return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
63
+
64
+ class VGPUDevice:
65
+ """
66
+ Custom PyTorch device implementation that routes operations through vGPU.
67
+ Usage:
68
+ vgpu = VGPUDevice()
69
+ with vgpu.mode():
70
+ tensor = torch.randn(2, 3) # Will be on vGPU
71
+ """
72
+ _VGPU_INSTANCES = {} # Class-level dict to track instances
73
+
74
+ def __init__(self, vram: Optional[VirtualVRAM] = None):
75
+ self.vram = vram or VirtualVRAM()
76
+ self.tensor_cores = None # Will be initialized when needed
77
+ self.device_name = "vgpu" # Both internal and user-facing name
78
+ self._register_device()
79
+
80
+ def _register_device(self):
81
+ """Register vGPU device using PyTorch's device system"""
82
+ try:
83
+ if not VGPU_BACKEND_INITIALIZED:
84
+ raise RuntimeError("VGPU backend not properly initialized")
85
+
86
+ # Create device with explicit index
87
+ self._device = torch.device("vgpu")
88
+
89
+ # Store this instance for reuse
90
+ VGPUDevice._VGPU_INSTANCES[self.device_name] = self
91
+
92
+ # Define custom operations for the device
93
+ class VGPUAllocator:
94
+ def __init__(self, vram, device):
95
+ self.vram = vram
96
+ self.device = device
97
+
98
+ def __call__(self, size, dtype=None, device=None):
99
+ # Create tensor on CPU first
100
+ cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
101
+ # Move to vGPU storage
102
+ return to_vgpu(cpu_tensor, self.vram)
103
+
104
+ # Set up allocator
105
+ self._allocator = VGPUAllocator(self.vram, self._device)
106
+
107
+ except Exception as e:
108
+ raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
109
+
110
+ @property
111
+ def type(self):
112
+ return self.internal_name
113
+
114
+ def __str__(self):
115
+ return f"{self.internal_name}:0"
116
+
117
+ def __repr__(self):
118
+ return f"vgpu(device='{self.internal_name}:0')"
119
+
120
+ def device(self):
121
+ """Get the PyTorch device object that maps to our vGPU"""
122
+ return self._device # Return the already created device object
123
+
124
+ def mode(self):
125
+ """Get a context manager for vGPU operations"""
126
+ return torch.device(self._device)
127
+
128
+ def _init_tensor_cores(self):
129
+ if self.tensor_cores is None:
130
+ from tensor_core import TensorCoreArray
131
+ self.tensor_cores = TensorCoreArray()
132
+
133
+ def _to_vram(self, tensor: torch.Tensor) -> str:
134
+ """Store tensor data in virtual VRAM"""
135
+ tensor_id = f"tensor_{id(tensor)}"
136
+ data = tensor.detach().cpu().numpy()
137
+ self.vram.storage.store_tensor(tensor_id, data)
138
+ return tensor_id
139
+
140
+ def _from_vram(self, tensor_id: str) -> torch.Tensor:
141
+ """Retrieve tensor data from virtual VRAM"""
142
+ data = self.vram.storage.load_tensor(tensor_id)
143
+ return torch.from_numpy(data)
144
+
145
+ def matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
146
+ """Matrix multiplication using tensor cores"""
147
+ self._init_tensor_cores()
148
+
149
+ # Store inputs in VRAM
150
+ a_id = self._to_vram(a)
151
+ b_id = self._to_vram(b)
152
+
153
+ # Perform matmul using tensor cores
154
+ result = self.tensor_cores.matmul(
155
+ self.vram.storage.load_tensor(a_id),
156
+ self.vram.storage.load_tensor(b_id)
157
+ )
158
+
159
+ # Create new tensor with result
160
+ return torch.from_numpy(result)
161
+
162
+ def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
163
+ """Move a tensor to vGPU device"""
164
+ if not isinstance(tensor, torch.Tensor):
165
+ tensor = torch.tensor(tensor)
166
+
167
+ # Get or create vGPU device
168
+ if not VGPUDevice._VGPU_INSTANCES:
169
+ device = VGPUDevice(vram)
170
+ else:
171
+ device = next(iter(VGPUDevice._VGPU_INSTANCES.values()))
172
+ if vram is not None:
173
+ device.vram = vram
174
+
175
+ # Move data to vRAM
176
+ tensor_id = device._to_vram(tensor)
177
+ result = device._from_vram(tensor_id)
178
+ result.requires_grad = tensor.requires_grad
179
+
180
+ # Set the device using the internal name
181
+ result.data = result.data.to(device._device)
182
+ return result