Spaces:
Running
Running
| # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: BSD 2-Clause License | |
| """NVIDIA Retrieval-Augmented Generation (RAG) service implementation. | |
| Integrates with NVIDIA's Retrieval-Augmented Generation service to enhance responses | |
| by incorporating knowledge from external documents. Features include: | |
| - Document collection management | |
| - Real-time retrieval and citation | |
| - OpenAI-compatible LLM interface | |
| - Configurable retrieval parameters | |
| """ | |
| import json | |
| import httpx | |
| from loguru import logger | |
| from openai.types.chat import ChatCompletionMessageParam | |
| from pipecat.frames.frames import ( | |
| CancelFrame, | |
| EndFrame, | |
| ErrorFrame, | |
| Frame, | |
| LLMFullResponseEndFrame, | |
| LLMFullResponseStartFrame, | |
| LLMMessagesFrame, | |
| StartInterruptionFrame, | |
| TextFrame, | |
| VisionImageRawFrame, | |
| ) | |
| from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame | |
| from pipecat.processors.frame_processor import FrameDirection | |
| from pipecat.services.openai.llm import OpenAILLMService | |
| from nvidia_pipecat.frames.nvidia_rag import NvidiaRAGCitation, NvidiaRAGCitationsFrame, NvidiaRAGSettingsFrame | |
| class NvidiaRAGService(OpenAILLMService): | |
| """This is the base class for all services that use NVIDIA RAG/GenerativeAIExamples. | |
| Requires deployed NVIDIA RAG server. For deployment instructions see: | |
| https://github.com/NVIDIA-AI-Blueprints/rag/blob/main/docs/quickstart.md | |
| Attributes: | |
| collection_name: Document collection identifier. | |
| rag_server_url: RAG API endpoint URL. | |
| stop_words: Words that stop LLM generation. | |
| temperature: Controls response randomness (0-1). | |
| top_p: Token probability threshold (0-1). | |
| max_tokens: Maximum response length. | |
| use_knowledge_base: Whether to use RAG retrieval. | |
| vdb_top_k: Number of chunks to retrieve. | |
| reranker_top_k: Number of chunks to rerank. | |
| enable_citations: Whether to return citations. | |
| suffix_prompt: Text appended to last user message. | |
| """ | |
| _shared_session: httpx.AsyncClient | None = None | |
| def __init__( | |
| self, | |
| collection_name: str, | |
| rag_server_url: str = "http://localhost:8081", | |
| stop_words: list | None = None, | |
| temperature: float = 0.2, | |
| top_p: float = 0.7, | |
| max_tokens: int = 1000, | |
| use_knowledge_base: bool = True, | |
| vdb_top_k: int = 20, | |
| reranker_top_k: int = 4, | |
| enable_citations: bool = True, | |
| suffix_prompt: str | None = None, | |
| session: httpx.AsyncClient | None = None, | |
| **kwargs, | |
| ): | |
| """Initialize the NVIDIA RAG service. | |
| Args: | |
| collection_name: Document collection identifier. | |
| rag_server_url: RAG API endpoint URL. | |
| stop_words: Words that stop LLM generation. | |
| temperature: Controls response randomness (0-1). | |
| top_p: Token probability threshold (0-1). | |
| max_tokens: Maximum response length. | |
| use_knowledge_base: Whether to use RAG retrieval. | |
| vdb_top_k: Number of chunks to retrieve. | |
| reranker_top_k: Number of chunks to rerank. | |
| enable_citations: Whether to return citations. | |
| suffix_prompt: Text appended to last user message. | |
| session: Optional httpx.AsyncClient. Creates new if None. | |
| **kwargs: Additional arguments passed to OpenAILLMService. | |
| """ | |
| super().__init__(api_key="", **kwargs) | |
| self.collection_name = collection_name | |
| self.rag_server_url = rag_server_url | |
| if stop_words is None: | |
| stop_words = [] | |
| self.stop_words = stop_words | |
| self.temperature = temperature | |
| self.top_p = top_p | |
| self.max_tokens = max_tokens | |
| self.use_knowledge_base = use_knowledge_base | |
| self.vdb_top_k = vdb_top_k | |
| self.reranker_top_k = reranker_top_k | |
| self.enable_citations = enable_citations | |
| self.suffix_prompt = suffix_prompt | |
| self._external_client_session = None | |
| self._current_task = None | |
| if session is not None: | |
| self._external_client_session = session | |
| def shared_session(self) -> httpx.AsyncClient: | |
| """Get the shared HTTP client session. | |
| Returns: | |
| httpx.AsyncClient: The shared session for making HTTP requests. | |
| Creates a new session if none exists and no external session was provided. | |
| """ | |
| if self._external_client_session is not None: | |
| return self._external_client_session | |
| if NvidiaRAGService._shared_session is None: | |
| NvidiaRAGService._shared_session = httpx.AsyncClient() | |
| return NvidiaRAGService._shared_session | |
| def shared_session(self, shared_session: httpx.AsyncClient): | |
| """Set the shared HTTP client session. | |
| Args: | |
| shared_session: The httpx.AsyncClient to use for all instances. | |
| """ | |
| NvidiaRAGService._shared_session = shared_session | |
| async def stop(self, frame: EndFrame): | |
| """Stop the NVIDIA RAG service and cleanup resources. | |
| Args: | |
| frame: The EndFrame that triggered the stop. | |
| """ | |
| await super().stop(frame) | |
| if self._current_task: | |
| await self.cancel_task(self._current_task) | |
| async def cancel(self, frame: CancelFrame): | |
| """Cancel the NVIDIA RAG service and cleanup resources. | |
| Args: | |
| frame: The CancelFrame that triggered the cancellation. | |
| """ | |
| await super().cancel(frame) | |
| if self._current_task: | |
| await self.cancel_task(self._current_task) | |
| async def cleanup(self): | |
| """Clean up resources used by the RAG service. | |
| Closes the shared HTTP client session if it exists and performs parent cleanup. | |
| """ | |
| await super().cleanup() | |
| await self._close_client_session() | |
| async def _close_client_session(self): | |
| """Close the Client Session if it exists.""" | |
| if NvidiaRAGService._shared_session: | |
| await NvidiaRAGService._shared_session.aclose() | |
| NvidiaRAGService._shared_session = None | |
| async def _get_rag_response(self, request_json: dict): | |
| resp = await self.shared_session.post(f"{self.rag_server_url}/generate", json=request_json) | |
| return resp | |
| async def _process_context(self, context: OpenAILLMContext): | |
| """Processes LLM context through RAG pipeline. | |
| Args: | |
| context: Contains conversation history and settings. | |
| Raises: | |
| Exception: If invalid message role or empty query. | |
| """ | |
| try: | |
| messages: list[ChatCompletionMessageParam] = context.get_messages() | |
| chat_details = [] | |
| for msg in messages: | |
| if msg["role"] != "system" and msg["role"] != "user" and msg["role"] != "assistant": | |
| raise Exception(f"Unexpected role {msg['role']} found!") | |
| chat_details.append({"role": msg["role"], "content": msg["content"]}) | |
| if self.suffix_prompt: | |
| for i in range(len(chat_details) - 1, -1, -1): | |
| if chat_details[i]["role"] == "user": | |
| chat_details[i]["content"] += f" {self.suffix_prompt}" | |
| break | |
| logger.debug(f"Chat details: {chat_details}") | |
| if len(chat_details) == 0 or all(msg["content"] == "" for msg in chat_details) or not self.collection_name: | |
| raise Exception("No query or collection name is provided..") | |
| """ | |
| Call the RAG chain server and return the streaming response. | |
| """ | |
| request_json = { | |
| "messages": chat_details, | |
| "use_knowledge_base": self.use_knowledge_base, | |
| "temperature": self.temperature, | |
| "top_p": self.top_p, | |
| "max_tokens": self.max_tokens, | |
| "vdb_top_k": self.vdb_top_k, | |
| "reranker_top_k": self.reranker_top_k, | |
| "collection_name": self.collection_name, | |
| "stop": self.stop_words, | |
| "enable_citations": self.enable_citations, | |
| } | |
| await self.start_ttfb_metrics() | |
| full_response = "" | |
| resp = await self._get_rag_response(request_json) | |
| try: | |
| async for chunk in resp.aiter_lines(): | |
| await self.stop_ttfb_metrics() | |
| citations = [] | |
| try: | |
| chunk = chunk.strip("\n") | |
| try: | |
| if len(chunk) > 6: | |
| parsed = json.loads(chunk[6:]) | |
| message = parsed["choices"][0]["message"]["content"] | |
| if "citations" in parsed: | |
| for citation in parsed["citations"]["results"]: | |
| citations.append( | |
| NvidiaRAGCitation( | |
| document_type=str(citation["document_type"]), | |
| document_id=str(citation["document_id"]), | |
| document_name=str(citation["document_name"]), | |
| content=str(citation["content"]).encode(), | |
| metadata=str(citation["metadata"]), | |
| score=float(citation["score"]), | |
| ) | |
| ) | |
| else: | |
| logger.warning(f"Received empty RAG response chunk '{chunk}'.") | |
| message = "" | |
| except Exception as e: | |
| logger.debug(f"Parsing RAG response chunk failed. Error: {e}") | |
| message = "" | |
| if not message and not citations: | |
| continue | |
| full_response += message | |
| if citations: | |
| scores = [citation.score for citation in citations] | |
| types = [citation.document_type for citation in citations] | |
| logger.debug(f"Received total {len(citations)} RAG citations") | |
| logger.debug(f"Received RAG citation types: {types}") | |
| logger.debug(f"Received RAG citation scores: {scores}") | |
| await self.push_frame(NvidiaRAGCitationsFrame(citations=citations)) | |
| if message: | |
| await self.push_frame(TextFrame(message)) | |
| except Exception as e: | |
| await self.push_error(ErrorFrame("Internal error in RAG stream: " + str(e))) | |
| finally: | |
| await resp.aclose() | |
| logger.debug(f"Full RAG response: {full_response}") | |
| except Exception as e: | |
| logger.error(f"An error occurred in http request to RAG endpoint, Error: {e}") | |
| await self.push_error(ErrorFrame("An error occurred in http request to RAG endpoint, Error: " + str(e))) | |
| async def _update_settings(self, settings): | |
| """Updates service settings. | |
| Args: | |
| settings: Dictionary of setting name-value pairs. | |
| """ | |
| for setting, value in settings.items(): | |
| logger.debug(f"Updating {setting} to {value} via NvidiaRAGSettingsFrame") | |
| match setting: | |
| case "collection_name": | |
| self.collection_name = value | |
| case "rag_server_url": | |
| self.rag_server_url = value | |
| case "stop_words": | |
| self.stop_words = value | |
| case "temperature": | |
| self.temperature = value | |
| case "top_p": | |
| self.top_p = value | |
| case "max_tokens": | |
| self.max_tokens = value | |
| case "use_knowledge_base": | |
| self.use_knowledge_base = value | |
| case "vdb_top_k": | |
| self.vdb_top_k = value | |
| case "reranker_top_k": | |
| self.reranker_top_k = value | |
| case "enable_citations": | |
| self.enable_citations = value | |
| case _: | |
| logger.warning(f"Unknown setting for NvidiaRAG service: {setting}") | |
| async def _process_context_and_frames(self, context: OpenAILLMContext): | |
| """Process context and handle start/end frames with metrics.""" | |
| await self.push_frame(LLMFullResponseStartFrame()) | |
| await self.start_processing_metrics() | |
| await self._process_context(context) | |
| await self.stop_processing_metrics() | |
| await self.push_frame(LLMFullResponseEndFrame()) | |
| async def process_frame(self, frame: Frame, direction: FrameDirection): | |
| """Processes pipeline frames. | |
| Handles settings updates and parent frame processing. | |
| Args: | |
| frame: Input frame to process. | |
| direction: Frame processing direction. | |
| """ | |
| context = None | |
| if isinstance(frame, NvidiaRAGSettingsFrame): | |
| await self._update_settings(frame.settings) | |
| if isinstance(frame, OpenAILLMContextFrame): | |
| context: OpenAILLMContext = frame.context | |
| elif isinstance(frame, LLMMessagesFrame): | |
| context = OpenAILLMContext.from_messages(frame.messages) | |
| elif isinstance(frame, VisionImageRawFrame): | |
| context = OpenAILLMContext() | |
| context.add_image_frame_message(format=frame.format, size=frame.size, image=frame.image, text=frame.text) | |
| elif isinstance(frame, StartInterruptionFrame): | |
| if self._current_task is not None: | |
| await self.cancel_task(self._current_task) | |
| await self._start_interruption() | |
| await self.stop_all_metrics() | |
| await self.push_frame(frame) | |
| else: | |
| await super().process_frame(frame, direction) | |
| if context: | |
| new_task = self.create_task(self._process_context_and_frames(context)) | |
| if self._current_task is not None: | |
| await self.cancel_task(self._current_task) | |
| self._current_task = new_task | |
| self._current_task.add_done_callback(lambda _: setattr(self, "_current_task", None)) | |