Factor Studios commited on
Commit
de03e5f
·
verified ·
1 Parent(s): 01c7c6f

Update test_ai_integration_http.py

Browse files
Files changed (1) hide show
  1. test_ai_integration_http.py +15 -12
test_ai_integration_http.py CHANGED
@@ -12,7 +12,7 @@ from typing import Any, Optional
12
  import torch
13
  from torch import nn
14
  import torch.nn.functional as F
15
- from torch.overrides import TorchFunctionMode
16
  from PIL import Image
17
  from transformers import (
18
  AutoTokenizer,
@@ -24,13 +24,16 @@ from http_storage import HTTPGPUStorage
24
  from torch_vgpu import VGPUDevice, to_vgpu
25
 
26
  def setup_vgpu():
27
- """Setup vGPU device and mode"""
28
  try:
29
- # Create device and get its mode
30
- device = VGPUDevice()
31
- mode = device.mode()
32
 
33
- return mode, device
 
 
 
34
 
35
  except Exception as e:
36
  logging.error(f"vGPU setup failed: {str(e)}")
@@ -101,9 +104,9 @@ def test_ai_integration_http():
101
  initial_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
102
  vram = VirtualVRAM(size_gb=None, storage=storage)
103
 
104
- # Initialize vGPU device and mode
105
- vgpu_mode, vgpu_device = setup_vgpu()
106
- logger.info(f"vGPU initialized with device {vgpu_device}")
107
 
108
  # Load Florence model and processor
109
  model_name = "microsoft/florence-2-large"
@@ -135,10 +138,10 @@ def test_ai_integration_http():
135
  model.eval()
136
  status['model_on_vgpu'] = True
137
 
138
- # Verify model location and device mode
139
- with vgpu_mode:
140
  for param in model.parameters():
141
- if not str(param.device).startswith('vgpu'):
142
  raise RuntimeError(f"Model parameter not on vGPU device. Found device: {param.device}")
143
 
144
  current_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
 
12
  import torch
13
  from torch import nn
14
  import torch.nn.functional as F
15
+ from torch.utils._python_dispatch import TorchFunctionMode
16
  from PIL import Image
17
  from transformers import (
18
  AutoTokenizer,
 
24
  from torch_vgpu import VGPUDevice, to_vgpu
25
 
26
  def setup_vgpu():
27
+ """Setup vGPU device"""
28
  try:
29
+ # Create and register vGPU device
30
+ vgpu = VGPUDevice()
31
+ device = vgpu.device()
32
 
33
+ # Set as default device for tensor operations
34
+ torch.set_default_device(device)
35
+
36
+ return device
37
 
38
  except Exception as e:
39
  logging.error(f"vGPU setup failed: {str(e)}")
 
104
  initial_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
105
  vram = VirtualVRAM(size_gb=None, storage=storage)
106
 
107
+ # Initialize vGPU device
108
+ device = setup_vgpu()
109
+ logger.info(f"vGPU initialized with device {device}")
110
 
111
  # Load Florence model and processor
112
  model_name = "microsoft/florence-2-large"
 
138
  model.eval()
139
  status['model_on_vgpu'] = True
140
 
141
+ # Verify model location
142
+ with torch.device(device):
143
  for param in model.parameters():
144
+ if param.device != device:
145
  raise RuntimeError(f"Model parameter not on vGPU device. Found device: {param.device}")
146
 
147
  current_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0