"""Server dependency and lifecycle management. Author: IntelliDeep Labs Team License: BSL 1.1 """ from __future__ import annotations import os import asyncio import logging import time from pathlib import Path from typing import Optional from nlproxy.cache.semantic_cache import SemanticLLMCache from nlproxy.core.corrector import ResponseCorrector from nlproxy.core.verifier import PostLLMVerifier from nlproxy.firewall import ( PromptFirewall, DEFAULT_FIREWALL_RULES, SEMANTIC_FIREWALL_CONFIG, ) from nlproxy.llm.client import LLMOrchestrator, LLMProvider, LLMClientFactory from nlproxy.service.compression import CompressionService from .config import settings from .logger import get_request_logger logger = logging.getLogger(__name__) compression_service: Optional[CompressionService] = None post_verifier: Optional[PostLLMVerifier] = None response_corrector: Optional[ResponseCorrector] = None llm_orchestrator: Optional[LLMOrchestrator] = None firewall: Optional[PromptFirewall] = None semantic_cache: Optional[SemanticLLMCache] = None def get_request_logger(request_id: str) -> logging.LoggerAdapter: return logging.LoggerAdapter(logger, {"request_id": request_id}) async def startup() -> None: global compression_service, post_verifier, response_corrector global llm_orchestrator, firewall, semantic_cache logger.info("Starting nlproxy Enterprise Proxy...") start_time = time.time() provider_name = os.getenv("NLPROXY_DEFAULT_LLM_PROVIDER", "").strip() if not provider_name: provider_name = settings.default_llm_provider # fallback gemini_key = os.getenv("GEMINI_API_KEY", "").strip() openai_key = os.getenv("OPENAI_API_KEY", "").strip() logger.info(f"Provider from env: {provider_name} (GEMINI_KEY={'present' if gemini_key else 'missing'})") compression_service = CompressionService( use_cache=settings.enable_semantic_cache, redis_url=settings.redis_url if settings.enable_semantic_cache else None, privacy_mode=settings.privacy_mode_default, models_dir=Path("nlproxy") / "models", llm_default_model=settings.default_llm_model, ) logger.info("CompressionService initialized") try: post_verifier = PostLLMVerifier( mode="general", use_nli=settings.enable_nli_verification, embedding_model=compression_service.segmenter._embedding_model, models_dir=Path("nlproxy") / "models", ) if post_verifier.use_nli: compression_service.nli_refinement_fn = post_verifier.get_nli_check_function() logger.info("PostLLMVerifier initialized") except Exception as exc: logger.warning( "NLI initialization failed: %s. Disabling semantic verification.", exc ) post_verifier = PostLLMVerifier( mode="general", use_nli=False, embedding_model=compression_service.segmenter._embedding_model, ) compression_service.nli_refinement_fn = None response_corrector = ResponseCorrector(mode="general") compression_service.response_corrector = response_corrector compression_service.post_verifier = post_verifier logger.info("ResponseCorrector initialized") # Allow runtime overrides from environment or CLI (set before uvicorn starts) provider_name = os.getenv("NLPROXY_DEFAULT_LLM_PROVIDER", settings.default_llm_provider) default_model_name = os.getenv("NLPROXY_DEFAULT_LLM_MODEL", settings.default_llm_model) llm_orchestrator = LLMOrchestrator( default_provider=LLMProvider(provider_name), fallback_providers=[LLMProvider.OPENAI, LLMProvider.CLAUDE, LLMProvider.GEMINI] if settings.enable_llm_fallback else [], load_balance=True, max_concurrent_requests=20, default_model=default_model_name, ) logger.info("LLMOrchestrator initialized") # Validate that at least the default provider has credentials configured. try: LLMClientFactory.get_or_create(LLMProvider(provider_name), model=default_model_name) logger.info("LLM client validated for default provider: %s", provider_name) except Exception as exc: provider_name = provider_name if 'provider_name' in locals() else settings.default_llm_provider env_map = { "gemini": "GEMINI_API_KEY", "claude": "ANTHROPIC_API_KEY", "openai": "OPENAI_API_KEY", "deepseek": "DEEPSEEK_API_KEY", "qwen": "QWEN_API_KEY", "kimi": "KIMI_API_KEY", "openrouter": "OPENROUTER_API_KEY", } expected_env = env_map.get(provider_name.lower(), "") logger.error( "LLM client validation failed for default provider '%s': %s. Expected environment variable: %s", provider_name, exc, expected_env, ) raise RuntimeError( f"LLM provider '{provider_name}' requires credentials. Set the environment variable {expected_env} " f"or the chosen provider, or pass --api-key-client/--llm-client to the runserver CLI." ) from exc firewall = PromptFirewall( regex_rules=[ { "name": rule.name, "pattern": rule.pattern, "action": rule.action.value, "severity": rule.severity.value, "description": rule.description, } for rule in DEFAULT_FIREWALL_RULES ], semantic_config=SEMANTIC_FIREWALL_CONFIG if settings.enable_nli_verification else None, default_mode="block", models_dir=Path("nlproxy") / "models", ) logger.info("PromptFirewall initialized") if settings.enable_semantic_cache: try: semantic_cache = SemanticLLMCache( redis_url=settings.redis_url, similarity_threshold=settings.cache_similarity_threshold, default_ttl=settings.cache_default_ttl, dimension=settings.cache_embedding_dim, max_connections=settings.redis_max_connections, socket_timeout=settings.redis_socket_timeout, ) logger.info("SemanticLLMCache initialized") except Exception as exc: logger.warning( "Semantic cache initialization failed: %s. Disabling cache.", exc ) semantic_cache = None else: semantic_cache = None logger.info("Semantic cache disabled") logger.info( "Startup complete in %.2fs | compression=%s verifier=%s cache=%s llm_orchestrator=%s", time.time() - start_time, compression_service is not None, post_verifier is not None, semantic_cache is not None, llm_orchestrator is not None, ) async def shutdown() -> None: global llm_orchestrator, semantic_cache logger.info("Shutting down nlproxy Enterprise Proxy...") start_time = time.time() if llm_orchestrator: await llm_orchestrator.close() logger.debug("LLM orchestrator closed") if semantic_cache and hasattr(semantic_cache, "redis") and semantic_cache.redis: semantic_cache.redis.close() if hasattr(semantic_cache.redis, "aclose"): await semantic_cache.redis.aclose() logger.debug("Redis connection closed") logger.info("Shutdown complete in %.2fs", time.time() - start_time)