import aiohttp import logging from lpm_kernel.api.domains.upload.TrainingTags import TrainingTags from lpm_kernel.configs import config import websockets import json import asyncio from lpm_kernel.configs.config import Config import time import requests from lpm_kernel.api.common.responses import ResponseHandler from lpm_kernel.api.domains.loads.load_service import LoadService from typing import Optional, List, Dict logger = logging.getLogger(__name__) class HeartbeatConfig: """Heartbeat Configuration Class""" def __init__( self, interval: int = 30, # Heartbeat interval (seconds) timeout: int = 10, # Heartbeat timeout (seconds) max_retries: int = 3, # Maximum retry count retry_interval: int = 5 # Retry interval (seconds) ): self.interval = interval self.timeout = timeout self.max_retries = max_retries self.retry_interval = retry_interval class RegistryClient: def __init__(self, heartbeat_config: HeartbeatConfig = None): config = Config.from_env() self.server_url = config.get("REGISTRY_SERVICE_URL") # Convert HTTP URL to WebSocket URL self.ws_url = self.server_url.replace('http://', 'ws://').replace('https://', 'wss://') # Store all active WebSocket connections self.active_connections = {} # Heartbeat configuration self.heartbeat_config = heartbeat_config or HeartbeatConfig() def _get_auth_header(self): """ Get the Authorization header for authenticated requests Returns: dict: Authorization header or empty dict if no credentials """ current_load, error, _ = LoadService.get_current_load(with_password=True) if not current_load or not current_load.instance_id or not current_load.instance_password: logger.info("No credentials found for auth") return {} instance_id = current_load.instance_id instance_password = current_load.instance_password logger.info(f"Using credentials for auth: {instance_id}:{instance_password}") return { "Authorization": f"Bearer {instance_id}:{instance_password}" } def get_ws_url(self, instance_id: str, instance_password: str) -> str: """ Generate WebSocket URL for the specified instance Args: instance_id: Instance ID instance_password: Instance password Returns: str: WebSocket URL """ return f"{self.ws_url}/api/ws/{instance_id}?password={instance_password}" def register_upload(self, upload_name: str, instance_id: str = None, description: str = None, email: str = None, tags: TrainingTags = None): """ Register Upload instance with the registry center Args: upload_name: Upload name instance_id: Instance ID (optional) description: Description (optional) email: User email (optional) Returns: Registration data """ headers = self._get_auth_header() tags_dict = tags.model_dump() if tags else None response = requests.post( f"{self.server_url}/api/upload/register", headers=headers, json={ "upload_name": upload_name, "instance_id": instance_id, "description": description, "email": email, "tags": tags_dict } ) return ResponseHandler.handle_response( response, success_log=f"Upload {upload_name} registered successfully in registry center, instance ID: {instance_id}", error_prefix="Registration" ) def unregister_upload(self, instance_id: str): """Unregister Upload instance from registry center Args: instance_id: Instance ID Returns: dict: Unregistration result """ headers = self._get_auth_header() response = requests.delete( f"{self.server_url}/api/upload/{instance_id}", headers=headers ) return ResponseHandler.handle_response( response, success_log=f"Upload instance {instance_id} unregistered successfully from registry center", error_prefix="Unregistration" ) async def connect_websocket(self, instance_id: str, instance_password: str): """Connect to registry center WebSocket and start keep-alive Args: instance_id: Instance ID instance_password: Instance password Returns: websockets.WebSocketClientProtocol: WebSocket connection """ # Check if connection already exists and is active connection_key = f"{instance_id}" if connection_key in self.active_connections: existing_ws = self.active_connections[connection_key] try: # Check if connection is still active and send heartbeat if await self.send_heartbeat(existing_ws): logger.info(f"Using existing WebSocket connection: {connection_key}") return existing_ws raise Exception("Heartbeat failed") except Exception: # If heartbeat fails, connection is disconnected, remove from active connections logger.warning(f"Existing WebSocket connection is disconnected, creating new connection: {connection_key}") del self.active_connections[connection_key] # Create new connection ws_uri = self.get_ws_url(instance_id, instance_password) try: logger.info(f"Connecting to WebSocket: {ws_uri}") websocket = await websockets.connect(ws_uri) logger.info(f"WebSocket connection established: {ws_uri}") # Add additional attributes to WebSocket connection websocket.instance_id = instance_id websocket.connection_key = connection_key # Store new connection self.active_connections[connection_key] = websocket # Add lock to prevent concurrent message reception websocket.recv_lock = asyncio.Lock() # Start heartbeat task websocket.heartbeat_task = asyncio.create_task( self._keep_alive_with_ping(websocket, instance_id), name=f"heartbeat_{connection_key}" ) await self.handle_messages(websocket) return websocket except Exception as e: logger.error(f"WebSocket connection failed: {str(e)}", exc_info=True) raise async def _keep_alive(self, websocket, instance_id: str): """Keep WebSocket connection alive Args: websocket: WebSocket connection instance_id: Instance ID """ connection_key = f"{instance_id}" logger.info(f"Starting heartbeat task: {connection_key}") retry_count = 0 last_success_time = time.time() try: while True: try: # Send heartbeat at configured interval await asyncio.sleep(self.heartbeat_config.interval) # Check last successful heartbeat time if time.time() - last_success_time > self.heartbeat_config.interval * 2: logger.warning(f"Upload (ID: {instance_id}) heartbeat timeout") raise websockets.exceptions.ConnectionClosed(1006, "Heartbeat timeout") success = await self.send_heartbeat(websocket) if success: retry_count = 0 # Reset retry count last_success_time = time.time() # logger.info(f"Upload (ID: {instance_id}) heartbeat sent") else: retry_count += 1 if retry_count >= self.heartbeat_config.max_retries: logger.error(f"Upload (ID: {instance_id}) heartbeat retry count exceeded") raise websockets.exceptions.ConnectionClosed(1006, "Heartbeat retry count exceeded") logger.warning(f"Upload (ID: {instance_id}) heartbeat send failed, retrying {retry_count} times") await asyncio.sleep(self.heartbeat_config.retry_interval) continue except websockets.exceptions.ConnectionClosed as e: logger.warning(f"Upload (ID: {instance_id}) WebSocket connection closed: {str(e)}") # Clean up connection if connection_key in self.active_connections: del self.active_connections[connection_key] # Cancel related tasks if hasattr(websocket, 'message_task'): websocket.message_task.cancel() break except Exception as e: logger.error(f"Upload (ID: {instance_id}) send heartbeat failed: {str(e)}", exc_info=True) retry_count += 1 if retry_count >= self.heartbeat_config.max_retries: logger.error(f"Upload (ID: {instance_id}) heartbeat retry count exceeded") raise await asyncio.sleep(self.heartbeat_config.retry_interval) except asyncio.CancelledError: logger.info(f"Heartbeat task cancelled: {connection_key}") raise except Exception as e: logger.error(f"Upload (ID: {instance_id}) keep alive task failed: {str(e)}") # Clean up connection if connection_key in self.active_connections: del self.active_connections[connection_key] raise async def _keep_alive_with_ping(self, websocket, instance_id: str): """Keep WebSocket connection alive using native ping/pong Args: websocket: WebSocket connection instance_id: Instance ID """ connection_key = f"{instance_id}" logger.info(f"Starting ping task: {connection_key}") try: while True: try: await asyncio.sleep(self.heartbeat_config.interval) await websocket.ping() # logger.debug(f"Ping sent successfully for {instance_id}") except websockets.exceptions.ConnectionClosed as e: logger.warning(f"Upload (ID: {instance_id}) WebSocket connection closed: {str(e)}") if connection_key in self.active_connections: del self.active_connections[connection_key] if hasattr(websocket, 'message_task'): websocket.message_task.cancel() break except Exception as e: logger.error(f"Upload (ID: {instance_id}) ping failed: {str(e)}") if connection_key in self.active_connections: del self.active_connections[connection_key] raise except asyncio.CancelledError: logger.info(f"Ping task cancelled: {connection_key}") raise except Exception as e: logger.error(f"Upload (ID: {instance_id}) keep alive task failed: {str(e)}") if connection_key in self.active_connections: del self.active_connections[connection_key] raise async def send_heartbeat(self, websocket): """Send heartbeat message Args: websocket: WebSocket connection Returns: bool: Whether heartbeat was sent successfully """ try: heartbeat_message = json.dumps({ "type": "heartbeat", "data": { "timestamp": int(time.time()), "instance_id": websocket.instance_id if hasattr(websocket, 'instance_id') else 'unknown', "status": "alive" }, "version": "1.0" }) # logger.info(f"Preparing to send heartbeat message: {heartbeat_message}") # Set send timeout async with asyncio.timeout(self.heartbeat_config.timeout): await websocket.send(heartbeat_message) # logger.info("Heartbeat message sent successfully") return True except asyncio.TimeoutError: logger.error("Sending heartbeat message timed out") return False except Exception as e: logger.error(f"Sending heartbeat failed: {str(e)}", exc_info=True) return False async def handle_messages(self, websocket): """Handle received WebSocket messages""" try: while True: try: # Use lock to ensure only one coroutine calls recv at a time async with websocket.recv_lock: message = await websocket.recv() data = json.loads(message) message_type = data.get("type") if message_type == "heartbeat_ack": continue elif message_type == "chat": # Handle chat request try: request_data = data.get("request", {}) logger.info(f"[Request details: {json.dumps(request_data, ensure_ascii=False)}") # Call chat interface async with aiohttp.ClientSession() as session: logger.info(f"Preparing to send request to chat interface") config = Config.from_env() kernel2_url = f"{config.KERNEL2_SERVICE_URL}/api/kernel2/chat" async with session.post( kernel2_url, json=request_data, headers={ "Content-Type": "application/json", "Accept": "text/event-stream", # Specify to accept SSE response "Cache-Control": "no-cache", "Connection": "keep-alive" }, timeout=aiohttp.ClientTimeout(total=None), # Disable timeout chunked=True # Enable chunked transfer ) as response: # Check response status logger.info(f"Response status code: {response.status}") if response.status != 200: error_text = await response.text() logger.error(f"[request_id: {data.get('request_id')}] Failed to call chat interface: {error_text}") await websocket.send(json.dumps({ "type": "chat_response", "request_id": data.get("request_id"), "error": f"Failed to call chat interface: {error_text}" })) continue logger.debug(f"Starting to read streaming response") message_count = 0 # Direct forwarding of streaming response async for line in response.content: if line: try: # Convert bytes to string decoded_line = line.decode('utf-8') logger.debug(f"[request_id: {data.get('request_id')}] Received raw data: {decoded_line.strip()}") # Check if it's SSE format data if decoded_line.startswith("data: "): message_count += 1 data_content = decoded_line[6:].strip() # logger.info(f"[request_id: {data.get('request_id')}] Processing message {message_count}") # Check if it's a completion marker if data_content == "[DONE]": logger.info(f"[request_id: {data.get('request_id')}] Received completion marker, processed {message_count} messages in total") await websocket.send(json.dumps({ "type": "chat_response", "request_id": data.get("request_id"), "done": True })) continue # Directly forward original SSE data await websocket.send(json.dumps({ "type": "chat_response", "request_id": data.get("request_id"), "raw_sse": data_content, # Contains original SSE data "done": False })) logger.debug(f"[requestId: {data.get('request_id')}] Forwarded SSE message #{message_count}") except UnicodeDecodeError as e: logger.error(f"[requestId: {data.get('request_id')}] Failed to decode response data: {str(e)}") except Exception as e: logger.error(f"[requestId: {data.get('request_id')}] Error processing response data: {str(e)}, type: {type(e).__name__}") except Exception as e: logger.error(f"Failed to process chat request: {str(e)}") await websocket.send(json.dumps({ "type": "chat_response", "request_id": data.get("request_id"), "error": f"Error processing chat request: {str(e)}" })) else: logger.debug(f"Received unknown message type: {message}") except websockets.exceptions.ConnectionClosed: logger.error("WebSocket connection closed") break except json.JSONDecodeError: logger.error(f"Invalid JSON message: {message}") except Exception as e: logger.error(f"Failed to process message: {str(e)}") except Exception as e: logger.error(f"Message processing loop failed: {str(e)}") raise def list_uploads(self, page_no: int = 1, page_size: int = 10, status: Optional[List[str]] = None): """Get list of registered Upload instances with pagination and status filter Args: page_no (int): Page number, starting from 1 page_size (int): Number of items per page status (Optional[List[str]]): List of status to filter by Returns: dict: Dictionary containing information about Upload instances """ # headers = self._get_auth_header() params = { "page_no": page_no, "page_size": page_size } if status: params["status"] = status response = requests.get( f"{self.server_url}/api/upload/list", # headers=headers, params=params ) return ResponseHandler.handle_response( response, error_prefix="Failed to retrieve list" ) def count_uploads(self): """Get count of all registered Upload instances Returns: dict: Dictionary containing count of Upload instances """ response = requests.get( f"{self.server_url}/api/upload/count", ) return ResponseHandler.handle_response( response, error_prefix="Failed to retrieve count" ) def get_upload_detail(self, instance_id: str) -> Dict: """Get detailed information of an Upload instance Args: instance_id (str): Instance ID of the Upload Returns: dict: Dictionary containing instance information with the following fields: upload_name (str): Name of the upload instance_id (str): Instance ID status (str): Current status of the upload description (str, optional): Description of the upload email (str, optional): Associated email address registration_time (datetime): Time when the instance was registered last_heartbeat (datetime, optional): Time of the last heartbeat is_connected (bool, optional): Connection status, defaults to False instance_password (str, optional): Password for instance registration """ headers = self._get_auth_header() response = requests.get( f"{self.server_url}/api/upload/{instance_id}", headers=headers ) return ResponseHandler.handle_response( response, error_prefix="Failed to retrieve upload details" ) def update_upload(self, instance_id: str, upload_name: str = None, capabilities: dict = None, email: str = None): """Update Upload instance information in the registry center Args: instance_id: Instance ID upload_name: New upload name (optional) capabilities: New capability set (optional) email: New user email (optional) Returns: dict: Update result """ update_data = {} if upload_name is not None: update_data["upload_name"] = upload_name if capabilities is not None: update_data["capabilities"] = capabilities if email is not None: update_data["email"] = email if not update_data: logger.warning("No update data provided for update_upload") return {"message": "No update data provided"} headers = self._get_auth_header() response = requests.put( f"{self.server_url}/api/upload/{instance_id}", headers=headers, json=update_data ) return ResponseHandler.handle_response( response, success_log=f"Upload instance {instance_id} updated successfully", error_prefix="Update" ) def create_role(self, role_id, name, description, system_prompt, icon, instance_id, is_active=True, enable_l0_retrieval=True, enable_l1_retrieval=True): """Create a new role in the registry center Args: role_id: Role UUID name: Role name description: Role description system_prompt: System prompt icon: Icon URL instance_id: Instance ID enable_l0_retrieval: Enable L0 retrieval enable_l1_retrieval: Enable L1 retrieval Returns: dict: Created role data """ headers = self._get_auth_header() response = requests.post( f"{self.server_url}/api/roles", headers=headers, json={ "role_id": role_id, "instance_id": instance_id, "name": name, "description": description, "system_prompt": system_prompt, "is_active": is_active, "icon": icon, "enable_l0_retrieval": enable_l0_retrieval, "enable_l1_retrieval": enable_l1_retrieval } ) return ResponseHandler.handle_response( response, success_log=f"Role {name} created successfully in registry center", error_prefix="Role creation" )