Spaces:
Sleeping
Sleeping
| import aiohttp | |
| import logging | |
| from lpm_kernel.api.domains.upload.TrainingTags import TrainingTags | |
| from lpm_kernel.configs import config | |
| import websockets | |
| import json | |
| import asyncio | |
| from lpm_kernel.configs.config import Config | |
| import time | |
| import requests | |
| from lpm_kernel.api.common.responses import ResponseHandler | |
| from lpm_kernel.api.domains.loads.load_service import LoadService | |
| from typing import Optional, List, Dict | |
| logger = logging.getLogger(__name__) | |
| class HeartbeatConfig: | |
| """Heartbeat Configuration Class""" | |
| def __init__( | |
| self, | |
| interval: int = 30, # Heartbeat interval (seconds) | |
| timeout: int = 10, # Heartbeat timeout (seconds) | |
| max_retries: int = 3, # Maximum retry count | |
| retry_interval: int = 5 # Retry interval (seconds) | |
| ): | |
| self.interval = interval | |
| self.timeout = timeout | |
| self.max_retries = max_retries | |
| self.retry_interval = retry_interval | |
| class RegistryClient: | |
| def __init__(self, heartbeat_config: HeartbeatConfig = None): | |
| config = Config.from_env() | |
| self.server_url = config.get("REGISTRY_SERVICE_URL") | |
| # Convert HTTP URL to WebSocket URL | |
| self.ws_url = self.server_url.replace('http://', 'ws://').replace('https://', 'wss://') | |
| # Store all active WebSocket connections | |
| self.active_connections = {} | |
| # Heartbeat configuration | |
| self.heartbeat_config = heartbeat_config or HeartbeatConfig() | |
| def _get_auth_header(self): | |
| """ | |
| Get the Authorization header for authenticated requests | |
| Returns: | |
| dict: Authorization header or empty dict if no credentials | |
| """ | |
| current_load, error, _ = LoadService.get_current_load(with_password=True) | |
| if not current_load or not current_load.instance_id or not current_load.instance_password: | |
| logger.info("No credentials found for auth") | |
| return {} | |
| instance_id = current_load.instance_id | |
| instance_password = current_load.instance_password | |
| logger.info(f"Using credentials for auth: {instance_id}:{instance_password}") | |
| return { | |
| "Authorization": f"Bearer {instance_id}:{instance_password}" | |
| } | |
| def get_ws_url(self, instance_id: str, instance_password: str) -> str: | |
| """ | |
| Generate WebSocket URL for the specified instance | |
| Args: | |
| instance_id: Instance ID | |
| instance_password: Instance password | |
| Returns: | |
| str: WebSocket URL | |
| """ | |
| return f"{self.ws_url}/api/ws/{instance_id}?password={instance_password}" | |
| def register_upload(self, upload_name: str, instance_id: str = None, description: str = None, email: str = None, tags: TrainingTags = None): | |
| """ | |
| Register Upload instance with the registry center | |
| Args: | |
| upload_name: Upload name | |
| instance_id: Instance ID (optional) | |
| description: Description (optional) | |
| email: User email (optional) | |
| Returns: | |
| Registration data | |
| """ | |
| headers = self._get_auth_header() | |
| tags_dict = tags.model_dump() if tags else None | |
| response = requests.post( | |
| f"{self.server_url}/api/upload/register", | |
| headers=headers, | |
| json={ | |
| "upload_name": upload_name, | |
| "instance_id": instance_id, | |
| "description": description, | |
| "email": email, | |
| "tags": tags_dict | |
| } | |
| ) | |
| return ResponseHandler.handle_response( | |
| response, | |
| success_log=f"Upload {upload_name} registered successfully in registry center, instance ID: {instance_id}", | |
| error_prefix="Registration" | |
| ) | |
| def unregister_upload(self, instance_id: str): | |
| """Unregister Upload instance from registry center | |
| Args: | |
| instance_id: Instance ID | |
| Returns: | |
| dict: Unregistration result | |
| """ | |
| headers = self._get_auth_header() | |
| response = requests.delete( | |
| f"{self.server_url}/api/upload/{instance_id}", | |
| headers=headers | |
| ) | |
| return ResponseHandler.handle_response( | |
| response, | |
| success_log=f"Upload instance {instance_id} unregistered successfully from registry center", | |
| error_prefix="Unregistration" | |
| ) | |
| async def connect_websocket(self, instance_id: str, instance_password: str): | |
| """Connect to registry center WebSocket and start keep-alive | |
| Args: | |
| instance_id: Instance ID | |
| instance_password: Instance password | |
| Returns: | |
| websockets.WebSocketClientProtocol: WebSocket connection | |
| """ | |
| # Check if connection already exists and is active | |
| connection_key = f"{instance_id}" | |
| if connection_key in self.active_connections: | |
| existing_ws = self.active_connections[connection_key] | |
| try: | |
| # Check if connection is still active and send heartbeat | |
| if await self.send_heartbeat(existing_ws): | |
| logger.info(f"Using existing WebSocket connection: {connection_key}") | |
| return existing_ws | |
| raise Exception("Heartbeat failed") | |
| except Exception: | |
| # If heartbeat fails, connection is disconnected, remove from active connections | |
| logger.warning(f"Existing WebSocket connection is disconnected, creating new connection: {connection_key}") | |
| del self.active_connections[connection_key] | |
| # Create new connection | |
| ws_uri = self.get_ws_url(instance_id, instance_password) | |
| try: | |
| logger.info(f"Connecting to WebSocket: {ws_uri}") | |
| websocket = await websockets.connect(ws_uri) | |
| logger.info(f"WebSocket connection established: {ws_uri}") | |
| # Add additional attributes to WebSocket connection | |
| websocket.instance_id = instance_id | |
| websocket.connection_key = connection_key | |
| # Store new connection | |
| self.active_connections[connection_key] = websocket | |
| # Add lock to prevent concurrent message reception | |
| websocket.recv_lock = asyncio.Lock() | |
| # Start heartbeat task | |
| websocket.heartbeat_task = asyncio.create_task( | |
| self._keep_alive_with_ping(websocket, instance_id), | |
| name=f"heartbeat_{connection_key}" | |
| ) | |
| await self.handle_messages(websocket) | |
| return websocket | |
| except Exception as e: | |
| logger.error(f"WebSocket connection failed: {str(e)}", exc_info=True) | |
| raise | |
| async def _keep_alive(self, websocket, instance_id: str): | |
| """Keep WebSocket connection alive | |
| Args: | |
| websocket: WebSocket connection | |
| instance_id: Instance ID | |
| """ | |
| connection_key = f"{instance_id}" | |
| logger.info(f"Starting heartbeat task: {connection_key}") | |
| retry_count = 0 | |
| last_success_time = time.time() | |
| try: | |
| while True: | |
| try: | |
| # Send heartbeat at configured interval | |
| await asyncio.sleep(self.heartbeat_config.interval) | |
| # Check last successful heartbeat time | |
| if time.time() - last_success_time > self.heartbeat_config.interval * 2: | |
| logger.warning(f"Upload (ID: {instance_id}) heartbeat timeout") | |
| raise websockets.exceptions.ConnectionClosed(1006, "Heartbeat timeout") | |
| success = await self.send_heartbeat(websocket) | |
| if success: | |
| retry_count = 0 # Reset retry count | |
| last_success_time = time.time() | |
| # logger.info(f"Upload (ID: {instance_id}) heartbeat sent") | |
| else: | |
| retry_count += 1 | |
| if retry_count >= self.heartbeat_config.max_retries: | |
| logger.error(f"Upload (ID: {instance_id}) heartbeat retry count exceeded") | |
| raise websockets.exceptions.ConnectionClosed(1006, "Heartbeat retry count exceeded") | |
| logger.warning(f"Upload (ID: {instance_id}) heartbeat send failed, retrying {retry_count} times") | |
| await asyncio.sleep(self.heartbeat_config.retry_interval) | |
| continue | |
| except websockets.exceptions.ConnectionClosed as e: | |
| logger.warning(f"Upload (ID: {instance_id}) WebSocket connection closed: {str(e)}") | |
| # Clean up connection | |
| if connection_key in self.active_connections: | |
| del self.active_connections[connection_key] | |
| # Cancel related tasks | |
| if hasattr(websocket, 'message_task'): | |
| websocket.message_task.cancel() | |
| break | |
| except Exception as e: | |
| logger.error(f"Upload (ID: {instance_id}) send heartbeat failed: {str(e)}", exc_info=True) | |
| retry_count += 1 | |
| if retry_count >= self.heartbeat_config.max_retries: | |
| logger.error(f"Upload (ID: {instance_id}) heartbeat retry count exceeded") | |
| raise | |
| await asyncio.sleep(self.heartbeat_config.retry_interval) | |
| except asyncio.CancelledError: | |
| logger.info(f"Heartbeat task cancelled: {connection_key}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Upload (ID: {instance_id}) keep alive task failed: {str(e)}") | |
| # Clean up connection | |
| if connection_key in self.active_connections: | |
| del self.active_connections[connection_key] | |
| raise | |
| async def _keep_alive_with_ping(self, websocket, instance_id: str): | |
| """Keep WebSocket connection alive using native ping/pong | |
| Args: | |
| websocket: WebSocket connection | |
| instance_id: Instance ID | |
| """ | |
| connection_key = f"{instance_id}" | |
| logger.info(f"Starting ping task: {connection_key}") | |
| try: | |
| while True: | |
| try: | |
| await asyncio.sleep(self.heartbeat_config.interval) | |
| await websocket.ping() | |
| # logger.debug(f"Ping sent successfully for {instance_id}") | |
| except websockets.exceptions.ConnectionClosed as e: | |
| logger.warning(f"Upload (ID: {instance_id}) WebSocket connection closed: {str(e)}") | |
| if connection_key in self.active_connections: | |
| del self.active_connections[connection_key] | |
| if hasattr(websocket, 'message_task'): | |
| websocket.message_task.cancel() | |
| break | |
| except Exception as e: | |
| logger.error(f"Upload (ID: {instance_id}) ping failed: {str(e)}") | |
| if connection_key in self.active_connections: | |
| del self.active_connections[connection_key] | |
| raise | |
| except asyncio.CancelledError: | |
| logger.info(f"Ping task cancelled: {connection_key}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Upload (ID: {instance_id}) keep alive task failed: {str(e)}") | |
| if connection_key in self.active_connections: | |
| del self.active_connections[connection_key] | |
| raise | |
| async def send_heartbeat(self, websocket): | |
| """Send heartbeat message | |
| Args: | |
| websocket: WebSocket connection | |
| Returns: | |
| bool: Whether heartbeat was sent successfully | |
| """ | |
| try: | |
| heartbeat_message = json.dumps({ | |
| "type": "heartbeat", | |
| "data": { | |
| "timestamp": int(time.time()), | |
| "instance_id": websocket.instance_id if hasattr(websocket, 'instance_id') else 'unknown', | |
| "status": "alive" | |
| }, | |
| "version": "1.0" | |
| }) | |
| # logger.info(f"Preparing to send heartbeat message: {heartbeat_message}") | |
| # Set send timeout | |
| async with asyncio.timeout(self.heartbeat_config.timeout): | |
| await websocket.send(heartbeat_message) | |
| # logger.info("Heartbeat message sent successfully") | |
| return True | |
| except asyncio.TimeoutError: | |
| logger.error("Sending heartbeat message timed out") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Sending heartbeat failed: {str(e)}", exc_info=True) | |
| return False | |
| async def handle_messages(self, websocket): | |
| """Handle received WebSocket messages""" | |
| try: | |
| while True: | |
| try: | |
| # Use lock to ensure only one coroutine calls recv at a time | |
| async with websocket.recv_lock: | |
| message = await websocket.recv() | |
| data = json.loads(message) | |
| message_type = data.get("type") | |
| if message_type == "heartbeat_ack": | |
| continue | |
| elif message_type == "chat": | |
| # Handle chat request | |
| try: | |
| request_data = data.get("request", {}) | |
| logger.info(f"[Request details: {json.dumps(request_data, ensure_ascii=False)}") | |
| # Call chat interface | |
| async with aiohttp.ClientSession() as session: | |
| logger.info(f"Preparing to send request to chat interface") | |
| config = Config.from_env() | |
| kernel2_url = f"{config.KERNEL2_SERVICE_URL}/api/kernel2/chat" | |
| async with session.post( | |
| kernel2_url, | |
| json=request_data, | |
| headers={ | |
| "Content-Type": "application/json", | |
| "Accept": "text/event-stream", # Specify to accept SSE response | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive" | |
| }, | |
| timeout=aiohttp.ClientTimeout(total=None), # Disable timeout | |
| chunked=True # Enable chunked transfer | |
| ) as response: | |
| # Check response status | |
| logger.info(f"Response status code: {response.status}") | |
| if response.status != 200: | |
| error_text = await response.text() | |
| logger.error(f"[request_id: {data.get('request_id')}] Failed to call chat interface: {error_text}") | |
| await websocket.send(json.dumps({ | |
| "type": "chat_response", | |
| "request_id": data.get("request_id"), | |
| "error": f"Failed to call chat interface: {error_text}" | |
| })) | |
| continue | |
| logger.debug(f"Starting to read streaming response") | |
| message_count = 0 | |
| # Direct forwarding of streaming response | |
| async for line in response.content: | |
| if line: | |
| try: | |
| # Convert bytes to string | |
| decoded_line = line.decode('utf-8') | |
| logger.debug(f"[request_id: {data.get('request_id')}] Received raw data: {decoded_line.strip()}") | |
| # Check if it's SSE format data | |
| if decoded_line.startswith("data: "): | |
| message_count += 1 | |
| data_content = decoded_line[6:].strip() | |
| # logger.info(f"[request_id: {data.get('request_id')}] Processing message {message_count}") | |
| # Check if it's a completion marker | |
| if data_content == "[DONE]": | |
| logger.info(f"[request_id: {data.get('request_id')}] Received completion marker, processed {message_count} messages in total") | |
| await websocket.send(json.dumps({ | |
| "type": "chat_response", | |
| "request_id": data.get("request_id"), | |
| "done": True | |
| })) | |
| continue | |
| # Directly forward original SSE data | |
| await websocket.send(json.dumps({ | |
| "type": "chat_response", | |
| "request_id": data.get("request_id"), | |
| "raw_sse": data_content, # Contains original SSE data | |
| "done": False | |
| })) | |
| logger.debug(f"[requestId: {data.get('request_id')}] Forwarded SSE message #{message_count}") | |
| except UnicodeDecodeError as e: | |
| logger.error(f"[requestId: {data.get('request_id')}] Failed to decode response data: {str(e)}") | |
| except Exception as e: | |
| logger.error(f"[requestId: {data.get('request_id')}] Error processing response data: {str(e)}, type: {type(e).__name__}") | |
| except Exception as e: | |
| logger.error(f"Failed to process chat request: {str(e)}") | |
| await websocket.send(json.dumps({ | |
| "type": "chat_response", | |
| "request_id": data.get("request_id"), | |
| "error": f"Error processing chat request: {str(e)}" | |
| })) | |
| else: | |
| logger.debug(f"Received unknown message type: {message}") | |
| except websockets.exceptions.ConnectionClosed: | |
| logger.error("WebSocket connection closed") | |
| break | |
| except json.JSONDecodeError: | |
| logger.error(f"Invalid JSON message: {message}") | |
| except Exception as e: | |
| logger.error(f"Failed to process message: {str(e)}") | |
| except Exception as e: | |
| logger.error(f"Message processing loop failed: {str(e)}") | |
| raise | |
| def list_uploads(self, page_no: int = 1, page_size: int = 10, status: Optional[List[str]] = None): | |
| """Get list of registered Upload instances with pagination and status filter | |
| Args: | |
| page_no (int): Page number, starting from 1 | |
| page_size (int): Number of items per page | |
| status (Optional[List[str]]): List of status to filter by | |
| Returns: | |
| dict: Dictionary containing information about Upload instances | |
| """ | |
| # headers = self._get_auth_header() | |
| params = { | |
| "page_no": page_no, | |
| "page_size": page_size | |
| } | |
| if status: | |
| params["status"] = status | |
| response = requests.get( | |
| f"{self.server_url}/api/upload/list", | |
| # headers=headers, | |
| params=params | |
| ) | |
| return ResponseHandler.handle_response( | |
| response, | |
| error_prefix="Failed to retrieve list" | |
| ) | |
| def count_uploads(self): | |
| """Get count of all registered Upload instances | |
| Returns: | |
| dict: Dictionary containing count of Upload instances | |
| """ | |
| response = requests.get( | |
| f"{self.server_url}/api/upload/count", | |
| ) | |
| return ResponseHandler.handle_response( | |
| response, | |
| error_prefix="Failed to retrieve count" | |
| ) | |
| def get_upload_detail(self, instance_id: str) -> Dict: | |
| """Get detailed information of an Upload instance | |
| Args: | |
| instance_id (str): Instance ID of the Upload | |
| Returns: | |
| dict: Dictionary containing instance information with the following fields: | |
| upload_name (str): Name of the upload | |
| instance_id (str): Instance ID | |
| status (str): Current status of the upload | |
| description (str, optional): Description of the upload | |
| email (str, optional): Associated email address | |
| registration_time (datetime): Time when the instance was registered | |
| last_heartbeat (datetime, optional): Time of the last heartbeat | |
| is_connected (bool, optional): Connection status, defaults to False | |
| instance_password (str, optional): Password for instance registration | |
| """ | |
| headers = self._get_auth_header() | |
| response = requests.get( | |
| f"{self.server_url}/api/upload/{instance_id}", | |
| headers=headers | |
| ) | |
| return ResponseHandler.handle_response( | |
| response, | |
| error_prefix="Failed to retrieve upload details" | |
| ) | |
| def update_upload(self, instance_id: str, upload_name: str = None, capabilities: dict = None, email: str = None): | |
| """Update Upload instance information in the registry center | |
| Args: | |
| instance_id: Instance ID | |
| upload_name: New upload name (optional) | |
| capabilities: New capability set (optional) | |
| email: New user email (optional) | |
| Returns: | |
| dict: Update result | |
| """ | |
| update_data = {} | |
| if upload_name is not None: | |
| update_data["upload_name"] = upload_name | |
| if capabilities is not None: | |
| update_data["capabilities"] = capabilities | |
| if email is not None: | |
| update_data["email"] = email | |
| if not update_data: | |
| logger.warning("No update data provided for update_upload") | |
| return {"message": "No update data provided"} | |
| headers = self._get_auth_header() | |
| response = requests.put( | |
| f"{self.server_url}/api/upload/{instance_id}", | |
| headers=headers, | |
| json=update_data | |
| ) | |
| return ResponseHandler.handle_response( | |
| response, | |
| success_log=f"Upload instance {instance_id} updated successfully", | |
| error_prefix="Update" | |
| ) | |
| def create_role(self, role_id, name, description, system_prompt, icon, instance_id, is_active=True, | |
| enable_l0_retrieval=True, enable_l1_retrieval=True): | |
| """Create a new role in the registry center | |
| Args: | |
| role_id: Role UUID | |
| name: Role name | |
| description: Role description | |
| system_prompt: System prompt | |
| icon: Icon URL | |
| instance_id: Instance ID | |
| enable_l0_retrieval: Enable L0 retrieval | |
| enable_l1_retrieval: Enable L1 retrieval | |
| Returns: | |
| dict: Created role data | |
| """ | |
| headers = self._get_auth_header() | |
| response = requests.post( | |
| f"{self.server_url}/api/roles", | |
| headers=headers, | |
| json={ | |
| "role_id": role_id, | |
| "instance_id": instance_id, | |
| "name": name, | |
| "description": description, | |
| "system_prompt": system_prompt, | |
| "is_active": is_active, | |
| "icon": icon, | |
| "enable_l0_retrieval": enable_l0_retrieval, | |
| "enable_l1_retrieval": enable_l1_retrieval | |
| } | |
| ) | |
| return ResponseHandler.handle_response( | |
| response, | |
| success_log=f"Role {name} created successfully in registry center", | |
| error_prefix="Role creation" | |
| ) |