Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Test script for the complete federated learning system. | |
| This script tests the server, client, and web app integration. | |
| """ | |
| import requests | |
| import time | |
| import json | |
| import numpy as np | |
| from pathlib import Path | |
| import subprocess | |
| import sys | |
| import threading | |
| def test_server_health(server_url="http://localhost:8080"): | |
| """Test if the server is healthy.""" | |
| try: | |
| response = requests.get(f"{server_url}/health", timeout=5) | |
| if response.status_code == 200: | |
| print("β Server health check passed") | |
| return True | |
| else: | |
| print(f"β Server health check failed: {response.status_code}") | |
| return False | |
| except Exception as e: | |
| print(f"β Cannot connect to server: {e}") | |
| return False | |
| def test_prediction(server_url="http://localhost:8080"): | |
| """Test the prediction endpoint.""" | |
| try: | |
| # Generate test features | |
| features = np.random.randn(32).tolist() | |
| response = requests.post( | |
| f"{server_url}/predict", | |
| json={"features": features}, | |
| timeout=10 | |
| ) | |
| if response.status_code == 200: | |
| prediction = response.json().get("prediction") | |
| print(f"β Prediction test passed: {prediction:.4f}") | |
| return True | |
| else: | |
| print(f"β Prediction test failed: {response.status_code}") | |
| return False | |
| except Exception as e: | |
| print(f"β Prediction test error: {e}") | |
| return False | |
| def test_training_status(server_url="http://localhost:8080"): | |
| """Test the training status endpoint.""" | |
| try: | |
| response = requests.get(f"{server_url}/training_status", timeout=5) | |
| if response.status_code == 200: | |
| data = response.json() | |
| print(f"β Training status test passed: Round {data.get('current_round', 0)}") | |
| return True | |
| else: | |
| print(f"β Training status test failed: {response.status_code}") | |
| return False | |
| except Exception as e: | |
| print(f"β Training status test error: {e}") | |
| return False | |
| def test_client_registration(server_url="http://localhost:8080"): | |
| """Test client registration.""" | |
| try: | |
| client_info = { | |
| 'dataset_size': 100, | |
| 'model_params': 10000, | |
| 'capabilities': ['training', 'inference'] | |
| } | |
| response = requests.post( | |
| f"{server_url}/register", | |
| json={'client_id': 'test_client', 'client_info': client_info}, | |
| timeout=10 | |
| ) | |
| if response.status_code == 200: | |
| print("β Client registration test passed") | |
| return True | |
| else: | |
| print(f"β Client registration test failed: {response.status_code}") | |
| return False | |
| except Exception as e: | |
| print(f"β Client registration test error: {e}") | |
| return False | |
| def run_complete_test(): | |
| """Run all tests.""" | |
| print("π Testing Complete Federated Learning System") | |
| print("=" * 50) | |
| server_url = "http://localhost:8080" | |
| # Test server health | |
| if not test_server_health(server_url): | |
| print("\nβ Server is not running. Please start the server first:") | |
| print("python -m src.main --mode server --config config/server_config.yaml") | |
| return False | |
| # Test client registration | |
| if not test_client_registration(server_url): | |
| print("\nβ Client registration failed") | |
| return False | |
| # Test training status | |
| if not test_training_status(server_url): | |
| print("\nβ Training status failed") | |
| return False | |
| # Test prediction | |
| if not test_prediction(server_url): | |
| print("\nβ Prediction failed") | |
| return False | |
| print("\nπ All tests passed! The federated learning system is working correctly.") | |
| print("\nNext steps:") | |
| print("1. Start the web app: streamlit run webapp/streamlit_app.py") | |
| print("2. Start additional clients: python -m src.main --mode client --config config/client_config.yaml") | |
| print("3. Use the web interface to interact with the system") | |
| return True | |
| if __name__ == "__main__": | |
| success = run_complete_test() | |
| sys.exit(0 if success else 1) |