Spaces:
Sleeping
Sleeping
| from operator import itemgetter | |
| from langchain_core.vectorstores import VectorStoreRetriever | |
| from langchain.schema.runnable import RunnableLambda, RunnableParallel, RunnableSequence | |
| from langchain.chat_models import ChatOpenAI, AzureChatOpenAI | |
| from langchain.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.documents import Document | |
| from langchain_core.messages.ai import AIMessage | |
| from langchain_core.messages.human import HumanMessage | |
| from langchain_core.messages.system import SystemMessage | |
| from langchain_core.messages.function import FunctionMessage | |
| template = """ | |
| You are a helpful assistant, your job is to answer the user's question using the relevant context. | |
| CONTEXT | |
| ========= | |
| {context} | |
| ========= | |
| User question: {question} | |
| """ | |
| prompt = PromptTemplate.from_template(template=template) | |
| chat_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", """ | |
| You are a helpful assistant, your job is to answer the user's question using the relevant context in the context section and in the conversation history. | |
| Make sure to relate the question to the conversation history and the context in the context section. If the question, the context and the conversation history | |
| does not align please let the user know about this and ask for further clarification. | |
| ========= | |
| CONTEXT: | |
| {context} | |
| ========= | |
| PREVIOUS CONVERSATION HISTORY: | |
| {chat_history} | |
| """), | |
| # MessagesPlaceholder(variable_name="chat_history"), | |
| ("human", "{question}") | |
| ]) | |
| def to_doc(input: AIMessage) -> list[Document]: | |
| return [Document(page_content="LLM", metadata={'chunk': 1.0, 'page_number': 1.0, 'text':input.content})] | |
| def merge_docs(a: dict[str, list[Document]]) -> list[Document]: | |
| merged_docs = [] | |
| for key,value in a.items(): | |
| merged_docs.extend(value) | |
| return merged_docs | |
| def create_chain(**kwargs) -> RunnableSequence: | |
| """ | |
| Requires retriever, llm and prompt | |
| """ | |
| retriever: VectorStoreRetriever = kwargs["retriever"] | |
| llm: AzureChatOpenAI = kwargs.get("llm", None) | |
| if not isinstance(retriever, VectorStoreRetriever): | |
| raise ValueError | |
| if not isinstance(llm, AzureChatOpenAI): | |
| raise ValueError | |
| docs_chain = (itemgetter("question") | retriever).with_config(config={"run_name": "docs"}) | |
| self_knowledge_chain = (itemgetter("question") | llm | to_doc).with_config(config={"run_name": "self knowledge"}) | |
| response_chain = (chat_prompt | llm).with_config(config={"run_name": "response"}) | |
| merge_docs_link = RunnableLambda(merge_docs).with_config(config={"run_name": "merge docs"}) | |
| context_chain = ( | |
| RunnableParallel( | |
| { | |
| "docs": docs_chain, | |
| "self_knowledge": self_knowledge_chain | |
| } | |
| ).with_config(config={"run_name": "parallel context"}) | |
| | merge_docs_link | |
| ) | |
| retrieval_augmented_qa_chain = ( | |
| RunnableParallel({ | |
| "question": itemgetter("question"), | |
| "chat_history": itemgetter("chat_history"), | |
| "context": context_chain | |
| }) | |
| | RunnableParallel({ | |
| "response": response_chain, | |
| "context": itemgetter("context"), | |
| "chat_history": itemgetter("chat_history") | |
| }) | |
| ) | |
| return retrieval_augmented_qa_chain |