"""Chat completion endpoints for nlproxy.""" from __future__ import annotations import asyncio import json import logging import time from typing import Any, AsyncGenerator, Dict, List, Optional from fastapi import APIRouter, HTTPException, Request, status from fastapi.responses import StreamingResponse from nlproxy.core.restriction import Restriction from nlproxy.firewall.firewall import FirewallAction from nlproxy.llm.client import LLMProvider, LLMProviderError, LLMResponse from nlproxy.server.config import settings from nlproxy.server import dependencies from nlproxy.server.dependencies import get_request_logger from nlproxy.server.schemas import ChatCompletionRequest, ChatCompletionResponse, Message router = APIRouter() logger = logging.getLogger(__name__) def _make_choice(text: str) -> Dict[str, Any]: return { "index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "stop", } async def compress_prompt( messages: List[Message], aggressiveness: float, mode: str, language: Optional[str], privacy_mode: bool, request_id: str, ) -> tuple[List[Dict[str, Any]], Dict[str, Any]]: if not dependencies.compression_service: raise RuntimeError("Compression service is not initialized") prompt = "\n".join([m.content for m in messages if m.content and m.content.strip()]) if not prompt.strip(): raise ValueError("Prompt is empty after concatenating messages") start_time = time.time() last_error: Optional[Exception] = None for attempt in range(1, settings.compression_max_retries + 1): try: results = await asyncio.wait_for( dependencies.compression_service.compress_batch_async( texts=[prompt], aggressiveness=aggressiveness, mode=mode, nli_active=settings.enable_nli_verification, language=language, privacy_mode=privacy_mode, ), timeout=settings.max_compression_timeout, ) if not results or not isinstance(results, list) or len(results) == 0: raise RuntimeError("compress_batch_async returned invalid results") res = results[0] new_messages = [m.model_dump() for m in messages] for i in range(len(new_messages) - 1, -1, -1): if new_messages[i]["role"] in ("user", "system"): new_messages[i]["content"] = res["compressed_text"] break metadata = { "original_tokens": res.get("original_tokens", 0), "compressed_tokens": res.get("compressed_tokens", 0), "tokens_saved": res.get("tokens_saved", 0), "compression_ratio": res.get("compression_ratio", 0.0), "cost_saved_usd": res.get("cost_saved_usd", 0.0), "safety_score": res.get("safety_score", 0.0), "alerts": res.get("alerts", []), "compression_latency_ms": (time.time() - start_time) * 1000, "cache_hit": False, } return new_messages, metadata except asyncio.TimeoutError as exc: last_error = exc if attempt < settings.compression_max_retries: backoff = 0.5 * (2 ** (attempt - 1)) await asyncio.sleep(backoff) continue except Exception as exc: last_error = exc if attempt < settings.compression_max_retries: backoff = 0.5 * (2 ** (attempt - 1)) await asyncio.sleep(backoff) continue raise RuntimeError( f"Compression failed after {settings.compression_max_retries} attempts: {last_error}" ) def _build_response_usage(compressed_tokens: int, response_text: str) -> Dict[str, int]: completion_tokens = len(response_text.split()) return { "prompt_tokens": compressed_tokens, "completion_tokens": completion_tokens, "total_tokens": compressed_tokens + completion_tokens, } async def _generate_stream_response(request: ChatCompletionRequest, prompt_text: str) -> StreamingResponse: async def event_generator() -> AsyncGenerator[str, None]: async for chunk in dependencies.llm_orchestrator.generate_stream( prompt=prompt_text, provider=(LLMProvider(request.provider) if request.provider else None), model=request.model, max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, stop_sequences=request.stop, ): yield f"data: {json.dumps(chunk)}\n\n" yield "data: [DONE]\n\n" return StreamingResponse(event_generator(), media_type="text/event-stream") @router.post("/v1/chat/completions", response_model=ChatCompletionResponse, tags=["Chat"]) async def chat_completions(request: ChatCompletionRequest, http_request: Request) -> Any: request_id = str(time.time_ns()) request_logger = get_request_logger(request_id) start_time = time.time() request_logger.info( "Received chat request: model=%s messages=%s mode=%s", request.model, len(request.messages), request.mode, ) if not dependencies.firewall or not dependencies.compression_service or not dependencies.llm_orchestrator or not dependencies.response_corrector or not dependencies.post_verifier: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Server is not ready") if request.stream: request_logger.info("Stream enabled for this request") user_messages = [m.content for m in request.messages if m.role == "user" and m.content] if not user_messages: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Request must contain at least one user message") user_prompt = "\n".join(user_messages) action, violations = dependencies.firewall.check_prompt(user_prompt) if action == FirewallAction.BLOCK: request_logger.warning("Prompt blocked by firewall: %s", violations) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Request blocked by security policy") if action == FirewallAction.ALERT: request_logger.warning("Firewall alert triggered: %s", violations) if action == FirewallAction.REWRITE: request_logger.info("Rewriting prompt due to firewall violations") user_prompt = dependencies.firewall.rewrite_prompt(user_prompt, violations) for i, message in enumerate(request.messages): if message.role == "user": request.messages[i].content = user_prompt break try: new_messages, metadata = await compress_prompt( messages=request.messages, aggressiveness=request.aggressiveness, mode=request.mode, language=request.language, privacy_mode=request.privacy_mode, request_id=request_id, ) prompt_text = new_messages[-1]["content"] except Exception as exc: request_logger.error("Compression failed: %s", exc) prompt_text = user_prompt new_messages = [m.model_dump() for m in request.messages] metadata = { "original_tokens": 0, "compressed_tokens": 0, "tokens_saved": 0, "compression_ratio": 0.0, "cost_saved_usd": 0.0, "safety_score": 0.0, "alerts": [f"Compression failed: {exc}"], "compression_latency_ms": 0, "cache_hit": False, "compression_failed": True, } try: manual_restrictions = None if request.manual_restrictions: manual_restrictions = [Restriction(**r) for r in request.manual_restrictions] shield_result = dependencies.compression_service._shield_with_cache( text=user_prompt, manual_restrictions=manual_restrictions, mode=request.mode, privacy_mode=request.privacy_mode, ) if shield_result is None: raise RuntimeError("Shield result is None") sentences = dependencies.compression_service.segmenter.split_sentences( shield_result.shielded_text, language=request.language, ) safety_report = dependencies.compression_service.safety.validate( original_text=user_prompt, compressed_text=prompt_text, shield_result=shield_result, original_sentences=sentences, compressed_indices=None, mode=request.mode, use_perplexity=request.use_perplexity, ) prompt_text = safety_report.final_text metadata["safety_score"] = safety_report.safety_score metadata["forced_sentences_added"] = safety_report.forced_sentences_added if safety_report.perplexity is not None: metadata["perplexity"] = safety_report.perplexity except Exception as exc: request_logger.error("Safety validation failed: %s", exc) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Prompt safety validation failed") if request.stream: return await _generate_stream_response(request, prompt_text) try: generated = await asyncio.wait_for( dependencies.llm_orchestrator.generate( prompt=prompt_text, provider=(LLMProvider(request.provider) if request.provider else None), model=request.model, max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, stop_sequences=request.stop, ), timeout=settings.llm_request_timeout, ) except asyncio.TimeoutError: request_logger.error("LLM generation timeout") raise HTTPException(status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="LLM generation timed out") except LLMProviderError as exc: request_logger.error("LLM provider error: %s", exc) raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) response_text = generated.text if isinstance(generated, LLMResponse) else str(generated) if request.privacy_mode: response_text = dependencies.compression_service.reconstructor._reinject_entities(response_text, shield_result.placeholder_map) final_response = dependencies.response_corrector.correct(response_text, shield_result) verification = dependencies.post_verifier.verify(final_response, shield_result) metadata.update({ "post_llm_confidence": verification.confidence_score, "post_llm_violations": verification.violations, "cache_hit": False, }) if verification.confidence_score < request.min_confidence and not request.auto_correct: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="Response does not meet confidence threshold", ) final_response_text = final_response correction_attempts = 0 while ( request.auto_correct and verification.confidence_score < request.min_confidence and correction_attempts < settings.max_regeneration_attempts ): correction_attempts += 1 try: corrected = await dependencies.llm_orchestrator.generate( prompt=f"Correct the following response to satisfy policy: {final_response}", provider=(LLMProvider(request.provider) if request.provider else None), model=request.model, max_tokens=request.max_tokens, temperature=request.temperature * 0.8, top_p=request.top_p, top_k=request.top_k, stop_sequences=request.stop, ) final_response_text = corrected.text if isinstance(corrected, LLMResponse) else str(corrected) verification = dependencies.post_verifier.verify(final_response_text, shield_result) if verification.confidence_score >= request.min_confidence: metadata["regenerated"] = True metadata["regeneration_attempts"] = correction_attempts break except Exception as exc: request_logger.warning("Auto-correction attempt failed: %s", exc) break metadata["auto_corrected"] = metadata.get("regenerated", False) metadata["regeneration_attempts"] = metadata.get("regeneration_attempts", 0) if verification.confidence_score < request.min_confidence: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="Response does not meet security confidence threshold after correction", ) usage = _build_response_usage(metadata.get("compressed_tokens", 0), final_response_text) metadata["total_latency_ms"] = round((time.time() - start_time) * 1000, 2) return ChatCompletionResponse( model=request.model, choices=[_make_choice(final_response_text)], usage=usage, nlproxy=metadata, )