Spaces:
Sleeping
Sleeping
| import requests | |
| import json | |
| import numpy as np | |
| from typing import Dict, Any, Optional, Union | |
| import threading | |
| import time | |
| import hashlib | |
| import logging | |
| from requests.adapters import HTTPAdapter | |
| from urllib3.util.retry import Retry | |
| class HTTPGPUStorage: | |
| """ | |
| HTTP-based GPU storage client that replaces WebSocket functionality. | |
| Maintains the same interface as WebSocketGPUStorage for backward compatibility. | |
| """ | |
| # Singleton instance | |
| _instance = None | |
| _lock = threading.Lock() | |
| def __new__(cls, base_url: str = "https://factorst-intiv.hf.space"): | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| cls._instance._init_singleton(base_url) | |
| return cls._instance | |
| def _init_singleton(self, base_url: str): | |
| """Initialize the singleton instance""" | |
| if hasattr(self, 'initialized'): | |
| return | |
| self.base_url = base_url.rstrip('/') | |
| self.api_base = f"{self.base_url}/api/v1" | |
| self.session_token = None | |
| self.session_id = None | |
| self.lock = threading.Lock() | |
| self._closing = False | |
| self.error_count = 0 | |
| self.last_error_time = 0 | |
| self.max_retries = 5 | |
| # Tensor and model registries (maintained for compatibility) | |
| self.tensor_registry: Dict[str, Dict[str, Any]] = {} | |
| self.model_registry: Dict[str, Dict[str, Any]] = {} | |
| self.resource_monitor = { | |
| 'vram_used': 0, | |
| 'active_tensors': 0, | |
| 'loaded_models': set() | |
| } | |
| # Configure HTTP session with connection pooling and retries | |
| self.http_session = requests.Session() | |
| # Configure retry strategy | |
| retry_strategy = Retry( | |
| total=3, | |
| status_forcelist=[429, 500, 502, 503, 504], | |
| allowed_methods=["HEAD", "GET", "OPTIONS", "POST", "PUT", "DELETE"], # Updated parameter name | |
| backoff_factor=1 | |
| ) | |
| adapter = HTTPAdapter( | |
| max_retries=retry_strategy, | |
| pool_connections=10, | |
| pool_maxsize=20 | |
| ) | |
| self.http_session.mount("http://", adapter) | |
| self.http_session.mount("https://", adapter) | |
| # Set default headers | |
| self.http_session.headers.update({ | |
| 'Content-Type': 'application/json', | |
| 'User-Agent': 'VirtualGPU-HTTP-Client/2.0' | |
| }) | |
| # Initialize session | |
| self._create_session() | |
| self.initialized = True | |
| def __init__(self, base_url: str = "https://factorst-intiv.hf.space"): | |
| """This will actually just return the singleton instance""" | |
| pass | |
| def _create_session(self): | |
| """Create HTTP session with the server""" | |
| try: | |
| response = self.http_session.post( | |
| f"{self.api_base}/sessions", | |
| json={"client_id": "virtual_gpu_client"}, | |
| timeout=30 | |
| ) | |
| response.raise_for_status() | |
| session_data = response.json() | |
| self.session_token = session_data['session_token'] | |
| self.session_id = session_data['session_id'] | |
| # Update session headers | |
| self.http_session.headers.update({ | |
| 'Authorization': f'Bearer {self.session_token}' | |
| }) | |
| logging.info(f"HTTP session created: {self.session_id}") | |
| return True | |
| except Exception as e: | |
| logging.error(f"Failed to create HTTP session: {e}") | |
| self.error_count += 1 | |
| self.last_error_time = time.time() | |
| return False | |
| def _make_request(self, method: str, endpoint: str, **kwargs) -> Optional[Dict[str, Any]]: | |
| """Make HTTP request with error handling and retries""" | |
| if self._closing: | |
| return {"status": "error", "message": "HTTP client is closing"} | |
| url = f"{self.api_base}{endpoint}" | |
| try: | |
| # Ensure we have a valid session | |
| if not self.session_token: | |
| if not self._create_session(): | |
| return {"status": "error", "message": "Failed to create session"} | |
| response = self.http_session.request(method, url, timeout=30, **kwargs) | |
| # Handle authentication errors by recreating session | |
| if response.status_code == 401: | |
| logging.warning("Session expired, recreating...") | |
| if self._create_session(): | |
| response = self.http_session.request(method, url, timeout=30, **kwargs) | |
| else: | |
| return {"status": "error", "message": "Failed to recreate session"} | |
| response.raise_for_status() | |
| # Reset error count on successful request | |
| self.error_count = 0 | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| self.error_count += 1 | |
| self.last_error_time = time.time() | |
| logging.error(f"HTTP request failed: {e}") | |
| return {"status": "error", "message": f"HTTP request failed: {str(e)}"} | |
| except Exception as e: | |
| self.error_count += 1 | |
| self.last_error_time = time.time() | |
| logging.error(f"Unexpected error in HTTP request: {e}") | |
| return {"status": "error", "message": f"Unexpected error: {str(e)}"} | |
| def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool: | |
| """Store tensor data via HTTP API""" | |
| try: | |
| if data is None: | |
| raise ValueError("Cannot store None tensor") | |
| # Calculate tensor metadata | |
| tensor_shape = data.shape | |
| tensor_dtype = str(data.dtype) | |
| tensor_size = data.nbytes | |
| request_data = { | |
| "data": data.tolist(), | |
| "metadata": { | |
| 'shape': tensor_shape, | |
| 'dtype': tensor_dtype, | |
| 'size': tensor_size, | |
| 'timestamp': time.time() | |
| }, | |
| "model_size": model_size if model_size is not None else -1 | |
| } | |
| response = self._make_request( | |
| 'POST', | |
| f'/vram/blocks/{tensor_id}', | |
| json=request_data | |
| ) | |
| if response and response.get('status') == 'success': | |
| # Update tensor registry | |
| with self.lock: | |
| self.tensor_registry[tensor_id] = { | |
| 'shape': tensor_shape, | |
| 'dtype': tensor_dtype, | |
| 'size': tensor_size, | |
| 'timestamp': time.time() | |
| } | |
| self.resource_monitor['vram_used'] += tensor_size | |
| self.resource_monitor['active_tensors'] += 1 | |
| return True | |
| else: | |
| logging.error(f"Failed to store tensor {tensor_id}: {response.get('message', 'Unknown error')}") | |
| return False | |
| except Exception as e: | |
| logging.error(f"Error storing tensor {tensor_id}: {str(e)}") | |
| return False | |
| def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]: | |
| """Load tensor data via HTTP API""" | |
| try: | |
| # Check tensor registry first | |
| if tensor_id not in self.tensor_registry: | |
| logging.warning(f"Tensor {tensor_id} not registered in VRAM") | |
| # Still try to load it in case it exists on server | |
| response = self._make_request('GET', f'/vram/blocks/{tensor_id}') | |
| if response and response.get('status') == 'success': | |
| data = response.get('data') | |
| metadata = response.get('metadata', {}) | |
| if data is None: | |
| logging.error(f"No data found for tensor {tensor_id}") | |
| return None | |
| try: | |
| # Convert to numpy array with correct dtype | |
| expected_dtype = metadata.get('dtype', 'float32') | |
| expected_shape = metadata.get('shape') | |
| arr = np.array(data, dtype=np.dtype(expected_dtype)) | |
| if expected_shape and arr.shape != tuple(expected_shape): | |
| arr = arr.reshape(expected_shape) | |
| # Update registry if not present | |
| if tensor_id not in self.tensor_registry: | |
| with self.lock: | |
| self.tensor_registry[tensor_id] = metadata | |
| return arr | |
| except Exception as e: | |
| logging.error(f"Error converting tensor data: {str(e)}") | |
| return None | |
| else: | |
| logging.error(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}") | |
| return None | |
| except Exception as e: | |
| logging.error(f"Error loading tensor {tensor_id}: {str(e)}") | |
| return None | |
| def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool: | |
| """Store component state via HTTP API""" | |
| try: | |
| request_data = { | |
| "data": state_data, | |
| "timestamp": time.time() | |
| } | |
| response = self._make_request( | |
| 'POST', | |
| f'/state/{component}/{state_id}', | |
| json=request_data | |
| ) | |
| if response and response.get('status') == 'success': | |
| return True | |
| else: | |
| logging.error(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}") | |
| return False | |
| except Exception as e: | |
| logging.error(f"Error storing state for {component}/{state_id}: {str(e)}") | |
| return False | |
| def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]: | |
| """Load component state via HTTP API""" | |
| try: | |
| response = self._make_request('GET', f'/state/{component}/{state_id}') | |
| if response and response.get('status') == 'success': | |
| return response.get('data') | |
| else: | |
| logging.error(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}") | |
| return None | |
| except Exception as e: | |
| logging.error(f"Error loading state for {component}/{state_id}: {str(e)}") | |
| return None | |
| def cache_data(self, key: str, data: Any) -> bool: | |
| """Cache data via HTTP API""" | |
| try: | |
| request_data = {"data": data} | |
| response = self._make_request( | |
| 'POST', | |
| f'/cache/{key}', | |
| json=request_data | |
| ) | |
| return response and response.get('status') == 'success' | |
| except Exception as e: | |
| logging.error(f"Error caching data for key {key}: {str(e)}") | |
| return False | |
| def get_cached_data(self, key: str) -> Optional[Any]: | |
| """Get cached data via HTTP API""" | |
| try: | |
| response = self._make_request('GET', f'/cache/{key}') | |
| if response and response.get('status') == 'success': | |
| return response.get('data') | |
| return None | |
| except Exception as e: | |
| logging.error(f"Error getting cached data for key {key}: {str(e)}") | |
| return None | |
| def is_model_loaded(self, model_name: str) -> bool: | |
| """Check if a model is loaded via HTTP API""" | |
| try: | |
| response = self._make_request('GET', f'/models/{model_name}/status') | |
| if response and response.get('status') == 'loaded': | |
| return True | |
| return False | |
| except Exception as e: | |
| logging.error(f"Error checking model status for {model_name}: {str(e)}") | |
| return False | |
| def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool: | |
| """Load a model via HTTP API""" | |
| try: | |
| # Check if model is already loaded | |
| if self.is_model_loaded(model_name): | |
| logging.info(f"Model {model_name} already loaded") | |
| return True | |
| # Calculate model hash if path provided | |
| model_hash = None | |
| if model_path: | |
| model_hash = self._calculate_model_hash(model_path) | |
| request_data = { | |
| "model_data": model_data, | |
| "model_path": model_path, | |
| "model_hash": model_hash | |
| } | |
| response = self._make_request( | |
| 'POST', | |
| f'/models/{model_name}/load', | |
| json=request_data | |
| ) | |
| if response and response.get('status') == 'success': | |
| with self.lock: | |
| self.model_registry[model_name] = { | |
| 'hash': model_hash, | |
| 'timestamp': time.time(), | |
| 'model_data': model_data | |
| } | |
| self.resource_monitor['loaded_models'].add(model_name) | |
| logging.info(f"Successfully loaded model {model_name}") | |
| return True | |
| else: | |
| logging.error(f"Failed to load model {model_name}: {response.get('message', 'Unknown error')}") | |
| return False | |
| except Exception as e: | |
| logging.error(f"Error loading model {model_name}: {str(e)}") | |
| return False | |
| def _calculate_model_hash(self, model_path: str) -> str: | |
| """Calculate SHA256 hash of model file""" | |
| try: | |
| sha256_hash = hashlib.sha256() | |
| with open(model_path, "rb") as f: | |
| for byte_block in iter(lambda: f.read(4096), b""): | |
| sha256_hash.update(byte_block) | |
| return sha256_hash.hexdigest() | |
| except Exception as e: | |
| logging.error(f"Error calculating model hash: {str(e)}") | |
| return "" | |
| def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]: | |
| """Start inference with a loaded model via HTTP API""" | |
| try: | |
| if not self.is_model_loaded(model_name): | |
| logging.error(f"Model {model_name} not loaded. Please load the model first.") | |
| return None | |
| request_data = { | |
| "input_data": input_data.tolist() if isinstance(input_data, np.ndarray) else input_data | |
| } | |
| response = self._make_request( | |
| 'POST', | |
| f'/models/{model_name}/inference', | |
| json=request_data | |
| ) | |
| if response and response.get('status') == 'success': | |
| return { | |
| 'output': np.array(response['output']) if 'output' in response else None, | |
| 'metrics': response.get('metrics', {}), | |
| 'model_info': self.model_registry.get(model_name, {}) | |
| } | |
| else: | |
| logging.error(f"Inference failed for model {model_name}: {response.get('message', 'Unknown error')}") | |
| return None | |
| except Exception as e: | |
| logging.error(f"Error during inference for model {model_name}: {str(e)}") | |
| return None | |
| def wait_for_connection(self, timeout: float = 30.0) -> bool: | |
| """Wait for HTTP connection to be established (compatibility method)""" | |
| # For HTTP, we just check if we can make a request | |
| try: | |
| if not self.session_token: | |
| return self._create_session() | |
| # Test connection with a simple request | |
| response = self._make_request('GET', '/cache/connection_test') | |
| return response is not None | |
| except Exception as e: | |
| logging.error(f"Connection test failed: {e}") | |
| return False | |
| def is_connected(self) -> bool: | |
| """Check if HTTP connection is active (compatibility method)""" | |
| return self.session_token is not None and not self._closing | |
| def get_connection_status(self) -> Dict[str, Any]: | |
| """Get detailed connection status""" | |
| return { | |
| "connected": self.is_connected(), | |
| "closing": self._closing, | |
| "error_count": self.error_count, | |
| "base_url": self.base_url, | |
| "last_error_time": self.last_error_time, | |
| "loaded_models": list(self.resource_monitor['loaded_models']), | |
| "session_id": self.session_id | |
| } | |
| def set_keep_alive(self, enabled: bool): | |
| """Set keep-alive mode (compatibility method for HTTP)""" | |
| # HTTP connections are stateless, so this is a no-op | |
| pass | |
| def reconnect(self): | |
| """Reconnect to server (recreate session for HTTP)""" | |
| self.session_token = None | |
| self.session_id = None | |
| return self._create_session() | |
| def close(self): | |
| """Close HTTP client""" | |
| self._closing = True | |
| if self.http_session: | |
| self.http_session.close() | |
| # Additional methods for multi-chip coordination | |
| def transfer_between_chips(self, src_chip: int, dst_chip: int, data_id: str) -> Optional[str]: | |
| """Transfer data between chips via HTTP API""" | |
| try: | |
| request_data = {"data_id": data_id} | |
| response = self._make_request( | |
| 'POST', | |
| f'/chips/{src_chip}/transfer/{dst_chip}', | |
| json=request_data | |
| ) | |
| if response and response.get('status') == 'success': | |
| return response.get('new_data_id') | |
| else: | |
| logging.error(f"Chip transfer failed: {response.get('message', 'Unknown error')}") | |
| return None | |
| except Exception as e: | |
| logging.error(f"Error in chip transfer: {str(e)}") | |
| return None | |
| def create_sync_barrier(self, barrier_id: str, num_participants: int) -> bool: | |
| """Create synchronization barrier via HTTP API""" | |
| try: | |
| request_data = {"num_participants": num_participants} | |
| response = self._make_request( | |
| 'POST', | |
| f'/sync/barrier/{barrier_id}', | |
| json=request_data | |
| ) | |
| return response and response.get('status') == 'success' | |
| except Exception as e: | |
| logging.error(f"Error creating sync barrier: {str(e)}") | |
| return False | |
| def wait_sync_barrier(self, barrier_id: str) -> bool: | |
| """Wait at synchronization barrier via HTTP API""" | |
| try: | |
| response = self._make_request('PUT', f'/sync/barrier/{barrier_id}/wait') | |
| if response: | |
| status = response.get('status') | |
| if status == 'released': | |
| return True | |
| elif status == 'waiting': | |
| # In a real implementation, this might poll or use long-polling | |
| time.sleep(0.1) # Brief delay before next check | |
| return False | |
| return False | |
| except Exception as e: | |
| logging.error(f"Error waiting at sync barrier: {str(e)}") | |
| return False | |
| # Compatibility alias for existing code | |
| WebSocketGPUStorage = HTTPGPUStorage | |