Spaces:
Running
Running
| """Server dependency and lifecycle management. | |
| Author: IntelliDeep Labs Team | |
| License: BSL 1.1 | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import asyncio | |
| import logging | |
| import time | |
| from pathlib import Path | |
| from typing import Optional | |
| from nlproxy.cache.semantic_cache import SemanticLLMCache | |
| from nlproxy.core.corrector import ResponseCorrector | |
| from nlproxy.core.verifier import PostLLMVerifier | |
| from nlproxy.firewall import ( | |
| PromptFirewall, | |
| DEFAULT_FIREWALL_RULES, | |
| SEMANTIC_FIREWALL_CONFIG, | |
| ) | |
| from nlproxy.llm.client import LLMOrchestrator, LLMProvider, LLMClientFactory | |
| from nlproxy.service.compression import CompressionService | |
| from .config import settings | |
| from .logger import get_request_logger | |
| logger = logging.getLogger(__name__) | |
| compression_service: Optional[CompressionService] = None | |
| post_verifier: Optional[PostLLMVerifier] = None | |
| response_corrector: Optional[ResponseCorrector] = None | |
| llm_orchestrator: Optional[LLMOrchestrator] = None | |
| firewall: Optional[PromptFirewall] = None | |
| semantic_cache: Optional[SemanticLLMCache] = None | |
| def get_request_logger(request_id: str) -> logging.LoggerAdapter: | |
| return logging.LoggerAdapter(logger, {"request_id": request_id}) | |
| async def startup() -> None: | |
| global compression_service, post_verifier, response_corrector | |
| global llm_orchestrator, firewall, semantic_cache | |
| logger.info("Starting nlproxy Enterprise Proxy...") | |
| start_time = time.time() | |
| provider_name = os.getenv("NLPROXY_DEFAULT_LLM_PROVIDER", "").strip() | |
| if not provider_name: | |
| provider_name = settings.default_llm_provider # fallback | |
| gemini_key = os.getenv("GEMINI_API_KEY", "").strip() | |
| openai_key = os.getenv("OPENAI_API_KEY", "").strip() | |
| logger.info(f"Provider from env: {provider_name} (GEMINI_KEY={'present' if gemini_key else 'missing'})") | |
| compression_service = CompressionService( | |
| use_cache=settings.enable_semantic_cache, | |
| redis_url=settings.redis_url if settings.enable_semantic_cache else None, | |
| privacy_mode=settings.privacy_mode_default, | |
| models_dir=Path("nlproxy") / "models", | |
| llm_default_model=settings.default_llm_model, | |
| ) | |
| logger.info("CompressionService initialized") | |
| try: | |
| post_verifier = PostLLMVerifier( | |
| mode="general", | |
| use_nli=settings.enable_nli_verification, | |
| embedding_model=compression_service.segmenter._embedding_model, | |
| models_dir=Path("nlproxy") / "models", | |
| ) | |
| if post_verifier.use_nli: | |
| compression_service.nli_refinement_fn = post_verifier.get_nli_check_function() | |
| logger.info("PostLLMVerifier initialized") | |
| except Exception as exc: | |
| logger.warning( | |
| "NLI initialization failed: %s. Disabling semantic verification.", exc | |
| ) | |
| post_verifier = PostLLMVerifier( | |
| mode="general", | |
| use_nli=False, | |
| embedding_model=compression_service.segmenter._embedding_model, | |
| ) | |
| compression_service.nli_refinement_fn = None | |
| response_corrector = ResponseCorrector(mode="general") | |
| compression_service.response_corrector = response_corrector | |
| compression_service.post_verifier = post_verifier | |
| logger.info("ResponseCorrector initialized") | |
| # Allow runtime overrides from environment or CLI (set before uvicorn starts) | |
| provider_name = os.getenv("NLPROXY_DEFAULT_LLM_PROVIDER", settings.default_llm_provider) | |
| default_model_name = os.getenv("NLPROXY_DEFAULT_LLM_MODEL", settings.default_llm_model) | |
| llm_orchestrator = LLMOrchestrator( | |
| default_provider=LLMProvider(provider_name), | |
| fallback_providers=[LLMProvider.OPENAI, LLMProvider.CLAUDE, LLMProvider.GEMINI] if settings.enable_llm_fallback else [], | |
| load_balance=True, | |
| max_concurrent_requests=20, | |
| default_model=default_model_name, | |
| ) | |
| logger.info("LLMOrchestrator initialized") | |
| # Validate that at least the default provider has credentials configured. | |
| try: | |
| LLMClientFactory.get_or_create(LLMProvider(provider_name), model=default_model_name) | |
| logger.info("LLM client validated for default provider: %s", provider_name) | |
| except Exception as exc: | |
| provider_name = provider_name if 'provider_name' in locals() else settings.default_llm_provider | |
| env_map = { | |
| "gemini": "GEMINI_API_KEY", | |
| "claude": "ANTHROPIC_API_KEY", | |
| "openai": "OPENAI_API_KEY", | |
| "deepseek": "DEEPSEEK_API_KEY", | |
| "qwen": "QWEN_API_KEY", | |
| "kimi": "KIMI_API_KEY", | |
| "openrouter": "OPENROUTER_API_KEY", | |
| } | |
| expected_env = env_map.get(provider_name.lower(), "<provider-specific API key env var>") | |
| logger.error( | |
| "LLM client validation failed for default provider '%s': %s. Expected environment variable: %s", | |
| provider_name, | |
| exc, | |
| expected_env, | |
| ) | |
| raise RuntimeError( | |
| f"LLM provider '{provider_name}' requires credentials. Set the environment variable {expected_env} " | |
| f"or the chosen provider, or pass --api-key-client/--llm-client to the runserver CLI." | |
| ) from exc | |
| firewall = PromptFirewall( | |
| regex_rules=[ | |
| { | |
| "name": rule.name, | |
| "pattern": rule.pattern, | |
| "action": rule.action.value, | |
| "severity": rule.severity.value, | |
| "description": rule.description, | |
| } | |
| for rule in DEFAULT_FIREWALL_RULES | |
| ], | |
| semantic_config=SEMANTIC_FIREWALL_CONFIG if settings.enable_nli_verification else None, | |
| default_mode="block", | |
| models_dir=Path("nlproxy") / "models", | |
| ) | |
| logger.info("PromptFirewall initialized") | |
| if settings.enable_semantic_cache: | |
| try: | |
| semantic_cache = SemanticLLMCache( | |
| redis_url=settings.redis_url, | |
| similarity_threshold=settings.cache_similarity_threshold, | |
| default_ttl=settings.cache_default_ttl, | |
| dimension=settings.cache_embedding_dim, | |
| max_connections=settings.redis_max_connections, | |
| socket_timeout=settings.redis_socket_timeout, | |
| ) | |
| logger.info("SemanticLLMCache initialized") | |
| except Exception as exc: | |
| logger.warning( | |
| "Semantic cache initialization failed: %s. Disabling cache.", exc | |
| ) | |
| semantic_cache = None | |
| else: | |
| semantic_cache = None | |
| logger.info("Semantic cache disabled") | |
| logger.info( | |
| "Startup complete in %.2fs | compression=%s verifier=%s cache=%s llm_orchestrator=%s", | |
| time.time() - start_time, | |
| compression_service is not None, | |
| post_verifier is not None, | |
| semantic_cache is not None, | |
| llm_orchestrator is not None, | |
| ) | |
| async def shutdown() -> None: | |
| global llm_orchestrator, semantic_cache | |
| logger.info("Shutting down nlproxy Enterprise Proxy...") | |
| start_time = time.time() | |
| if llm_orchestrator: | |
| await llm_orchestrator.close() | |
| logger.debug("LLM orchestrator closed") | |
| if semantic_cache and hasattr(semantic_cache, "redis") and semantic_cache.redis: | |
| semantic_cache.redis.close() | |
| if hasattr(semantic_cache.redis, "aclose"): | |
| await semantic_cache.redis.aclose() | |
| logger.debug("Redis connection closed") | |
| logger.info("Shutdown complete in %.2fs", time.time() - start_time) | |