Nexa_Labs / examples /test_model_server.py
Allanatrix's picture
Upload 57 files
d8328bf verified
"""Test script to verify model server connection and CUDA availability."""
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).resolve().parents[1]
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
import httpx
import yaml
def test_model_server():
"""Test connection to model server."""
# Load config
config_path = Path("agent/config.yaml")
if not config_path.exists():
print("❌ Config file not found!")
return False
with config_path.open() as f:
config = yaml.safe_load(f)
model_server_cfg = config.get("model_server", {})
base_url = model_server_cfg.get("base_url", "http://127.0.0.1:8001")
enabled = model_server_cfg.get("enabled", False)
print("=" * 80)
print("Testing Model Server Connection")
print("=" * 80)
print(f"Server URL: {base_url}")
print(f"Remote model enabled: {enabled}")
print()
# Test health endpoint
try:
print("1. Testing health endpoint...")
response = httpx.get(f"{base_url}/health", timeout=5.0)
response.raise_for_status()
health = response.json()
print(f" βœ“ Health check passed")
print(f" Status: {health.get('status')}")
print(f" Model loaded: {health.get('model_loaded')}")
print(f" GPU available: {health.get('gpu_available')}")
if health.get("model_device"):
print(f" Model device: {health.get('model_device')}")
if health.get("gpu_name"):
print(f" GPU: {health.get('gpu_name')}")
if health.get("gpu_memory_allocated_gb"):
print(f" GPU Memory: {health.get('gpu_memory_allocated_gb')} GB / {health.get('gpu_memory_total_gb')} GB")
print()
except httpx.RequestError as e:
print(f" ❌ Failed to connect to model server: {e}")
print(f" Make sure model server is running in agent tmux!")
return False
except Exception as e:
print(f" ❌ Error: {e}")
return False
# Test generation
try:
print("2. Testing generation...")
test_prompt = {
"messages": [
{"role": "user", "content": "What is 2+2? Answer in one word."}
],
"max_new_tokens": 10
}
response = httpx.post(f"{base_url}/generate", json=test_prompt, timeout=30.0)
response.raise_for_status()
result = response.json()
print(f" βœ“ Generation successful")
print(f" Response: {result.get('text', '')[:100]}")
print()
except Exception as e:
print(f" ❌ Generation failed: {e}")
return False
# Test remote client
try:
print("3. Testing remote client...")
from agent.client_llm_remote import RemoteNexaSciClient
from agent.client_llm import Message
client = RemoteNexaSciClient(base_url=base_url)
messages = [Message(role="user", content="Say 'hello' in one word.")]
response = client.generate(messages, max_new_tokens=5)
print(f" βœ“ Remote client works")
print(f" Response: {response[:50]}")
print()
except Exception as e:
print(f" ❌ Remote client failed: {e}")
import traceback
traceback.print_exc()
return False
print("=" * 80)
print("βœ“ All tests passed! Model server is working correctly.")
print("=" * 80)
return True
if __name__ == "__main__":
success = test_model_server()
sys.exit(0 if success else 1)