Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| from pydantic import BaseModel, Field | |
| from typing import Literal, List, Any, Annotated | |
| from typing_extensions import TypedDict | |
| from langchain.schema import Document | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.prompts.prompt import PromptTemplate | |
| from langchain_core.messages import HumanMessage, AIMessage, AnyMessage | |
| from langgraph.graph import END, StateGraph, MessagesState, START | |
| from langgraph.graph.message import add_messages | |
| from huggingface_hub import InferenceClient | |
| from dotenv import load_dotenv | |
| load_dotenv(verbose=True) | |
| assert os.getenv("PINECONE_API_KEY") is not None | |
| assert os.getenv("HUGGINGFACEHUB_EMBEDDINGS_MODEL") is not None | |
| assert os.getenv("TAVILY_API_KEY") is not None | |
| logger = logging.getLogger(__name__) # Child logger for this module | |
| logger.setLevel(logging.INFO) | |
| logger.info(f"""correctiveRag.py:Config: | |
| GROQ_MODEL = {os.getenv('GROQ_MODEL')} | |
| HUGGINGFACEHUB_EMBEDDINGS_MODEL = {os.getenv('HUGGINGFACEHUB_EMBEDDINGS_MODEL')} | |
| PINECONE_API_KEY = {os.getenv("PINECONE_API_KEY")[:5]} | |
| """) | |
| # Prepare the LLM | |
| from langchain_groq import ChatGroq | |
| assert os.getenv('GROQ_MODEL') is not None, "GROQ_MODEL not set" | |
| assert os.getenv('GROQ_API_KEY') is not None, "GROQ_API_KEY not set" | |
| llm = ChatGroq(model_name=os.getenv('GROQ_MODEL'), temperature=0, verbose=True) | |
| # For using Grok | |
| # from langchain_openai import ChatOpenAI | |
| # assert os.getenv('XAI_API_KEY') is not None, "XAI_API_KEY not set" | |
| # assert os.getenv('XAI_MODEL') is not None, "XAI_MODEL not set" | |
| # assert os.getenv('XAI_BASE_URL') is not None, "XAI_BASE_URL not set" | |
| # llm = ChatOpenAI( | |
| # api_key=os.getenv("XAI_API_KEY"), | |
| # base_url=os.getenv("XAI_BASE_URL"), | |
| # model=os.getenv("XAI_MODEL"), | |
| # temperature=0.1) | |
| # from langchain_openai import ChatOpenAI | |
| # assert os.getenv('OPENAI_MODEL_NAME') is not None, "GROQ_MODEL not set" | |
| # llm = ChatOpenAI(model=os.getenv("OPENAI_MODEL_NAME"), temperature=0.1, verbose=True) | |
| # Huggingface - Does not support structured_output | |
| # llm = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
| # Prepare the retriever | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_pinecone import PineconeVectorStore | |
| index_name, namespace = 'courses', 'dsa' | |
| # Simple RAG | |
| # embeddings = HuggingFaceEmbeddings(model_name=os.getenv("HUGGINGFACEHUB_EMBEDDINGS_MODEL")) | |
| # docsearch = PineconeVectorStore.from_existing_index(embedding=embeddings, index_name=index_name, namespace=namespace) | |
| # retriever = docsearch.as_retriever(search_type="mmr", search_kwargs={ 'k': 5 }) | |
| # Large-Small RAG | |
| def larger_from_nearby(vectorstore, doc: Document, range:int) -> Document: | |
| """ | |
| Given a document, find the "parent" document as a range of chunks around the central chunk | |
| """ | |
| filter0 = { "document" : doc.metadata['document'] } | |
| filter1 = { "chunk": { "$gte" : doc.metadata['chunk']-range } } | |
| filter2 = { "chunk": { "$lte" : doc.metadata['chunk']+range } } | |
| and_filter = { "$and" : [ filter0, filter1, filter2 ] } | |
| range_docs = vectorstore.similarity_search(query='', k=2*range+1, filter=and_filter) | |
| content = '' | |
| for doc in range_docs: | |
| content += doc.page_content | |
| full_document = Document(page_content=content, metadata=doc.metadata) | |
| return full_document | |
| def larger_retriever(vectorstore, query:str, topK:int): | |
| RANGE=2 # -RANGE...+RANGE | |
| logger.info(f'larger_retriever: with RANGE={RANGE}') | |
| docs = vectorstore.similarity_search(query, k=topK) | |
| larger_documents = list(map(lambda d: larger_from_nearby(vectorstore, d, RANGE), docs)) | |
| logger.info(f'larger_retriever: Found {len(larger_documents)} documents.') | |
| return larger_documents | |
| embeddings = HuggingFaceEmbeddings(model_name=os.getenv("HUGGINGFACEHUB_EMBEDDINGS_MODEL")) | |
| vectorstore = PineconeVectorStore.from_existing_index(embedding=embeddings, index_name=index_name, namespace=namespace) | |
| # docs = larger_retriever(vectorstore, query, 5) | |
| retriever = lambda query: larger_retriever(vectorstore, query, 5) # TODO topK | |
| # Classify question | |
| class ClassifyQuestion(BaseModel): | |
| """Binary score to decide if need to retrieve documents from the vectorstore about data structures and algorithms. | |
| The binary_score is "yes" to indicate that document retrieval is needed, otherwise is "no".""" | |
| binary_score: str = Field(description="If the question is about data structures and algorithms answer `yes`, otherwise answer `no`") | |
| # justification: str = Field(description="Explained reasoning for giving the yes/no score") | |
| # LLM with function call | |
| structured_llm_grader = llm.with_structured_output(ClassifyQuestion) | |
| # Prompt | |
| system = """You are an expert at classifying user questions. | |
| If the question are specific about data structures and algorithms, then answer `yes` to indicate that document retrieval is needed. | |
| Otherwise, it is a question as a general question, answer `no`. | |
| """ | |
| grade_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system), | |
| ("human", "Question: {question}"), | |
| ] | |
| ) | |
| retriever_grader = grade_prompt | structured_llm_grader | |
| # Retrieval grader | |
| class GradeDocuments(BaseModel): | |
| """Binary score for relevance check on retrieved documents.""" | |
| binary_score: str = Field( | |
| description="Documents are relevant to the question, 'yes' or 'no'" | |
| ) | |
| # LLM with function call | |
| structured_llm_grader = llm.with_structured_output(GradeDocuments) | |
| # Prompt | |
| system = """You are a grader assessing relevance of a retrieved document to a user question. | |
| If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. | |
| It does not need to be a stringent test. The goal is to filter out erroneous retrievals. | |
| Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""" | |
| grade_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system), | |
| ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"), | |
| ] | |
| ) | |
| retrieval_grader = grade_prompt | structured_llm_grader | |
| # Create the RAG chain | |
| from langchain import hub | |
| from langchain_core.output_parsers import StrOutputParser | |
| # prompt = hub.pull("rlm/rag-prompt") | |
| # print('----', prompt, '---') | |
| template = """You are an assistant for question-answering tasks. | |
| Use the following pieces of retrieved context to answer the question. | |
| If you don't know the answer, just say that you don't know. | |
| Please keep the answer concise and to the point. | |
| Context: {context} | |
| Question: {question} | |
| Answer: | |
| """ | |
| prompt_template = PromptTemplate.from_template(template=template) | |
| rag_chain = prompt_template | llm | StrOutputParser() | |
| # Question rewriter | |
| system = """You a question re-writer that converts an input question to a better version that is optimized | |
| for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning. | |
| Return only the re-written question. Do not return anything else. | |
| """ | |
| re_write_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system), | |
| ("human", "Here is the initial question: \n\n {question} \n Formulate an improved question."), | |
| ] | |
| ) | |
| question_rewriter = re_write_prompt | llm | StrOutputParser() | |
| # Web search tool | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| web_search_tool = TavilySearchResults(k=3) | |
| # Define the workflow Graph | |
| class GraphState(TypedDict): | |
| """ | |
| Represents the state of our graph. | |
| Attributes: | |
| messages: conversation history | |
| generation: LLM generation | |
| web_search: whether to add search | |
| documents: list of documents | |
| question: the last user question | |
| """ | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| generation: str | |
| web_search: str | |
| documents: List[str] | |
| question: str | |
| def chatbot(state: GraphState): | |
| logger.info("---GENERATE (no context)---") | |
| logger.info(state) | |
| chain = llm | StrOutputParser() | |
| generation = chain.invoke(state["messages"]) | |
| logger.info(generation) | |
| return { "messages": [AIMessage(content=generation)], "generation": generation } | |
| def retrieve(state): | |
| """ | |
| Retrieve documents | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): New key added to state, documents, that contains retrieved documents | |
| """ | |
| logger.info("---RETRIEVE---") | |
| question = state['messages'][-1].content # Last Human message | |
| logger.info(f'question: {question}') | |
| # Retrieval | |
| # documents = retriever.invoke(question) | |
| documents = retriever(question) # Large-small retriever | |
| # logger.debug(documents) | |
| logger.info([ (doc.metadata['id'], doc.page_content[:20])for doc in documents ]) | |
| return {"documents": documents, "question": question} | |
| def generate_with_context(state): | |
| """ | |
| Generate answer | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): New key added to state, generation, that contains LLM generation | |
| """ | |
| logger.debug("---GENERATE WITH CONTEXT---") | |
| logger.debug(f'state: {state}') | |
| question = state["question"] | |
| documents = state["documents"] | |
| # RAG generation | |
| generation = rag_chain.invoke({"context": documents, "question": question}) | |
| logger.debug(generation) | |
| return {"documents": documents, "question": question, "generation": generation} | |
| def web_search(state): | |
| """ | |
| Web search based on the re-phrased question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): Updates documents key with appended web results | |
| """ | |
| logger.debug("---WEB SEARCH---") | |
| question = state["question"] | |
| documents = state["documents"] | |
| # Web search | |
| logger.debug(f'question: {question}') | |
| docs = web_search_tool.invoke({"query": question}) # Returns str if error | |
| logger.debug(f'type(docs) = {type(docs)}') | |
| logger.debug(docs) | |
| web_results = "\n".join([d["content"] for d in docs]) | |
| web_results = Document(page_content=web_results) | |
| documents.append(web_results) | |
| return {"documents": web_results, "question": question} | |
| def grade_documents(state): | |
| """ | |
| Determines whether the retrieved documents are relevant to the question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): Updates documents key with only filtered relevant documents | |
| """ | |
| logger.debug("---CHECK DOCUMENT RELEVANCE TO QUESTION---") | |
| question = state["question"] | |
| documents = state["documents"] | |
| # Score each doc | |
| filtered_docs = [] | |
| web_search = "No" | |
| for d in documents: | |
| score = retrieval_grader.invoke({"question": question, "document": d.page_content}) | |
| grade = score.binary_score | |
| if grade == "yes": | |
| logger.debug("---GRADE: DOCUMENT RELEVANT---") | |
| filtered_docs.append(d) | |
| else: | |
| logger.debug("---GRADE: DOCUMENT NOT RELEVANT---") | |
| web_search = "Yes" | |
| continue | |
| return {"documents": filtered_docs, "question": question, "web_search": web_search} | |
| def transform_query(state): | |
| """ | |
| Transform the query to produce a better question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): Updates question key with a re-phrased question | |
| """ | |
| logger.debug("---TRANSFORM QUERY---") | |
| question = state["question"] | |
| documents = state["documents"] | |
| # Re-write question | |
| better_question = question_rewriter.invoke({"question": question}) | |
| return {"documents": documents, "question": better_question} | |
| ### Edges ### | |
| # For conditional edges | |
| def decide_to_retrieve(state): | |
| """ | |
| Determines whether to retrieve a context for answering a question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| str: Binary decision for next node to call | |
| """ | |
| logger.debug("---ASSESS NEED FOR RETRIEVAL---") | |
| # logger.debug(state) | |
| question = state['messages'][-1].content # Last Human message | |
| logger.debug(question) | |
| response = retriever_grader.invoke({ 'question': question }) | |
| logger.debug(response) | |
| logger.debug(response.binary_score) | |
| if response.binary_score == "yes": | |
| # All documents have been filtered check_relevance | |
| # We will re-generate a new query | |
| logger.debug("---DECISION: RETRIEVE DOCUMENTS---") | |
| return "retrieve" | |
| else: | |
| # We have relevant documents, so generate answer | |
| logger.debug("---DECISION: GENERAL QUESTION, NO RETRIEVAL---") | |
| # state['question'] = question | |
| return "chatbot" | |
| def decide_to_generate(state): | |
| """ | |
| Determines whether to generate an answer, or re-generate a question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| str: Binary decision for next node to call | |
| """ | |
| logger.debug("---ASSESS GRADED DOCUMENTS---") | |
| state["question"] | |
| web_search = state["web_search"] | |
| state["documents"] | |
| if web_search == "Yes": | |
| # All documents have been filtered check_relevance | |
| # We will re-generate a new query | |
| logger.debug( | |
| "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---" | |
| ) | |
| return "transform_query" | |
| else: | |
| # We have relevant documents, so generate answer | |
| logger.debug("---DECISION: GENERATE---") | |
| return "generate_with_context" | |
| # Prepare and compile the Graph | |
| workflow = StateGraph(GraphState) | |
| # Define the nodes | |
| workflow.add_node("chatbot", chatbot) # retrieve | |
| workflow.add_node("retrieve", retrieve) # retrieve | |
| workflow.add_node("grade_documents", grade_documents) # grade documents | |
| workflow.add_node("generate_with_context", generate_with_context) # generate | |
| workflow.add_node("transform_query", transform_query) # transform_query | |
| workflow.add_node("web_search_node", web_search) # web search | |
| # Build graph | |
| # workflow.add_edge(START, "retrieve") | |
| workflow.add_conditional_edges( | |
| START, | |
| decide_to_retrieve | |
| ) | |
| workflow.add_edge("retrieve", "grade_documents") | |
| workflow.add_conditional_edges( | |
| "grade_documents", | |
| decide_to_generate, | |
| { | |
| "transform_query": "transform_query", | |
| "generate_with_context": "generate_with_context", | |
| }, | |
| ) | |
| workflow.add_edge("transform_query", "web_search_node") | |
| workflow.add_edge("web_search_node", "generate_with_context") | |
| workflow.add_edge("generate_with_context", "chatbot") | |
| workflow.add_edge("chatbot", END) | |
| # Compile | |
| from langgraph.checkpoint.memory import MemorySaver | |
| memory = MemorySaver() | |
| app = workflow.compile(checkpointer=memory, debug=False) | |
| if __name__ == "__main__": | |
| # Use the graph | |
| from pprint import pprint | |
| # print(retriever.invoke("What is an algorithm?")) | |
| config = {"configurable": {"thread_id": "abc123"}} | |
| def query_graph(question:str): | |
| inputs = { "question": question } | |
| messages = [HumanMessage(inputs['question'])] | |
| response = app.invoke({"messages": messages}, config) | |
| # print('TYPE >>', type(response)) # langgraph.pregel.io.AddableValuesDict | |
| return response | |
| def print_generation(response:str): | |
| # pprint(type(response['generation'])) # str | |
| # pprint(response['messages']) | |
| pprint(response['generation']) # AIMessage (no context) | |
| # question = "Cual es el orden de ingreso y egresos de elementos en un Queue?" | |
| # pprint(query_graph(question)) | |
| # print_generation(query_graph("Hi, my name is George and I would like to learn about algorithms")) | |
| # print_generation(query_graph("Do you remember my name? What algorithm would you use to reverse the letters in my name?")) | |
| # print_generation(query_graph("Que es un algoritmo?")) | |
| # print_generation(query_graph("Que es una heuristica?")) | |
| # print_generation(query_graph("Que se entiende por orden de crecimiento de un algoritmo?")) | |
| # print_generation(query_graph("Que es la función tilde?")) | |
| # print_generation(query_graph("Cuál es la diferencia entre función tilde y orden de crecimiento?")) | |
| # In stream mode, returns the full 'chatbot' message | |
| for x in app.stream({"messages": "What is the answer to the question of everything?"}, config): | |
| print(x) | |