Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| LightRAG client for interacting with the RAG server | |
| """ | |
| import os | |
| import requests | |
| import time | |
| from typing import Dict, List, Any, Optional | |
| from dotenv import load_dotenv | |
| import logging | |
| # Load environment variables | |
| load_dotenv(dotenv_path=".env", override=False) | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # LightRAG configuration | |
| LIGHTRAG_HOST = os.getenv("LIGHTRAG_HOST", "127.0.0.1") | |
| DEFAULT_LIGHTRAG_PORT = int(os.getenv("LIGHTRAG_PORT", "9621")) | |
| API_KEY = os.getenv("LIGHTRAG_API_KEY") | |
| def parse_jurisdiction_graphs() -> Dict[str, int]: | |
| """ | |
| Parse LIGHTRAG_GRAPHS environment variable to get jurisdiction-to-port mappings. | |
| Expected format: "jurisdiction1:port1,jurisdiction2:port2,..." | |
| Example: "romania:9621,bahrain:9622" | |
| Returns: | |
| Dictionary mapping jurisdiction names to their respective ports | |
| """ | |
| graphs_config = os.getenv("LIGHTRAG_GRAPHS", "") | |
| jurisdiction_map = {} | |
| if not graphs_config: | |
| logger.info("No LIGHTRAG_GRAPHS configured, using default port") | |
| return jurisdiction_map | |
| try: | |
| # Parse the comma-separated list | |
| for mapping in graphs_config.split(","): | |
| mapping = mapping.strip() | |
| if ":" in mapping: | |
| jurisdiction, port = mapping.split(":", 1) | |
| jurisdiction = jurisdiction.strip().lower() | |
| port = int(port.strip()) | |
| jurisdiction_map[jurisdiction] = port | |
| logger.info(f"Loaded jurisdiction mapping: {jurisdiction} → port {port}") | |
| logger.info(f"Total jurisdictions loaded: {len(jurisdiction_map)}") | |
| return jurisdiction_map | |
| except Exception as e: | |
| logger.error(f"Error parsing LIGHTRAG_GRAPHS: {e}") | |
| return jurisdiction_map | |
| # Parse jurisdiction mappings at module load time | |
| JURISDICTION_PORTS = parse_jurisdiction_graphs() | |
| def get_server_url_for_jurisdiction(jurisdiction: Optional[str] = None) -> str: | |
| """ | |
| Get the appropriate server URL based on jurisdiction. | |
| Args: | |
| jurisdiction: The jurisdiction name (e.g., "romania", "bahrain") | |
| Returns: | |
| Server URL string | |
| """ | |
| if jurisdiction and jurisdiction.lower() in JURISDICTION_PORTS: | |
| port = JURISDICTION_PORTS[jurisdiction.lower()] | |
| logger.info(f"Using jurisdiction-specific server: {jurisdiction} → port {port}") | |
| return f"http://{LIGHTRAG_HOST}:{port}" | |
| else: | |
| if jurisdiction: | |
| logger.warning(f"Jurisdiction '{jurisdiction}' not found in mappings, using default port {DEFAULT_LIGHTRAG_PORT}") | |
| else: | |
| logger.info(f"No jurisdiction specified, using default port {DEFAULT_LIGHTRAG_PORT}") | |
| return f"http://{LIGHTRAG_HOST}:{DEFAULT_LIGHTRAG_PORT}" | |
| class LightRAGClient: | |
| """ | |
| Client for interacting with LightRAG server | |
| """ | |
| def __init__(self, server_url: Optional[str] = None, api_key: Optional[str] = API_KEY): | |
| self.server_url = server_url or get_server_url_for_jurisdiction(None) | |
| self.api_key = api_key | |
| self.timeout = 300 | |
| def health_check(self, timeout: float = 1.5) -> bool: | |
| """ | |
| Check if LightRAG server is healthy | |
| """ | |
| try: | |
| response = requests.get(f"{self.server_url}/health", timeout=timeout) | |
| return response.status_code == 200 | |
| except Exception as e: | |
| logger.warning(f"Health check failed: {e}") | |
| return False | |
| def query( | |
| self, | |
| query: str, | |
| mode: str = "hybrid", | |
| include_references: bool = True, | |
| conversation_history: Optional[List[Dict[str, str]]] = None, | |
| max_retries: int = 3 | |
| ) -> Dict[str, Any]: | |
| """ | |
| Query LightRAG server with retry logic | |
| """ | |
| headers = {"Content-Type": "application/json"} | |
| if self.api_key: | |
| headers["X-API-Key"] = self.api_key | |
| payload = { | |
| "query": query, | |
| "mode": mode, | |
| "include_references": include_references, | |
| "conversation_history": conversation_history or [], | |
| } | |
| for attempt in range(max_retries): | |
| try: | |
| response = requests.post( | |
| f"{self.server_url}/query", | |
| json=payload, | |
| headers=headers, | |
| timeout=self.timeout | |
| ) | |
| if response.status_code == 200: | |
| logger.info(f"Query successful") | |
| return response.json() | |
| else: | |
| logger.warning(f"Query failed with status {response.status_code}, attempt {attempt + 1}") | |
| except requests.exceptions.Timeout: | |
| logger.warning(f"Query timeout, attempt {attempt + 1}") | |
| except Exception as e: | |
| logger.warning(f"Query error: {e}, attempt {attempt + 1}") | |
| if attempt < max_retries - 1: | |
| time.sleep(2 ** attempt) # Exponential backoff | |
| return {"error": f"Query failed after {max_retries} attempts"} | |
| def get_references(self, response_data: Dict[str, Any]) -> List[str]: | |
| """ | |
| Extract reference information from LightRAG response | |
| """ | |
| references = response_data.get("references", []) or [] | |
| ref_list = [] | |
| for ref in references[:5]: # Limit to top 5 references | |
| file_name = str(ref.get("file_path", "Unknown file")).split("/")[-1] | |
| ref_list.append(file_name) | |
| return ref_list | |
| class ResponseProcessor: | |
| """ | |
| Process and enhance LightRAG responses | |
| """ | |
| def extract_main_content(response: Dict[str, Any]) -> str: | |
| """ | |
| Extract the main response content | |
| """ | |
| return response.get("response", "No response available.") | |
| def format_references(references: List[str]) -> str: | |
| """ | |
| Format reference list for display | |
| """ | |
| if not references: | |
| return "" | |
| ref_text = "\n\n**📚 References:**\n" | |
| for ref in references: | |
| ref_text += f"• {ref}\n" | |
| return ref_text | |
| def extract_key_entities(response: Dict[str, Any]) -> List[str]: | |
| """ | |
| Extract key entities mentioned in the response | |
| """ | |
| # This could be enhanced if LightRAG provides entity information | |
| content = response.get("response", "") | |
| # Simple entity extraction based on common legal terms | |
| legal_entities = [] | |
| regulations = ["GDPR", "NIS2", "DORA", "CRA", "eIDAS", "Cyber Resilience Act"] | |
| for reg in regulations: | |
| if reg.lower() in content.lower(): | |
| legal_entities.append(reg) | |
| return list(set(legal_entities)) # Remove duplicates | |
| _lightrag_client_cache: Dict[str, LightRAGClient] = {} | |
| def get_lightrag_client(jurisdiction: Optional[str] = None) -> LightRAGClient: | |
| """ | |
| Get or create a LightRAG client for the specified jurisdiction. | |
| Clients are cached to reuse connections. | |
| Args: | |
| jurisdiction: The jurisdiction name (e.g., "romania", "bahrain") | |
| If None, uses default port | |
| Returns: | |
| LightRAGClient instance configured for the jurisdiction | |
| """ | |
| cache_key = jurisdiction.lower() if jurisdiction else "default" | |
| # Return cached client if available | |
| if cache_key in _lightrag_client_cache: | |
| logger.debug(f"Using cached LightRAG client for jurisdiction: {cache_key}") | |
| return _lightrag_client_cache[cache_key] | |
| # Create new client | |
| server_url = get_server_url_for_jurisdiction(jurisdiction) | |
| client = LightRAGClient(server_url=server_url, api_key=API_KEY) | |
| # Cache the client | |
| _lightrag_client_cache[cache_key] = client | |
| logger.info(f"Created and cached LightRAG client for jurisdiction: {cache_key} → {server_url}") | |
| return client | |
| def validate_jurisdiction(jurisdiction: str) -> bool: | |
| """ | |
| Validate if a jurisdiction is supported. | |
| Args: | |
| jurisdiction: The jurisdiction name to validate | |
| Returns: | |
| True if jurisdiction is configured, False otherwise | |
| """ | |
| return jurisdiction.lower() in JURISDICTION_PORTS | |
| def get_available_jurisdictions() -> List[str]: | |
| """ | |
| Get list of available jurisdictions. | |
| Returns: | |
| List of configured jurisdiction names | |
| """ | |
| return list(JURISDICTION_PORTS.keys()) | |