Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| from llama_index.core import VectorStoreIndex | |
| from llama_index.core.query_pipeline import ( | |
| QueryPipeline, | |
| InputComponent, | |
| ArgPackComponent, | |
| ) | |
| from llama_index.core.prompts import PromptTemplate | |
| from llama_index.llms.openai import OpenAI | |
| from llama_index.postprocessor.colbert_rerank import ColbertRerank | |
| from typing import Any, Dict, List, Optional | |
| from llama_index.core.bridge.pydantic import Field | |
| from llama_index.core.llms import ChatMessage | |
| from llama_index.core.query_pipeline import CustomQueryComponent | |
| from llama_index.core.schema import NodeWithScore | |
| from llama_index.core.memory import ChatMemoryBuffer | |
| llm = OpenAI( | |
| model="gpt-3.5-turbo-0125", | |
| api_key=os.getenv("OPENAI_API_KEY"), | |
| ) | |
| # First, we create an input component to capture the user query | |
| input_component = InputComponent() | |
| # Next, we use the LLM to rewrite a user query | |
| rewrite = ( | |
| "Please write a query to a semantic search engine using the current conversation.\n" | |
| "\n" | |
| "\n" | |
| "{chat_history_str}" | |
| "\n" | |
| "\n" | |
| "Latest message: {query_str}\n" | |
| 'Query:"""\n' | |
| ) | |
| rewrite_template = PromptTemplate(rewrite) | |
| # we will retrieve two times, so we need to pack the retrieved nodes into a single list | |
| argpack_component = ArgPackComponent() | |
| # then postprocess/rerank with Colbert | |
| reranker = ColbertRerank(top_n=3) | |
| DEFAULT_CONTEXT_PROMPT = ( | |
| "Here is some context that may be relevant:\n" | |
| "-----\n" | |
| "{node_context}\n" | |
| "-----\n" | |
| "Please write a response to the following question, using the above context:\n" | |
| "{query_str}\n" | |
| "Please formate your response in the following way:\n" | |
| "Your answer here.\n" | |
| "Reference:\n" | |
| " Your references here (e.g. page numbers, titles, etc.).\n" | |
| ) | |
| class ResponseWithChatHistory(CustomQueryComponent): | |
| llm: OpenAI = Field(..., description="OpenAI LLM") | |
| system_prompt: Optional[str] = Field( | |
| default=None, description="System prompt to use for the LLM" | |
| ) | |
| context_prompt: str = Field( | |
| default=DEFAULT_CONTEXT_PROMPT, | |
| description="Context prompt to use for the LLM", | |
| ) | |
| def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]: | |
| """Validate component inputs during run_component.""" | |
| # NOTE: this is OPTIONAL but we show you where to do validation as an example | |
| return input | |
| def _input_keys(self) -> set: | |
| """Input keys dict.""" | |
| # NOTE: These are required inputs. If you have optional inputs please override | |
| # `optional_input_keys_dict` | |
| return {"chat_history", "nodes", "query_str"} | |
| def _output_keys(self) -> set: | |
| return {"response"} | |
| def _prepare_context( | |
| self, | |
| chat_history: List[ChatMessage], | |
| nodes: List[NodeWithScore], | |
| query_str: str, | |
| ) -> List[ChatMessage]: | |
| node_context = "" | |
| for idx, node in enumerate(nodes): | |
| node_text = node.get_content(metadata_mode="llm") | |
| node_context += f"Context Chunk {idx}:\n{node_text}\n\n" | |
| formatted_context = self.context_prompt.format( | |
| node_context=node_context, query_str=query_str | |
| ) | |
| user_message = ChatMessage(role="user", content=formatted_context) | |
| chat_history.append(user_message) | |
| if self.system_prompt is not None: | |
| chat_history = [ | |
| ChatMessage(role="system", content=self.system_prompt) | |
| ] + chat_history | |
| return chat_history | |
| def _run_component(self, **kwargs) -> Dict[str, Any]: | |
| """Run the component.""" | |
| chat_history = kwargs["chat_history"] | |
| nodes = kwargs["nodes"] | |
| query_str = kwargs["query_str"] | |
| prepared_context = self._prepare_context(chat_history, nodes, query_str) | |
| response = llm.chat(prepared_context) | |
| return {"response": response} | |
| async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]: | |
| """Run the component asynchronously.""" | |
| # NOTE: Optional, but async LLM calls are easy to implement | |
| chat_history = kwargs["chat_history"] | |
| nodes = kwargs["nodes"] | |
| query_str = kwargs["query_str"] | |
| prepared_context = self._prepare_context(chat_history, nodes, query_str) | |
| response = await llm.achat(prepared_context) | |
| return {"response": response} | |
| class LlamaCustomV2: | |
| response_component = ResponseWithChatHistory( | |
| llm=llm, | |
| system_prompt=( | |
| "You are a Q&A system. You will be provided with the previous chat history, " | |
| "as well as possibly relevant context, to assist in answering a user message." | |
| ), | |
| ) | |
| def __init__(self, model_name: str, index: VectorStoreIndex): | |
| self.model_name = model_name | |
| self.index = index | |
| self.retriever = index.as_retriever() | |
| self.chat_mode = "condense_plus_context" | |
| self.memory = ChatMemoryBuffer.from_defaults() | |
| self.verbose = True | |
| self._build_pipeline() | |
| def _build_pipeline(self): | |
| self.pipeline = QueryPipeline( | |
| modules={ | |
| "input": input_component, | |
| "rewrite_template": rewrite_template, | |
| "llm": llm, | |
| "rewrite_retriever": self.retriever, | |
| "query_retriever": self.retriever, | |
| "join": argpack_component, | |
| "reranker": reranker, | |
| "response_component": self.response_component, | |
| }, | |
| verbose=self.verbose, | |
| ) | |
| # run both retrievers -- once with the hallucinated query, once with the real query | |
| self.pipeline.add_link( | |
| "input", "rewrite_template", src_key="query_str", dest_key="query_str" | |
| ) | |
| self.pipeline.add_link( | |
| "input", | |
| "rewrite_template", | |
| src_key="chat_history_str", | |
| dest_key="chat_history_str", | |
| ) | |
| self.pipeline.add_link("rewrite_template", "llm") | |
| self.pipeline.add_link("llm", "rewrite_retriever") | |
| self.pipeline.add_link("input", "query_retriever", src_key="query_str") | |
| # each input to the argpack component needs a dest key -- it can be anything | |
| # then, the argpack component will pack all the inputs into a single list | |
| self.pipeline.add_link("rewrite_retriever", "join", dest_key="rewrite_nodes") | |
| self.pipeline.add_link("query_retriever", "join", dest_key="query_nodes") | |
| # reranker needs the packed nodes and the query string | |
| self.pipeline.add_link("join", "reranker", dest_key="nodes") | |
| self.pipeline.add_link( | |
| "input", "reranker", src_key="query_str", dest_key="query_str" | |
| ) | |
| # synthesizer needs the reranked nodes and query str | |
| self.pipeline.add_link("reranker", "response_component", dest_key="nodes") | |
| self.pipeline.add_link( | |
| "input", "response_component", src_key="query_str", dest_key="query_str" | |
| ) | |
| self.pipeline.add_link( | |
| "input", | |
| "response_component", | |
| src_key="chat_history", | |
| dest_key="chat_history", | |
| ) | |
| def get_response(self, query_str: str, chat_history: List[ChatMessage]): | |
| chat_history = self.memory.get() | |
| char_history_str = "\n".join([str(x) for x in chat_history]) | |
| response = self.pipeline.run( | |
| query_str=query_str, | |
| chat_history=chat_history, | |
| chat_history_str=char_history_str, | |
| ) | |
| user_msg = ChatMessage(role="user", content=query_str) | |
| print("user_msg: ", str(user_msg)) | |
| print("response: ", str(response.message)) | |
| self.memory.put(user_msg) | |
| self.memory.put(response.message) | |
| return str(response.message) | |
| def get_stream_response(self, query_str: str, chat_history: List[ChatMessage]): | |
| response = self.get_response(query_str=query_str, chat_history=chat_history) | |
| for word in response.split(): | |
| yield word + " " | |
| time.sleep(0.05) | |