Spaces:
Paused
Paused
| """Test script to verify model server connection and CUDA availability.""" | |
| import sys | |
| from pathlib import Path | |
| # Add project root to path | |
| project_root = Path(__file__).resolve().parents[1] | |
| if str(project_root) not in sys.path: | |
| sys.path.insert(0, str(project_root)) | |
| import httpx | |
| import yaml | |
| def test_model_server(): | |
| """Test connection to model server.""" | |
| # Load config | |
| config_path = Path("agent/config.yaml") | |
| if not config_path.exists(): | |
| print("β Config file not found!") | |
| return False | |
| with config_path.open() as f: | |
| config = yaml.safe_load(f) | |
| model_server_cfg = config.get("model_server", {}) | |
| base_url = model_server_cfg.get("base_url", "http://127.0.0.1:8001") | |
| enabled = model_server_cfg.get("enabled", False) | |
| print("=" * 80) | |
| print("Testing Model Server Connection") | |
| print("=" * 80) | |
| print(f"Server URL: {base_url}") | |
| print(f"Remote model enabled: {enabled}") | |
| print() | |
| # Test health endpoint | |
| try: | |
| print("1. Testing health endpoint...") | |
| response = httpx.get(f"{base_url}/health", timeout=5.0) | |
| response.raise_for_status() | |
| health = response.json() | |
| print(f" β Health check passed") | |
| print(f" Status: {health.get('status')}") | |
| print(f" Model loaded: {health.get('model_loaded')}") | |
| print(f" GPU available: {health.get('gpu_available')}") | |
| if health.get("model_device"): | |
| print(f" Model device: {health.get('model_device')}") | |
| if health.get("gpu_name"): | |
| print(f" GPU: {health.get('gpu_name')}") | |
| if health.get("gpu_memory_allocated_gb"): | |
| print(f" GPU Memory: {health.get('gpu_memory_allocated_gb')} GB / {health.get('gpu_memory_total_gb')} GB") | |
| print() | |
| except httpx.RequestError as e: | |
| print(f" β Failed to connect to model server: {e}") | |
| print(f" Make sure model server is running in agent tmux!") | |
| return False | |
| except Exception as e: | |
| print(f" β Error: {e}") | |
| return False | |
| # Test generation | |
| try: | |
| print("2. Testing generation...") | |
| test_prompt = { | |
| "messages": [ | |
| {"role": "user", "content": "What is 2+2? Answer in one word."} | |
| ], | |
| "max_new_tokens": 10 | |
| } | |
| response = httpx.post(f"{base_url}/generate", json=test_prompt, timeout=30.0) | |
| response.raise_for_status() | |
| result = response.json() | |
| print(f" β Generation successful") | |
| print(f" Response: {result.get('text', '')[:100]}") | |
| print() | |
| except Exception as e: | |
| print(f" β Generation failed: {e}") | |
| return False | |
| # Test remote client | |
| try: | |
| print("3. Testing remote client...") | |
| from agent.client_llm_remote import RemoteNexaSciClient | |
| from agent.client_llm import Message | |
| client = RemoteNexaSciClient(base_url=base_url) | |
| messages = [Message(role="user", content="Say 'hello' in one word.")] | |
| response = client.generate(messages, max_new_tokens=5) | |
| print(f" β Remote client works") | |
| print(f" Response: {response[:50]}") | |
| print() | |
| except Exception as e: | |
| print(f" β Remote client failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| print("=" * 80) | |
| print("β All tests passed! Model server is working correctly.") | |
| print("=" * 80) | |
| return True | |
| if __name__ == "__main__": | |
| success = test_model_server() | |
| sys.exit(0 if success else 1) | |