Spaces:
Running
Running
File size: 8,739 Bytes
851f2ed 695b33f 851f2ed 695b33f 851f2ed db8e436 851f2ed db8e436 695b33f 851f2ed db8e436 851f2ed d9ce50a 851f2ed d9ce50a 851f2ed 695b33f 851f2ed db8e436 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 |
#!/usr/bin/env python3
"""
LightRAG client for interacting with the RAG server
"""
import os
import requests
import time
from typing import Dict, List, Any, Optional
from dotenv import load_dotenv
import logging
# Load environment variables
load_dotenv(dotenv_path=".env", override=False)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# LightRAG configuration
LIGHTRAG_HOST = os.getenv("LIGHTRAG_HOST", "127.0.0.1")
DEFAULT_LIGHTRAG_PORT = int(os.getenv("LIGHTRAG_PORT", "9621"))
API_KEY = os.getenv("LIGHTRAG_API_KEY")
def parse_jurisdiction_graphs() -> Dict[str, int]:
"""
Parse LIGHTRAG_GRAPHS environment variable to get jurisdiction-to-port mappings.
Expected format: "jurisdiction1:port1,jurisdiction2:port2,..."
Example: "romania:9621,bahrain:9622"
Returns:
Dictionary mapping jurisdiction names to their respective ports
"""
graphs_config = os.getenv("LIGHTRAG_GRAPHS", "")
jurisdiction_map = {}
if not graphs_config:
logger.info("No LIGHTRAG_GRAPHS configured, using default port")
return jurisdiction_map
try:
# Parse the comma-separated list
for mapping in graphs_config.split(","):
mapping = mapping.strip()
if ":" in mapping:
jurisdiction, port = mapping.split(":", 1)
jurisdiction = jurisdiction.strip().lower()
port = int(port.strip())
jurisdiction_map[jurisdiction] = port
logger.info(f"Loaded jurisdiction mapping: {jurisdiction} → port {port}")
logger.info(f"Total jurisdictions loaded: {len(jurisdiction_map)}")
return jurisdiction_map
except Exception as e:
logger.error(f"Error parsing LIGHTRAG_GRAPHS: {e}")
return jurisdiction_map
# Parse jurisdiction mappings at module load time
JURISDICTION_PORTS = parse_jurisdiction_graphs()
def get_server_url_for_jurisdiction(jurisdiction: Optional[str] = None) -> str:
"""
Get the appropriate server URL based on jurisdiction.
Args:
jurisdiction: The jurisdiction name (e.g., "romania", "bahrain")
Returns:
Server URL string
"""
if jurisdiction and jurisdiction.lower() in JURISDICTION_PORTS:
port = JURISDICTION_PORTS[jurisdiction.lower()]
logger.info(f"Using jurisdiction-specific server: {jurisdiction} → port {port}")
return f"http://{LIGHTRAG_HOST}:{port}"
else:
if jurisdiction:
logger.warning(f"Jurisdiction '{jurisdiction}' not found in mappings, using default port {DEFAULT_LIGHTRAG_PORT}")
else:
logger.info(f"No jurisdiction specified, using default port {DEFAULT_LIGHTRAG_PORT}")
return f"http://{LIGHTRAG_HOST}:{DEFAULT_LIGHTRAG_PORT}"
class LightRAGClient:
"""
Client for interacting with LightRAG server
"""
def __init__(self, server_url: Optional[str] = None, api_key: Optional[str] = API_KEY):
self.server_url = server_url or get_server_url_for_jurisdiction(None)
self.api_key = api_key
self.timeout = 300
def health_check(self, timeout: float = 1.5) -> bool:
"""
Check if LightRAG server is healthy
"""
try:
response = requests.get(f"{self.server_url}/health", timeout=timeout)
return response.status_code == 200
except Exception as e:
logger.warning(f"Health check failed: {e}")
return False
def query(
self,
query: str,
mode: str = "hybrid",
include_references: bool = True,
conversation_history: Optional[List[Dict[str, str]]] = None,
max_retries: int = 3
) -> Dict[str, Any]:
"""
Query LightRAG server with retry logic
"""
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["X-API-Key"] = self.api_key
payload = {
"query": query,
"mode": mode,
"include_references": include_references,
"conversation_history": conversation_history or [],
}
for attempt in range(max_retries):
try:
response = requests.post(
f"{self.server_url}/query",
json=payload,
headers=headers,
timeout=self.timeout
)
if response.status_code == 200:
logger.info(f"Query successful")
return response.json()
else:
logger.warning(f"Query failed with status {response.status_code}, attempt {attempt + 1}")
except requests.exceptions.Timeout:
logger.warning(f"Query timeout, attempt {attempt + 1}")
except Exception as e:
logger.warning(f"Query error: {e}, attempt {attempt + 1}")
if attempt < max_retries - 1:
time.sleep(2 ** attempt) # Exponential backoff
return {"error": f"Query failed after {max_retries} attempts"}
def get_references(self, response_data: Dict[str, Any]) -> List[str]:
"""
Extract reference information from LightRAG response
"""
references = response_data.get("references", []) or []
ref_list = []
for ref in references[:5]: # Limit to top 5 references
file_name = str(ref.get("file_path", "Unknown file")).split("/")[-1]
ref_list.append(file_name)
return ref_list
class ResponseProcessor:
"""
Process and enhance LightRAG responses
"""
@staticmethod
def extract_main_content(response: Dict[str, Any]) -> str:
"""
Extract the main response content
"""
return response.get("response", "No response available.")
@staticmethod
def format_references(references: List[str]) -> str:
"""
Format reference list for display
"""
if not references:
return ""
ref_text = "\n\n**📚 References:**\n"
for ref in references:
ref_text += f"• {ref}\n"
return ref_text
@staticmethod
def extract_key_entities(response: Dict[str, Any]) -> List[str]:
"""
Extract key entities mentioned in the response
"""
# This could be enhanced if LightRAG provides entity information
content = response.get("response", "")
# Simple entity extraction based on common legal terms
legal_entities = []
regulations = ["GDPR", "NIS2", "DORA", "CRA", "eIDAS", "Cyber Resilience Act"]
for reg in regulations:
if reg.lower() in content.lower():
legal_entities.append(reg)
return list(set(legal_entities)) # Remove duplicates
_lightrag_client_cache: Dict[str, LightRAGClient] = {}
def get_lightrag_client(jurisdiction: Optional[str] = None) -> LightRAGClient:
"""
Get or create a LightRAG client for the specified jurisdiction.
Clients are cached to reuse connections.
Args:
jurisdiction: The jurisdiction name (e.g., "romania", "bahrain")
If None, uses default port
Returns:
LightRAGClient instance configured for the jurisdiction
"""
cache_key = jurisdiction.lower() if jurisdiction else "default"
# Return cached client if available
if cache_key in _lightrag_client_cache:
logger.debug(f"Using cached LightRAG client for jurisdiction: {cache_key}")
return _lightrag_client_cache[cache_key]
# Create new client
server_url = get_server_url_for_jurisdiction(jurisdiction)
client = LightRAGClient(server_url=server_url, api_key=API_KEY)
# Cache the client
_lightrag_client_cache[cache_key] = client
logger.info(f"Created and cached LightRAG client for jurisdiction: {cache_key} → {server_url}")
return client
def validate_jurisdiction(jurisdiction: str) -> bool:
"""
Validate if a jurisdiction is supported.
Args:
jurisdiction: The jurisdiction name to validate
Returns:
True if jurisdiction is configured, False otherwise
"""
return jurisdiction.lower() in JURISDICTION_PORTS
def get_available_jurisdictions() -> List[str]:
"""
Get list of available jurisdictions.
Returns:
List of configured jurisdiction names
"""
return list(JURISDICTION_PORTS.keys())
|