Spaces:
Paused
Paused
| """ | |
| 源文件:local_llm_rag.py | |
| This is a simple example of how to use LangChain to build a local LLM RAG system. | |
| """ | |
| import getpass | |
| import os | |
| def _set_env(var: str): | |
| if not os.environ.get(var): | |
| os.environ[var] = getpass.getpass(f"{var}: ") | |
| _set_env("TAVILY_API_KEY") | |
| _set_env("NOMIC_API_KEY") | |
| # Ollama model name | |
| local_llm = "mistral" | |
| try: | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| except ImportError: | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_nomic.embeddings import NomicEmbeddings | |
| urls = [ | |
| "https://lilianweng.github.io/posts/2023-06-23-agent/", | |
| "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/", | |
| "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/", | |
| ] | |
| docs = [WebBaseLoader(url).load() for url in urls] | |
| docs_list = [item for sublist in docs for item in sublist] | |
| text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |
| chunk_size=250, chunk_overlap=0 | |
| ) | |
| doc_splits = text_splitter.split_documents(docs_list) | |
| # Add to vectorDB | |
| vectorstore = Chroma.from_documents( | |
| documents=doc_splits, | |
| collection_name="rag-chroma", | |
| embedding=NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local"), | |
| ) | |
| retriever = vectorstore.as_retriever() | |
| ### Router | |
| try: | |
| from langchain_core.prompts import PromptTemplate | |
| except ImportError: | |
| from langchain.prompts import PromptTemplate | |
| from langchain_community.chat_models import ChatOllama | |
| from langchain_core.output_parsers import JsonOutputParser | |
| # LLM | |
| llm = ChatOllama(model=local_llm, format="json", temperature=0) | |
| prompt = PromptTemplate( | |
| template="""You are an expert at routing a user question to a vectorstore or web search. \n | |
| Use the vectorstore for questions on LLM agents, prompt engineering, and adversarial attacks. \n | |
| You do not need to be stringent with the keywords in the question related to these topics. \n | |
| Otherwise, use web-search. Give a binary choice 'web_search' or 'vectorstore' based on the question. \n | |
| Return the a JSON with a single key 'datasource' and no premable or explanation. \n | |
| Question to route: {question}""", | |
| input_variables=["question"], | |
| ) | |
| question_router = prompt | llm | JsonOutputParser() | |
| question = "llm agent memory" | |
| docs = retriever.get_relevant_documents(question) | |
| doc_txt = docs[1].page_content | |
| print(question_router.invoke({"question": question})) | |
| ### Generate | |
| from langchain import hub | |
| from langchain_community.chat_models import ChatOllama | |
| from langchain_core.output_parsers import StrOutputParser | |
| # Prompt | |
| prompt = hub.pull("rlm/rag-prompt") | |
| # LLM | |
| llm = ChatOllama(model=local_llm, temperature=0) | |
| # Post-processing | |
| def format_docs(docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| # Chain | |
| rag_chain = prompt | llm | StrOutputParser() | |
| # Run | |
| question = "agent memory" | |
| generation = rag_chain.invoke({"context": docs, "question": question}) | |
| print(generation) | |
| ### Answer Grader | |
| # LLM | |
| llm = ChatOllama(model=local_llm, format="json", temperature=0) | |
| # Prompt | |
| prompt = PromptTemplate( | |
| template="""You are a grader assessing whether an answer is useful to resolve a question. \n | |
| Here is the answer: | |
| \n ------- \n | |
| {generation} | |
| \n ------- \n | |
| Here is the question: {question} | |
| Give a binary score 'yes' or 'no' to indicate whether the answer is useful to resolve a question. \n | |
| Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.""", | |
| input_variables=["generation", "question"], | |
| ) | |
| answer_grader = prompt | llm | JsonOutputParser() | |
| answer_grader.invoke({"question": question, "generation": generation}) | |
| ### Question Re-writer | |
| # LLM | |
| llm = ChatOllama(model=local_llm, temperature=0) | |
| # Prompt | |
| re_write_prompt = PromptTemplate( | |
| template="""You a question re-writer that converts an input question to a better version that is optimized \n | |
| for vectorstore retrieval. Look at the initial and formulate an improved question. \n | |
| Here is the initial question: \n\n {question}. Improved question with no preamble: \n """, | |
| input_variables=["generation", "question"], | |
| ) | |
| question_rewriter = re_write_prompt | llm | StrOutputParser() | |
| question_rewriter.invoke({"question": question}) | |
| ### Search | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| web_search_tool = TavilySearchResults(k=3) | |
| from typing import List | |
| from typing_extensions import TypedDict | |
| class GraphState(TypedDict): | |
| """ | |
| Represents the state of our graph. | |
| Attributes: | |
| question: question | |
| generation: LLM generation | |
| documents: list of documents | |
| """ | |
| question: str | |
| generation: str | |
| documents: List[str] | |
| try: | |
| from langchain_core.documents import Document | |
| except ImportError: | |
| print("langchain_core.documents not found, using langchain.schema") | |
| try: | |
| from langchain_core.documents import Document | |
| except ImportError: | |
| from langchain.schema import Document | |
| def retrieve(state): | |
| """ | |
| Retrieve documents | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): New key added to state, documents, that contains retrieved documents | |
| """ | |
| print("---RETRIEVE---") | |
| question = state["question"] | |
| # Retrieval | |
| documents = retriever.get_relevant_documents(question) | |
| return {"documents": documents, "question": question} | |
| def generate(state): | |
| """ | |
| Generate answer | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): New key added to state, generation, that contains LLM generation | |
| """ | |
| print("---GENERATE---") | |
| question = state["question"] | |
| documents = state["documents"] | |
| # RAG generation | |
| generation = rag_chain.invoke({"context": documents, "question": question}) | |
| return {"documents": documents, "question": question, "generation": generation} | |
| def grade_documents(state): | |
| """ | |
| Determines whether the retrieved documents are relevant to the question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): Updates documents key with only filtered relevant documents | |
| """ | |
| print("---CHECK DOCUMENT RELEVANCE TO QUESTION---") | |
| question = state["question"] | |
| documents = state["documents"] | |
| # Score each doc | |
| filtered_docs = [] | |
| for d in documents: | |
| score = retrieval_grader.invoke( | |
| {"question": question, "document": d.page_content} | |
| ) | |
| grade = score["score"] | |
| if grade == "yes": | |
| print("---GRADE: DOCUMENT RELEVANT---") | |
| filtered_docs.append(d) | |
| else: | |
| print("---GRADE: DOCUMENT NOT RELEVANT---") | |
| continue | |
| return {"documents": filtered_docs, "question": question} | |
| def transform_query(state): | |
| """ | |
| Transform the query to produce a better question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): Updates question key with a re-phrased question | |
| """ | |
| print("---TRANSFORM QUERY---") | |
| question = state["question"] | |
| documents = state["documents"] | |
| # Re-write question | |
| better_question = question_rewriter.invoke({"question": question}) | |
| return {"documents": documents, "question": better_question} | |
| def web_search(state): | |
| """ | |
| Web search based on the re-phrased question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| state (dict): Updates documents key with appended web results | |
| """ | |
| print("---WEB SEARCH---") | |
| question = state["question"] | |
| # Web search | |
| docs = web_search_tool.invoke({"query": question}) | |
| web_results = "\n".join([d["content"] for d in docs]) | |
| web_results = Document(page_content=web_results) | |
| return {"documents": web_results, "question": question} | |
| ### Edges ### | |
| def route_question(state): | |
| """ | |
| Route question to web search or RAG. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| str: Next node to call | |
| """ | |
| print("---ROUTE QUESTION---") | |
| question = state["question"] | |
| print(question) | |
| source = question_router.invoke({"question": question}) | |
| print(source) | |
| print(source["datasource"]) | |
| if source["datasource"] == "web_search": | |
| print("---ROUTE QUESTION TO WEB SEARCH---") | |
| return "web_search" | |
| elif source["datasource"] == "vectorstore": | |
| print("---ROUTE QUESTION TO RAG---") | |
| return "vectorstore" | |
| def decide_to_generate(state): | |
| """ | |
| Determines whether to generate an answer, or re-generate a question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| str: Binary decision for next node to call | |
| """ | |
| print("---ASSESS GRADED DOCUMENTS---") | |
| state["question"] | |
| filtered_documents = state["documents"] | |
| if not filtered_documents: | |
| # All documents have been filtered check_relevance | |
| # We will re-generate a new query | |
| print( | |
| "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---" | |
| ) | |
| return "transform_query" | |
| else: | |
| # We have relevant documents, so generate answer | |
| print("---DECISION: GENERATE---") | |
| return "generate" | |
| def grade_generation_v_documents_and_question(state): | |
| """ | |
| Determines whether the generation is grounded in the document and answers question. | |
| Args: | |
| state (dict): The current graph state | |
| Returns: | |
| str: Decision for next node to call | |
| """ | |
| print("---CHECK HALLUCINATIONS---") | |
| question = state["question"] | |
| documents = state["documents"] | |
| generation = state["generation"] | |
| score = hallucination_grader.invoke( | |
| {"documents": documents, "generation": generation} | |
| ) | |
| grade = score["score"] | |
| # Check hallucination | |
| if grade == "yes": | |
| print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---") | |
| # Check question-answering | |
| print("---GRADE GENERATION vs QUESTION---") | |
| score = answer_grader.invoke({"question": question, "generation": generation}) | |
| grade = score["score"] | |
| if grade == "yes": | |
| print("---DECISION: GENERATION ADDRESSES QUESTION---") | |
| return "useful" | |
| else: | |
| print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---") | |
| return "not useful" | |
| else: | |
| pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---") | |
| return "not supported" | |
| from langgraph.graph import END, StateGraph, START | |
| workflow = StateGraph(GraphState) | |
| # Define the nodes | |
| workflow.add_node("web_search", web_search) # web search | |
| workflow.add_node("retrieve", retrieve) # retrieve | |
| workflow.add_node("grade_documents", grade_documents) # grade documents | |
| workflow.add_node("generate", generate) # generate | |
| workflow.add_node("transform_query", transform_query) # transform_query | |
| # Build graph | |
| workflow.add_conditional_edges( | |
| START, | |
| route_question, | |
| { | |
| "web_search": "web_search", | |
| "vectorstore": "retrieve", | |
| }, | |
| ) | |
| workflow.add_edge("web_search", "generate") | |
| workflow.add_edge("retrieve", "grade_documents") | |
| workflow.add_conditional_edges( | |
| "grade_documents", | |
| decide_to_generate, | |
| { | |
| "transform_query": "transform_query", | |
| "generate": "generate", | |
| }, | |
| ) | |
| workflow.add_edge("transform_query", "retrieve") | |
| workflow.add_conditional_edges( | |
| "generate", | |
| grade_generation_v_documents_and_question, | |
| { | |
| "not supported": "generate", | |
| "useful": END, | |
| "not useful": "transform_query", | |
| }, | |
| ) | |
| # Compile | |
| app = workflow.compile() | |
| from pprint import pprint | |
| # Run | |
| inputs = {"question": "What is the AlphaCodium paper about?"} | |
| for output in app.stream(inputs): | |
| for key, value in output.items(): | |
| # Node | |
| pprint(f"Node '{key}':") | |
| # Optional: print full state at each node | |
| # pprint.pprint(value["keys"], indent=2, width=80, depth=None) | |
| pprint("\n---\n") | |
| # Final generation | |
| pprint(value["generation"]) |