Spaces:
Sleeping
Sleeping
| """ | |
| HTTP Client for Federated Learning | |
| Handles communication with the federated server | |
| """ | |
| import requests | |
| import json | |
| import logging | |
| import time | |
| from typing import Dict, Any, Optional, List | |
| logger = logging.getLogger(__name__) | |
| class FederatedHTTPClient: | |
| def __init__(self, server_url: str, client_id: str, timeout: int = 30): | |
| self.server_url = server_url.rstrip('/') | |
| self.client_id = client_id | |
| self.timeout = timeout | |
| self.session = requests.Session() | |
| def register(self, client_info: Dict[str, Any] = None) -> Dict[str, Any]: | |
| """Register this client with the server""" | |
| try: | |
| payload = { | |
| 'client_id': self.client_id, | |
| 'client_info': client_info or {} | |
| } | |
| response = self.session.post( | |
| f"{self.server_url}/register", | |
| json=payload, | |
| timeout=self.timeout | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| logger.info(f"Client {self.client_id} registered successfully") | |
| return result | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Failed to register client {self.client_id}: {str(e)}") | |
| raise | |
| def get_global_model(self) -> Dict[str, Any]: | |
| """Get the current global model from server""" | |
| try: | |
| payload = {'client_id': self.client_id} | |
| response = self.session.post( | |
| f"{self.server_url}/get_model", | |
| json=payload, | |
| timeout=self.timeout | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| logger.debug(f"Retrieved global model for round {result.get('round', 'unknown')}") | |
| return result | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Failed to get global model: {str(e)}") | |
| raise | |
| def submit_model_update(self, model_weights: List, metrics: Dict[str, Any] = None) -> Dict[str, Any]: | |
| """Submit model update to server""" | |
| try: | |
| payload = { | |
| 'client_id': self.client_id, | |
| 'model_weights': model_weights, | |
| 'metrics': metrics or {} | |
| } | |
| response = self.session.post( | |
| f"{self.server_url}/submit_update", | |
| json=payload, | |
| timeout=self.timeout | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| logger.info(f"Model update submitted successfully by client {self.client_id}") | |
| return result | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Failed to submit model update: {str(e)}") | |
| raise | |
| def get_training_status(self) -> Dict[str, Any]: | |
| """Get current training status from server""" | |
| try: | |
| response = self.session.get( | |
| f"{self.server_url}/training_status", | |
| timeout=self.timeout | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Failed to get training status: {str(e)}") | |
| raise | |
| def health_check(self) -> bool: | |
| """Check if server is healthy""" | |
| try: | |
| response = self.session.get( | |
| f"{self.server_url}/health", | |
| timeout=5 # Short timeout for health checks | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| return result.get('status') == 'healthy' | |
| except requests.exceptions.RequestException: | |
| return False | |
| def wait_for_server(self, max_wait: int = 60, check_interval: int = 5) -> bool: | |
| """Wait for server to become available""" | |
| start_time = time.time() | |
| while time.time() - start_time < max_wait: | |
| if self.health_check(): | |
| logger.info(f"Server is available at {self.server_url}") | |
| return True | |
| logger.info(f"Waiting for server at {self.server_url}...") | |
| time.sleep(check_interval) | |
| logger.error(f"Server not available after {max_wait} seconds") | |
| return False | |
| def rag_query(self, query: str) -> Dict[str, Any]: | |
| """Submit a RAG query to the server""" | |
| try: | |
| payload = { | |
| 'query': query, | |
| 'client_id': self.client_id | |
| } | |
| response = self.session.post( | |
| f"{self.server_url}/rag/query", | |
| json=payload, | |
| timeout=self.timeout | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Failed to submit RAG query: {str(e)}") | |
| raise | |
| def close(self): | |
| """Close the HTTP session""" | |
| self.session.close() | |