api-gpu / test_multi_model.py
gary-boon
Add Code Llama 7B support with hardware-aware filtering and ICL timeout fixes
ed40a9a
#!/usr/bin/env python3
"""
Test script for multi-model support
Tests model switching and generation with CodeGen and Code-Llama
"""
import requests
import time
import sys
import json
BASE_URL = "http://localhost:8000"
def print_header(text):
"""Print a formatted header"""
print("\n" + "="*60)
print(f" {text}")
print("="*60)
def print_result(success, message):
"""Print test result"""
status = "✅ PASS" if success else "❌ FAIL"
print(f"{status}: {message}")
return success
def test_health_check():
"""Test if backend is running"""
print_header("1. Health Check")
try:
response = requests.get(f"{BASE_URL}/health", timeout=5)
data = response.json()
print(f"Status: {data.get('status')}")
print(f"Model loaded: {data.get('model_loaded')}")
print(f"Device: {data.get('device')}")
return print_result(response.status_code == 200, "Backend is running")
except requests.exceptions.ConnectionError:
return print_result(False, "Cannot connect to backend. Is it running?")
except Exception as e:
return print_result(False, f"Health check failed: {e}")
def test_list_models():
"""Test listing available models"""
print_header("2. List Available Models")
try:
response = requests.get(f"{BASE_URL}/models", timeout=5)
data = response.json()
models = data.get('models', [])
print(f"Found {len(models)} models:")
for model in models:
status = "✓" if model['available'] else "✗"
current = " (CURRENT)" if model['is_current'] else ""
print(f" {status} {model['name']} ({model['size']}) - {model['architecture']}{current}")
return print_result(len(models) >= 2, f"Found {len(models)} models")
except Exception as e:
return print_result(False, f"List models failed: {e}")
def test_current_model():
"""Test getting current model info"""
print_header("3. Get Current Model Info")
try:
response = requests.get(f"{BASE_URL}/models/current", timeout=5)
data = response.json()
print(f"Current model: {data.get('name')}")
print(f"Model ID: {data.get('id')}")
config = data.get('config', {})
print(f"Layers: {config.get('num_layers')}")
print(f"Heads: {config.get('num_heads')}")
print(f"Attention: {config.get('attention_type')}")
return print_result(response.status_code == 200, "Got current model info")
except Exception as e:
return print_result(False, f"Get current model failed: {e}")
def test_generation(model_name, prompt="def fibonacci(n):\n ", max_tokens=30):
"""Test text generation"""
print_header(f"4. Test Generation with {model_name}")
print(f"Prompt: {repr(prompt)}")
print(f"Generating {max_tokens} tokens...")
try:
response = requests.post(
f"{BASE_URL}/generate",
json={
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": 0.7,
"extract_traces": False # Faster for testing
},
timeout=60 # Generation can take a while
)
if response.status_code != 200:
return print_result(False, f"Generation failed: {response.status_code}")
data = response.json()
generated = data.get('generated_text', '')
tokens = data.get('tokens', [])
print(f"\nGenerated text:")
print("-" * 60)
print(generated)
print("-" * 60)
print(f"Token count: {len(tokens)}")
print(f"Confidence: {data.get('confidence', 0):.3f}")
print(f"Perplexity: {data.get('perplexity', 0):.3f}")
return print_result(len(tokens) > 0, f"Generated {len(tokens)} tokens")
except Exception as e:
return print_result(False, f"Generation failed: {e}")
def test_model_switch(model_id, model_name):
"""Test switching to a different model"""
print_header(f"5. Switch to {model_name}")
print(f"Switching to model: {model_id}")
print("⏳ This may take a while (downloading + loading model)...")
try:
response = requests.post(
f"{BASE_URL}/models/switch",
json={"model_id": model_id},
timeout=300 # 5 minutes for download + loading
)
if response.status_code != 200:
return print_result(False, f"Switch failed: {response.status_code}")
data = response.json()
print(f"Message: {data.get('message')}")
# Verify switch by getting current model
verify_response = requests.get(f"{BASE_URL}/models/current", timeout=5)
verify_data = verify_response.json()
current_id = verify_data.get('id')
success = current_id == model_id
return print_result(success, f"Switched to {model_name}" if success else "Switch verification failed")
except requests.exceptions.Timeout:
return print_result(False, "Switch timeout - model download may be in progress")
except Exception as e:
return print_result(False, f"Switch failed: {e}")
def test_model_info():
"""Test detailed model info endpoint"""
print_header("6. Get Detailed Model Info")
try:
response = requests.get(f"{BASE_URL}/model/info", timeout=5)
data = response.json()
print(f"Model: {data.get('name')}")
print(f"Architecture: {data.get('architecture')}")
print(f"Parameters: {data.get('totalParams'):,}")
print(f"Layers: {data.get('layers')}")
print(f"Heads: {data.get('heads')}")
if data.get('kv_heads'):
print(f"KV Heads: {data.get('kv_heads')} (GQA)")
print(f"Attention type: {data.get('attention_type')}")
print(f"Vocab size: {data.get('vocabSize'):,}")
print(f"Context length: {data.get('maxPositions'):,}")
return print_result(response.status_code == 200, "Got detailed model info")
except Exception as e:
return print_result(False, f"Get model info failed: {e}")
def main():
"""Run all tests"""
print("\n🧪 Multi-Model Support Test Suite")
print("This will test model switching between CodeGen 350M and Code-Llama 7B")
print("\nIMPORTANT: Make sure the backend is running:")
print(" cd /Users/garyboon/Development/VisualisableAI/visualisable-ai-backend")
print(" python -m uvicorn backend.model_service:app --reload --port 8000")
input("\nPress Enter to start tests...")
results = []
# Test 1: Health check
results.append(test_health_check())
if not results[-1]:
print("\n❌ Backend not running. Exiting.")
sys.exit(1)
time.sleep(1)
# Test 2: List models
results.append(test_list_models())
time.sleep(1)
# Test 3: Current model (should be CodeGen)
results.append(test_current_model())
time.sleep(1)
# Test 4: Get detailed model info
results.append(test_model_info())
time.sleep(1)
# Test 5: Generate with CodeGen
results.append(test_generation("CodeGen 350M"))
time.sleep(2)
# Test 6: Switch to Code-Llama
print("\n⚠️ WARNING: Next test will download Code-Llama 7B (~14GB)")
print("This may take 5-10 minutes depending on your internet connection.")
proceed = input("Proceed with Code-Llama test? (y/n): ").lower()
if proceed == 'y':
results.append(test_model_switch("code-llama-7b", "Code-Llama 7B"))
if results[-1]:
time.sleep(2)
# Test 7: Get model info for Code-Llama
results.append(test_model_info())
time.sleep(1)
# Test 8: Generate with Code-Llama
results.append(test_generation("Code-Llama 7B"))
time.sleep(2)
# Test 9: Switch back to CodeGen
results.append(test_model_switch("codegen-350m", "CodeGen 350M"))
if results[-1]:
time.sleep(2)
# Test 10: Verify CodeGen still works
results.append(test_generation("CodeGen 350M (after switch back)"))
else:
print("\nSkipping Code-Llama tests.")
# Summary
print_header("Test Summary")
passed = sum(results)
total = len(results)
print(f"Passed: {passed}/{total} tests")
if passed == total:
print("\n🎉 All tests passed! Multi-model support is working correctly.")
return 0
else:
print(f"\n⚠️ {total - passed} test(s) failed. Check output above for details.")
return 1
if __name__ == "__main__":
sys.exit(main())