"""RAG (Retrieval-Augmented Generation) chain implementation""" from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableParallel, RunnablePassthrough from legisqa_local.core.llm import get_llm from legisqa_local.core.vectorstore import get_vectorstore, get_vectorstore_filter from legisqa_local.utils.formatting import format_docs def create_rag_chain(llm, retriever): """Create a RAG chain with the given LLM and retriever""" QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. When citing legis_id, use the same format as the excerpts (e.g. "116-hr-125"). If you don't know how to respond, just tell the user. --- Congressional Legislation Excerpts: {context} --- Query: {query}""" prompt = ChatPromptTemplate.from_messages([ ("human", QUERY_RAG_TEMPLATE), ]) rag_chain = ( RunnableParallel({ "docs": retriever, "query": RunnablePassthrough(), }) .assign(context=lambda x: format_docs(x["docs"])) .assign(aimessage=prompt | llm) ) return rag_chain def process_query(gen_config: dict, ret_config: dict, query: str): """Process a query using RAG""" # Check if vectorstore is loaded vectorstore = get_vectorstore() if vectorstore is None: return { "aimessage": "⏳ Vectorstore is still loading. Please wait a moment and try again.", "docs": [], "query": query } llm = get_llm(gen_config) vs_filter = get_vectorstore_filter(ret_config) # ChromaDB uses 'filter' parameter in search_kwargs search_kwargs = {"k": ret_config["n_ret_docs"]} if vs_filter: search_kwargs["filter"] = vs_filter retriever = vectorstore.as_retriever(search_kwargs=search_kwargs) rag_chain = create_rag_chain(llm, retriever) response = rag_chain.invoke(query) return response