Factor Studios commited on
Commit
556bf3a
·
verified ·
1 Parent(s): 54aca07

Update test_ai_integration_http.py

Browse files
Files changed (1) hide show
  1. test_ai_integration_http.py +22 -14
test_ai_integration_http.py CHANGED
@@ -17,10 +17,6 @@ from torch_vgpu import VGPUDevice, to_vgpu
17
  def setup_vgpu():
18
  """Setup vGPU device"""
19
  try:
20
- # Register vGPU device type
21
- if not hasattr(torch, 'vgpu'):
22
- torch.register_privateuseone_backend()
23
-
24
  # Create and register vGPU device
25
  vgpu = VGPUDevice()
26
  device = vgpu.device()
@@ -100,18 +96,21 @@ def test_ai_integration_http():
100
  transformers_logger.setLevel(logging.ERROR)
101
 
102
  try:
103
- # Create pipeline and manually move to vGPU
 
104
  pipe = pipeline(
105
  "text-generation",
106
  model=model_id,
107
- torch_dtype=torch.float32,
108
- device_map=None # Don't auto-place on devices
 
 
109
  )
110
  status['pipeline_loaded'] = True
111
 
112
- # Move model to vGPU
113
- pipe.model = pipe.model.to(device)
114
  pipe.model = to_vgpu(pipe.model, vram=vram)
 
115
  status['model_on_vgpu'] = True
116
 
117
  # Log model details
@@ -121,6 +120,11 @@ def test_ai_integration_http():
121
  model_size = get_model_size(pipe.model)
122
  logger.info(f"Model loaded: {model_size/1e9:.2f} GB in parameters")
123
  logger.info(f"Model architecture: {pipe.model.__class__.__name__}")
 
 
 
 
 
124
 
125
  finally:
126
  # Restore original logging level
@@ -129,6 +133,9 @@ def test_ai_integration_http():
129
  except Exception as e:
130
  logger.error(f"Model loading failed: {str(e)}")
131
  raise
 
 
 
132
 
133
  # Run text generation
134
  logger.info("Running text generation...")
@@ -136,18 +143,20 @@ def test_ai_integration_http():
136
  peak_mem = initial_mem
137
 
138
  try:
139
- # Prepare input text
140
- text = "Explain how virtual GPUs work in simple terms."
141
 
142
  with torch.no_grad():
143
  # Generate text
144
  outputs = pipe(
145
- text,
146
  max_new_tokens=256,
147
  temperature=0.7,
148
  top_p=0.95,
149
  top_k=40,
150
- do_sample=True
 
 
151
  )
152
 
153
  if hasattr(storage, 'get_used_memory'):
@@ -160,7 +169,6 @@ def test_ai_integration_http():
160
  logger.info(f"\nGeneration stats:")
161
  logger.info(f"- Time: {inference_time:.4f}s")
162
  logger.info(f"- Memory peak: {(peak_mem - initial_mem)/1e9:.2f} GB")
163
- logger.info(f"- Output length: {len(outputs[0]['generated_text'])}")
164
  logger.info(f"- Generated text: {outputs[0]['generated_text']}")
165
 
166
  except Exception as e:
 
17
  def setup_vgpu():
18
  """Setup vGPU device"""
19
  try:
 
 
 
 
20
  # Create and register vGPU device
21
  vgpu = VGPUDevice()
22
  device = vgpu.device()
 
96
  transformers_logger.setLevel(logging.ERROR)
97
 
98
  try:
99
+ # Create pipeline
100
+ # Create pipeline with vGPU device
101
  pipe = pipeline(
102
  "text-generation",
103
  model=model_id,
104
+ torch_dtype=torch.float32, # Use full precision
105
+ device=device, # Use our vGPU device
106
+ use_safetensors=True,
107
+ trust_remote_code=True
108
  )
109
  status['pipeline_loaded'] = True
110
 
111
+ # Move pipeline model to vGPU
 
112
  pipe.model = to_vgpu(pipe.model, vram=vram)
113
+ pipe.model.eval()
114
  status['model_on_vgpu'] = True
115
 
116
  # Log model details
 
120
  model_size = get_model_size(pipe.model)
121
  logger.info(f"Model loaded: {model_size/1e9:.2f} GB in parameters")
122
  logger.info(f"Model architecture: {pipe.model.__class__.__name__}")
123
+
124
+ # Verify model location
125
+ with torch.device(device):
126
+ current_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
127
+ logger.info(f"Model memory usage: {(current_mem - initial_mem)/1e9:.2f} GB")
128
 
129
  finally:
130
  # Restore original logging level
 
133
  except Exception as e:
134
  logger.error(f"Model loading failed: {str(e)}")
135
  raise
136
+ except Exception as e:
137
+ logger.error(f"Model transfer to vGPU failed: {str(e)}")
138
+ raise
139
 
140
  # Run text generation
141
  logger.info("Running text generation...")
 
143
  peak_mem = initial_mem
144
 
145
  try:
146
+ # Prepare input prompt
147
+ prompt = "Explain how virtual GPUs work in simple terms."
148
 
149
  with torch.no_grad():
150
  # Generate text
151
  outputs = pipe(
152
+ prompt,
153
  max_new_tokens=256,
154
  temperature=0.7,
155
  top_p=0.95,
156
  top_k=40,
157
+ num_beams=1,
158
+ do_sample=True,
159
+ return_full_text=True
160
  )
161
 
162
  if hasattr(storage, 'get_used_memory'):
 
169
  logger.info(f"\nGeneration stats:")
170
  logger.info(f"- Time: {inference_time:.4f}s")
171
  logger.info(f"- Memory peak: {(peak_mem - initial_mem)/1e9:.2f} GB")
 
172
  logger.info(f"- Generated text: {outputs[0]['generated_text']}")
173
 
174
  except Exception as e: