Factor Studios commited on
Commit
7670c1d
·
verified ·
1 Parent(s): 89dccee

Update test_ai_integration_http.py

Browse files
Files changed (1) hide show
  1. test_ai_integration_http.py +30 -69
test_ai_integration_http.py CHANGED
@@ -9,12 +9,7 @@ from contextlib import contextmanager
9
  from typing import Any, Optional
10
 
11
  import torch
12
- import torch.nn.functional as F
13
- from transformers import (
14
- AutoTokenizer,
15
- AutoModelForCausalLM,
16
- TextStreamer
17
- )
18
  from virtual_vram import VirtualVRAM
19
  from http_storage import HTTPGPUStorage
20
  from torch_vgpu import VGPUDevice, to_vgpu
@@ -70,12 +65,11 @@ def prepare_prompt(instruction: str) -> str:
70
  return f"<s>[INST] {instruction} [/INST]"
71
 
72
  def test_ai_integration_http():
73
- """Test Llama-2-7b-instruct model on vGPU with text generation"""
74
  logger.info("Starting vGPU text generation test")
75
 
76
  status = {
77
- 'model_loaded': False,
78
- 'tokenizer_loaded': False,
79
  'model_on_vgpu': False,
80
  'generation_complete': False,
81
  'cleanup_success': False
@@ -91,9 +85,9 @@ def test_ai_integration_http():
91
  device = setup_vgpu()
92
  logger.info(f"vGPU initialized with device {device}")
93
 
94
- # Load Llama model and tokenizer
95
- model_name = "meta-llama/Llama-2-7b-chat-hf"
96
- logger.info(f"Loading {model_name}")
97
 
98
  try:
99
  # Disable transformers logging temporarily
@@ -102,32 +96,26 @@ def test_ai_integration_http():
102
  transformers_logger.setLevel(logging.ERROR)
103
 
104
  try:
105
- # Load tokenizer first
106
- tokenizer = AutoTokenizer.from_pretrained(
107
- model_name,
108
- trust_remote_code=True,
109
- use_fast=True
 
110
  )
111
- status['tokenizer_loaded'] = True
112
 
113
- # Load model with full precision
114
- model = AutoModelForCausalLM.from_pretrained(
115
- model_name,
116
- trust_remote_code=True,
117
- torch_dtype=torch.float32, # Use full precision
118
- device_map=None, # Don't auto-map devices
119
- use_safetensors=True
120
- )
121
- status['model_loaded'] = True
122
 
123
  # Log model details
124
- logger.info(f"Tokenizer type: {type(tokenizer).__name__}")
125
- logger.info(f"Model type: {type(model).__name__}")
126
 
127
- # Log model architecture
128
- model_size = get_model_size(model)
129
  logger.info(f"Model loaded: {model_size/1e9:.2f} GB in parameters")
130
- logger.info(f"Model architecture: {model.__class__.__name__}")
131
 
132
  finally:
133
  # Restore original logging level
@@ -137,52 +125,25 @@ def test_ai_integration_http():
137
  logger.error(f"Model loading failed: {str(e)}")
138
  raise
139
 
140
- # Move model to vGPU with verification
141
- try:
142
- model = to_vgpu(model, vram=vram)
143
- model.eval()
144
- status['model_on_vgpu'] = True
145
-
146
- # Verify model location
147
- with torch.device(device):
148
- for param in model.parameters():
149
- if param.device != device:
150
- raise RuntimeError(f"Model parameter not on vGPU device. Found device: {param.device}")
151
-
152
- current_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
153
- logger.info(f"Model memory usage: {(current_mem - initial_mem)/1e9:.2f} GB")
154
- except Exception as e:
155
- logger.error(f"Model transfer to vGPU failed: {str(e)}")
156
- raise
157
-
158
  # Run text generation
159
  logger.info("Running text generation...")
160
  start = time.time()
161
  peak_mem = initial_mem
162
 
163
  try:
164
- # Prepare input prompt
165
- instruction = "Explain how virtual GPUs work in simple terms."
166
- prompt = prepare_prompt(instruction)
167
-
168
- # Tokenize input
169
- inputs = tokenizer(prompt, return_tensors="pt")
170
- inputs = {k: to_vgpu(v, vram=vram) for k, v in inputs.items()}
171
-
172
- # Set up streamer for token-by-token output
173
- streamer = TextStreamer(tokenizer)
174
 
175
  with torch.no_grad():
176
  # Generate text
177
- outputs = model.generate(
178
- **inputs,
179
- max_length=512,
180
  temperature=0.7,
181
  top_p=0.95,
182
- top_k=40,
183
- num_beams=1,
184
- streamer=streamer,
185
- pad_token_id=tokenizer.pad_token_id
186
  )
187
 
188
  if hasattr(storage, 'get_used_memory'):
@@ -195,8 +156,8 @@ def test_ai_integration_http():
195
  logger.info(f"\nGeneration stats:")
196
  logger.info(f"- Time: {inference_time:.4f}s")
197
  logger.info(f"- Memory peak: {(peak_mem - initial_mem)/1e9:.2f} GB")
198
- logger.info(f"- Output length: {len(outputs[0])}")
199
- logger.info(f"- Output device: {outputs.device}")
200
 
201
  except Exception as e:
202
  logger.error(f"Text generation failed: {str(e)}")
 
9
  from typing import Any, Optional
10
 
11
  import torch
12
+ from transformers import pipeline
 
 
 
 
 
13
  from virtual_vram import VirtualVRAM
14
  from http_storage import HTTPGPUStorage
15
  from torch_vgpu import VGPUDevice, to_vgpu
 
65
  return f"<s>[INST] {instruction} [/INST]"
66
 
67
  def test_ai_integration_http():
68
+ """Test GPT OSS model on vGPU with text generation"""
69
  logger.info("Starting vGPU text generation test")
70
 
71
  status = {
72
+ 'pipeline_loaded': False,
 
73
  'model_on_vgpu': False,
74
  'generation_complete': False,
75
  'cleanup_success': False
 
85
  device = setup_vgpu()
86
  logger.info(f"vGPU initialized with device {device}")
87
 
88
+ # Load model using pipeline
89
+ model_id = "openai/gpt-oss-20b"
90
+ logger.info(f"Loading {model_id}")
91
 
92
  try:
93
  # Disable transformers logging temporarily
 
96
  transformers_logger.setLevel(logging.ERROR)
97
 
98
  try:
99
+ # Create pipeline
100
+ pipe = pipeline(
101
+ "text-generation",
102
+ model=model_id,
103
+ torch_dtype="auto",
104
+ device=device # Use our vGPU device
105
  )
106
+ status['pipeline_loaded'] = True
107
 
108
+ # Move pipeline to vGPU
109
+ pipe.model = to_vgpu(pipe.model, vram=vram)
110
+ status['model_on_vgpu'] = True
 
 
 
 
 
 
111
 
112
  # Log model details
113
+ logger.info(f"Pipeline created with model: {model_id}")
 
114
 
115
+ # Log model size
116
+ model_size = get_model_size(pipe.model)
117
  logger.info(f"Model loaded: {model_size/1e9:.2f} GB in parameters")
118
+ logger.info(f"Model architecture: {pipe.model.__class__.__name__}")
119
 
120
  finally:
121
  # Restore original logging level
 
125
  logger.error(f"Model loading failed: {str(e)}")
126
  raise
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  # Run text generation
129
  logger.info("Running text generation...")
130
  start = time.time()
131
  peak_mem = initial_mem
132
 
133
  try:
134
+ # Prepare messages
135
+ messages = [
136
+ {"role": "user", "content": "Explain how virtual GPUs work in simple terms."}
137
+ ]
 
 
 
 
 
 
138
 
139
  with torch.no_grad():
140
  # Generate text
141
+ outputs = pipe(
142
+ messages,
143
+ max_new_tokens=256,
144
  temperature=0.7,
145
  top_p=0.95,
146
+ top_k=40
 
 
 
147
  )
148
 
149
  if hasattr(storage, 'get_used_memory'):
 
156
  logger.info(f"\nGeneration stats:")
157
  logger.info(f"- Time: {inference_time:.4f}s")
158
  logger.info(f"- Memory peak: {(peak_mem - initial_mem)/1e9:.2f} GB")
159
+ logger.info(f"- Output length: {len(outputs[0]['generated_text'])}")
160
+ logger.info(f"- Generated text: {outputs[0]['generated_text']}")
161
 
162
  except Exception as e:
163
  logger.error(f"Text generation failed: {str(e)}")