Spaces:
Running
Running
| """Chat completion endpoints for nlproxy.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| import time | |
| from typing import Any, AsyncGenerator, Dict, List, Optional | |
| from fastapi import APIRouter, HTTPException, Request, status | |
| from fastapi.responses import StreamingResponse | |
| from nlproxy.core.restriction import Restriction | |
| from nlproxy.firewall.firewall import FirewallAction | |
| from nlproxy.llm.client import LLMProvider, LLMProviderError, LLMResponse | |
| from nlproxy.server.config import settings | |
| from nlproxy.server import dependencies | |
| from nlproxy.server.dependencies import get_request_logger | |
| from nlproxy.server.schemas import ChatCompletionRequest, ChatCompletionResponse, Message | |
| router = APIRouter() | |
| logger = logging.getLogger(__name__) | |
| def _make_choice(text: str) -> Dict[str, Any]: | |
| return { | |
| "index": 0, | |
| "message": {"role": "assistant", "content": text}, | |
| "finish_reason": "stop", | |
| } | |
| async def compress_prompt( | |
| messages: List[Message], | |
| aggressiveness: float, | |
| mode: str, | |
| language: Optional[str], | |
| privacy_mode: bool, | |
| request_id: str, | |
| ) -> tuple[List[Dict[str, Any]], Dict[str, Any]]: | |
| if not dependencies.compression_service: | |
| raise RuntimeError("Compression service is not initialized") | |
| prompt = "\n".join([m.content for m in messages if m.content and m.content.strip()]) | |
| if not prompt.strip(): | |
| raise ValueError("Prompt is empty after concatenating messages") | |
| start_time = time.time() | |
| last_error: Optional[Exception] = None | |
| for attempt in range(1, settings.compression_max_retries + 1): | |
| try: | |
| results = await asyncio.wait_for( | |
| dependencies.compression_service.compress_batch_async( | |
| texts=[prompt], | |
| aggressiveness=aggressiveness, | |
| mode=mode, | |
| nli_active=settings.enable_nli_verification, | |
| language=language, | |
| privacy_mode=privacy_mode, | |
| ), | |
| timeout=settings.max_compression_timeout, | |
| ) | |
| if not results or not isinstance(results, list) or len(results) == 0: | |
| raise RuntimeError("compress_batch_async returned invalid results") | |
| res = results[0] | |
| new_messages = [m.model_dump() for m in messages] | |
| for i in range(len(new_messages) - 1, -1, -1): | |
| if new_messages[i]["role"] in ("user", "system"): | |
| new_messages[i]["content"] = res["compressed_text"] | |
| break | |
| metadata = { | |
| "original_tokens": res.get("original_tokens", 0), | |
| "compressed_tokens": res.get("compressed_tokens", 0), | |
| "tokens_saved": res.get("tokens_saved", 0), | |
| "compression_ratio": res.get("compression_ratio", 0.0), | |
| "cost_saved_usd": res.get("cost_saved_usd", 0.0), | |
| "safety_score": res.get("safety_score", 0.0), | |
| "alerts": res.get("alerts", []), | |
| "compression_latency_ms": (time.time() - start_time) * 1000, | |
| "cache_hit": False, | |
| } | |
| return new_messages, metadata | |
| except asyncio.TimeoutError as exc: | |
| last_error = exc | |
| if attempt < settings.compression_max_retries: | |
| backoff = 0.5 * (2 ** (attempt - 1)) | |
| await asyncio.sleep(backoff) | |
| continue | |
| except Exception as exc: | |
| last_error = exc | |
| if attempt < settings.compression_max_retries: | |
| backoff = 0.5 * (2 ** (attempt - 1)) | |
| await asyncio.sleep(backoff) | |
| continue | |
| raise RuntimeError( | |
| f"Compression failed after {settings.compression_max_retries} attempts: {last_error}" | |
| ) | |
| def _build_response_usage(compressed_tokens: int, response_text: str) -> Dict[str, int]: | |
| completion_tokens = len(response_text.split()) | |
| return { | |
| "prompt_tokens": compressed_tokens, | |
| "completion_tokens": completion_tokens, | |
| "total_tokens": compressed_tokens + completion_tokens, | |
| } | |
| async def _generate_stream_response(request: ChatCompletionRequest, prompt_text: str) -> StreamingResponse: | |
| async def event_generator() -> AsyncGenerator[str, None]: | |
| async for chunk in dependencies.llm_orchestrator.generate_stream( | |
| prompt=prompt_text, | |
| provider=(LLMProvider(request.provider) if request.provider else None), | |
| model=request.model, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| top_k=request.top_k, | |
| stop_sequences=request.stop, | |
| ): | |
| yield f"data: {json.dumps(chunk)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse(event_generator(), media_type="text/event-stream") | |
| async def chat_completions(request: ChatCompletionRequest, http_request: Request) -> Any: | |
| request_id = str(time.time_ns()) | |
| request_logger = get_request_logger(request_id) | |
| start_time = time.time() | |
| request_logger.info( | |
| "Received chat request: model=%s messages=%s mode=%s", | |
| request.model, | |
| len(request.messages), | |
| request.mode, | |
| ) | |
| if not dependencies.firewall or not dependencies.compression_service or not dependencies.llm_orchestrator or not dependencies.response_corrector or not dependencies.post_verifier: | |
| raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Server is not ready") | |
| if request.stream: | |
| request_logger.info("Stream enabled for this request") | |
| user_messages = [m.content for m in request.messages if m.role == "user" and m.content] | |
| if not user_messages: | |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Request must contain at least one user message") | |
| user_prompt = "\n".join(user_messages) | |
| action, violations = dependencies.firewall.check_prompt(user_prompt) | |
| if action == FirewallAction.BLOCK: | |
| request_logger.warning("Prompt blocked by firewall: %s", violations) | |
| raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Request blocked by security policy") | |
| if action == FirewallAction.ALERT: | |
| request_logger.warning("Firewall alert triggered: %s", violations) | |
| if action == FirewallAction.REWRITE: | |
| request_logger.info("Rewriting prompt due to firewall violations") | |
| user_prompt = dependencies.firewall.rewrite_prompt(user_prompt, violations) | |
| for i, message in enumerate(request.messages): | |
| if message.role == "user": | |
| request.messages[i].content = user_prompt | |
| break | |
| try: | |
| new_messages, metadata = await compress_prompt( | |
| messages=request.messages, | |
| aggressiveness=request.aggressiveness, | |
| mode=request.mode, | |
| language=request.language, | |
| privacy_mode=request.privacy_mode, | |
| request_id=request_id, | |
| ) | |
| prompt_text = new_messages[-1]["content"] | |
| except Exception as exc: | |
| request_logger.error("Compression failed: %s", exc) | |
| prompt_text = user_prompt | |
| new_messages = [m.model_dump() for m in request.messages] | |
| metadata = { | |
| "original_tokens": 0, | |
| "compressed_tokens": 0, | |
| "tokens_saved": 0, | |
| "compression_ratio": 0.0, | |
| "cost_saved_usd": 0.0, | |
| "safety_score": 0.0, | |
| "alerts": [f"Compression failed: {exc}"], | |
| "compression_latency_ms": 0, | |
| "cache_hit": False, | |
| "compression_failed": True, | |
| } | |
| try: | |
| manual_restrictions = None | |
| if request.manual_restrictions: | |
| manual_restrictions = [Restriction(**r) for r in request.manual_restrictions] | |
| shield_result = dependencies.compression_service._shield_with_cache( | |
| text=user_prompt, | |
| manual_restrictions=manual_restrictions, | |
| mode=request.mode, | |
| privacy_mode=request.privacy_mode, | |
| ) | |
| if shield_result is None: | |
| raise RuntimeError("Shield result is None") | |
| sentences = dependencies.compression_service.segmenter.split_sentences( | |
| shield_result.shielded_text, | |
| language=request.language, | |
| ) | |
| safety_report = dependencies.compression_service.safety.validate( | |
| original_text=user_prompt, | |
| compressed_text=prompt_text, | |
| shield_result=shield_result, | |
| original_sentences=sentences, | |
| compressed_indices=None, | |
| mode=request.mode, | |
| use_perplexity=request.use_perplexity, | |
| ) | |
| prompt_text = safety_report.final_text | |
| metadata["safety_score"] = safety_report.safety_score | |
| metadata["forced_sentences_added"] = safety_report.forced_sentences_added | |
| if safety_report.perplexity is not None: | |
| metadata["perplexity"] = safety_report.perplexity | |
| except Exception as exc: | |
| request_logger.error("Safety validation failed: %s", exc) | |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Prompt safety validation failed") | |
| if request.stream: | |
| return await _generate_stream_response(request, prompt_text) | |
| try: | |
| generated = await asyncio.wait_for( | |
| dependencies.llm_orchestrator.generate( | |
| prompt=prompt_text, | |
| provider=(LLMProvider(request.provider) if request.provider else None), | |
| model=request.model, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| top_k=request.top_k, | |
| stop_sequences=request.stop, | |
| ), | |
| timeout=settings.llm_request_timeout, | |
| ) | |
| except asyncio.TimeoutError: | |
| request_logger.error("LLM generation timeout") | |
| raise HTTPException(status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="LLM generation timed out") | |
| except LLMProviderError as exc: | |
| request_logger.error("LLM provider error: %s", exc) | |
| raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) | |
| response_text = generated.text if isinstance(generated, LLMResponse) else str(generated) | |
| if request.privacy_mode: | |
| response_text = dependencies.compression_service.reconstructor._reinject_entities(response_text, shield_result.placeholder_map) | |
| final_response = dependencies.response_corrector.correct(response_text, shield_result) | |
| verification = dependencies.post_verifier.verify(final_response, shield_result) | |
| metadata.update({ | |
| "post_llm_confidence": verification.confidence_score, | |
| "post_llm_violations": verification.violations, | |
| "cache_hit": False, | |
| }) | |
| if verification.confidence_score < request.min_confidence and not request.auto_correct: | |
| raise HTTPException( | |
| status_code=status.HTTP_409_CONFLICT, | |
| detail="Response does not meet confidence threshold", | |
| ) | |
| final_response_text = final_response | |
| correction_attempts = 0 | |
| while ( | |
| request.auto_correct | |
| and verification.confidence_score < request.min_confidence | |
| and correction_attempts < settings.max_regeneration_attempts | |
| ): | |
| correction_attempts += 1 | |
| try: | |
| corrected = await dependencies.llm_orchestrator.generate( | |
| prompt=f"Correct the following response to satisfy policy: {final_response}", | |
| provider=(LLMProvider(request.provider) if request.provider else None), | |
| model=request.model, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature * 0.8, | |
| top_p=request.top_p, | |
| top_k=request.top_k, | |
| stop_sequences=request.stop, | |
| ) | |
| final_response_text = corrected.text if isinstance(corrected, LLMResponse) else str(corrected) | |
| verification = dependencies.post_verifier.verify(final_response_text, shield_result) | |
| if verification.confidence_score >= request.min_confidence: | |
| metadata["regenerated"] = True | |
| metadata["regeneration_attempts"] = correction_attempts | |
| break | |
| except Exception as exc: | |
| request_logger.warning("Auto-correction attempt failed: %s", exc) | |
| break | |
| metadata["auto_corrected"] = metadata.get("regenerated", False) | |
| metadata["regeneration_attempts"] = metadata.get("regeneration_attempts", 0) | |
| if verification.confidence_score < request.min_confidence: | |
| raise HTTPException( | |
| status_code=status.HTTP_409_CONFLICT, | |
| detail="Response does not meet security confidence threshold after correction", | |
| ) | |
| usage = _build_response_usage(metadata.get("compressed_tokens", 0), final_response_text) | |
| metadata["total_latency_ms"] = round((time.time() - start_time) * 1000, 2) | |
| return ChatCompletionResponse( | |
| model=request.model, | |
| choices=[_make_choice(final_response_text)], | |
| usage=usage, | |
| nlproxy=metadata, | |
| ) | |