CyberLegalAIendpoint / utils /lightrag_client.py
Charles Grandjean
fix graph by jurisdiction
db8e436
#!/usr/bin/env python3
"""
LightRAG client for interacting with the RAG server
"""
import os
import requests
import time
from typing import Dict, List, Any, Optional
from dotenv import load_dotenv
import logging
# Load environment variables
load_dotenv(dotenv_path=".env", override=False)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# LightRAG configuration
LIGHTRAG_HOST = os.getenv("LIGHTRAG_HOST", "127.0.0.1")
DEFAULT_LIGHTRAG_PORT = int(os.getenv("LIGHTRAG_PORT", "9621"))
API_KEY = os.getenv("LIGHTRAG_API_KEY")
def parse_jurisdiction_graphs() -> Dict[str, int]:
"""
Parse LIGHTRAG_GRAPHS environment variable to get jurisdiction-to-port mappings.
Expected format: "jurisdiction1:port1,jurisdiction2:port2,..."
Example: "romania:9621,bahrain:9622"
Returns:
Dictionary mapping jurisdiction names to their respective ports
"""
graphs_config = os.getenv("LIGHTRAG_GRAPHS", "")
jurisdiction_map = {}
if not graphs_config:
logger.info("No LIGHTRAG_GRAPHS configured, using default port")
return jurisdiction_map
try:
# Parse the comma-separated list
for mapping in graphs_config.split(","):
mapping = mapping.strip()
if ":" in mapping:
jurisdiction, port = mapping.split(":", 1)
jurisdiction = jurisdiction.strip().lower()
port = int(port.strip())
jurisdiction_map[jurisdiction] = port
logger.info(f"Loaded jurisdiction mapping: {jurisdiction} → port {port}")
logger.info(f"Total jurisdictions loaded: {len(jurisdiction_map)}")
return jurisdiction_map
except Exception as e:
logger.error(f"Error parsing LIGHTRAG_GRAPHS: {e}")
return jurisdiction_map
# Parse jurisdiction mappings at module load time
JURISDICTION_PORTS = parse_jurisdiction_graphs()
def get_server_url_for_jurisdiction(jurisdiction: Optional[str] = None) -> str:
"""
Get the appropriate server URL based on jurisdiction.
Args:
jurisdiction: The jurisdiction name (e.g., "romania", "bahrain")
Returns:
Server URL string
"""
if jurisdiction and jurisdiction.lower() in JURISDICTION_PORTS:
port = JURISDICTION_PORTS[jurisdiction.lower()]
logger.info(f"Using jurisdiction-specific server: {jurisdiction} → port {port}")
return f"http://{LIGHTRAG_HOST}:{port}"
else:
if jurisdiction:
logger.warning(f"Jurisdiction '{jurisdiction}' not found in mappings, using default port {DEFAULT_LIGHTRAG_PORT}")
else:
logger.info(f"No jurisdiction specified, using default port {DEFAULT_LIGHTRAG_PORT}")
return f"http://{LIGHTRAG_HOST}:{DEFAULT_LIGHTRAG_PORT}"
class LightRAGClient:
"""
Client for interacting with LightRAG server
"""
def __init__(self, server_url: Optional[str] = None, api_key: Optional[str] = API_KEY):
self.server_url = server_url or get_server_url_for_jurisdiction(None)
self.api_key = api_key
self.timeout = 300
def health_check(self, timeout: float = 1.5) -> bool:
"""
Check if LightRAG server is healthy
"""
try:
response = requests.get(f"{self.server_url}/health", timeout=timeout)
return response.status_code == 200
except Exception as e:
logger.warning(f"Health check failed: {e}")
return False
def query(
self,
query: str,
mode: str = "hybrid",
include_references: bool = True,
conversation_history: Optional[List[Dict[str, str]]] = None,
max_retries: int = 3
) -> Dict[str, Any]:
"""
Query LightRAG server with retry logic
"""
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["X-API-Key"] = self.api_key
payload = {
"query": query,
"mode": mode,
"include_references": include_references,
"conversation_history": conversation_history or [],
}
for attempt in range(max_retries):
try:
response = requests.post(
f"{self.server_url}/query",
json=payload,
headers=headers,
timeout=self.timeout
)
if response.status_code == 200:
logger.info(f"Query successful")
return response.json()
else:
logger.warning(f"Query failed with status {response.status_code}, attempt {attempt + 1}")
except requests.exceptions.Timeout:
logger.warning(f"Query timeout, attempt {attempt + 1}")
except Exception as e:
logger.warning(f"Query error: {e}, attempt {attempt + 1}")
if attempt < max_retries - 1:
time.sleep(2 ** attempt) # Exponential backoff
return {"error": f"Query failed after {max_retries} attempts"}
def get_references(self, response_data: Dict[str, Any]) -> List[str]:
"""
Extract reference information from LightRAG response
"""
references = response_data.get("references", []) or []
ref_list = []
for ref in references[:5]: # Limit to top 5 references
file_name = str(ref.get("file_path", "Unknown file")).split("/")[-1]
ref_list.append(file_name)
return ref_list
class ResponseProcessor:
"""
Process and enhance LightRAG responses
"""
@staticmethod
def extract_main_content(response: Dict[str, Any]) -> str:
"""
Extract the main response content
"""
return response.get("response", "No response available.")
@staticmethod
def format_references(references: List[str]) -> str:
"""
Format reference list for display
"""
if not references:
return ""
ref_text = "\n\n**📚 References:**\n"
for ref in references:
ref_text += f"• {ref}\n"
return ref_text
@staticmethod
def extract_key_entities(response: Dict[str, Any]) -> List[str]:
"""
Extract key entities mentioned in the response
"""
# This could be enhanced if LightRAG provides entity information
content = response.get("response", "")
# Simple entity extraction based on common legal terms
legal_entities = []
regulations = ["GDPR", "NIS2", "DORA", "CRA", "eIDAS", "Cyber Resilience Act"]
for reg in regulations:
if reg.lower() in content.lower():
legal_entities.append(reg)
return list(set(legal_entities)) # Remove duplicates
_lightrag_client_cache: Dict[str, LightRAGClient] = {}
def get_lightrag_client(jurisdiction: Optional[str] = None) -> LightRAGClient:
"""
Get or create a LightRAG client for the specified jurisdiction.
Clients are cached to reuse connections.
Args:
jurisdiction: The jurisdiction name (e.g., "romania", "bahrain")
If None, uses default port
Returns:
LightRAGClient instance configured for the jurisdiction
"""
cache_key = jurisdiction.lower() if jurisdiction else "default"
# Return cached client if available
if cache_key in _lightrag_client_cache:
logger.debug(f"Using cached LightRAG client for jurisdiction: {cache_key}")
return _lightrag_client_cache[cache_key]
# Create new client
server_url = get_server_url_for_jurisdiction(jurisdiction)
client = LightRAGClient(server_url=server_url, api_key=API_KEY)
# Cache the client
_lightrag_client_cache[cache_key] = client
logger.info(f"Created and cached LightRAG client for jurisdiction: {cache_key}{server_url}")
return client
def validate_jurisdiction(jurisdiction: str) -> bool:
"""
Validate if a jurisdiction is supported.
Args:
jurisdiction: The jurisdiction name to validate
Returns:
True if jurisdiction is configured, False otherwise
"""
return jurisdiction.lower() in JURISDICTION_PORTS
def get_available_jurisdictions() -> List[str]:
"""
Get list of available jurisdictions.
Returns:
List of configured jurisdiction names
"""
return list(JURISDICTION_PORTS.keys())