Spaces:
Sleeping
Sleeping
| import os | |
| import certifi | |
| os.environ['REQUESTS_CA_BUNDLE'] = certifi.where() | |
| from dotenv import load_dotenv | |
| from langgraph.graph import START, StateGraph, MessagesState | |
| from langgraph.prebuilt import tools_condition | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_core.tools import tool | |
| from langchain.tools.retriever import create_retriever_tool | |
| from langchain_core.documents import Document | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_groq import ChatGroq | |
| load_dotenv() | |
| # ---------------- CONFIGURATION ---------------- | |
| # Change this to any valid Hugging Face model endpoint (e.g., meta-llama/Llama-3-8b-chat-hf) | |
| HF_MODEL_NAME = os.getenv("LLAMA_MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct") | |
| HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| HF_MODEL_URL = f"https://api-inference.huggingface.co/models/{HF_MODEL_NAME}" | |
| # Use the OpenAI-compatible inference endpoint | |
| HF_OPENAI_URL = "https://api-inference.huggingface.co/openai" | |
| # ---------------- UTILITY TOOLS ---------------- | |
| def multiply_numbers(x: int, y: int) -> int: | |
| """Multiply two integers and return the result.""" | |
| return x * y | |
| def add_numbers(x: int, y: int) -> int: | |
| """Add two integers and return the sum.""" | |
| return x + y | |
| def subtract_numbers(x: int, y: int) -> int: | |
| """Subtract the second integer from the first and return the result.""" | |
| return x - y | |
| def divide_numbers(x: int, y: int) -> float: | |
| """Divide the first number by the second and return the result. Raises an error on division by zero.""" | |
| if y == 0: | |
| raise ValueError("Division by zero is not allowed.") | |
| return x / y | |
| def modulus_numbers(x: int, y: int) -> int: | |
| """Return the remainder when the first number is divided by the second.""" | |
| return x % y | |
| def power_numbers(base: float, exponent: float) -> float: | |
| """Raise the base to the power of exponent and return the result.""" | |
| return base ** exponent | |
| def root_number(value: float, n: float) -> float: | |
| """Compute the nth root of a value and return the result.""" | |
| return value ** (1 / n) | |
| def wiki_lookup(query: str) -> str: | |
| """Search Wikipedia for the query and return up to 2 summarized documents.""" | |
| docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
| return "\n\n---\n\n".join( | |
| f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content}</Document>' for d in docs | |
| ) | |
| def web_lookup(query: str) -> str: | |
| """Search the web using Tavily and return up to 3 summarized results.""" | |
| docs = TavilySearchResults(max_results=3).invoke(query=query) | |
| return "\n\n---\n\n".join( | |
| f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content}</Document>' for d in docs | |
| ) | |
| def arxiv_lookup(query: str) -> str: | |
| """Search arXiv for the query and return summaries of up to 3 papers.""" | |
| docs = ArxivLoader(query=query, load_max_docs=3).load() | |
| return "\n\n---\n\n".join( | |
| f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content[:800]}</Document>' for d in docs | |
| ) | |
| def add_numbers(x: int, y: int) -> int: | |
| """Add two integers and return the sum.""" | |
| return x + y | |
| def subtract_numbers(x: int, y: int) -> int: | |
| """Subtract the second integer from the first and return the result.""" | |
| return x - y | |
| def divide_numbers(x: int, y: int) -> float: | |
| """Divide the first number by the second and return the result. Raises an error on division by zero.""" | |
| if y == 0: | |
| raise ValueError("Division by zero is not allowed.") | |
| return x / y | |
| def modulus_numbers(x: int, y: int) -> int: | |
| """Return the remainder when the first number is divided by the second.""" | |
| return x % y | |
| def power_numbers(base: float, exponent: float) -> float: | |
| """Raise the base to the power of exponent and return the result.""" | |
| return base ** exponent | |
| def root_number(value: float, n: float) -> float: | |
| """Compute the nth root of a value and return the result.""" | |
| return value ** (1 / n) | |
| def wiki_lookup(query: str) -> str: | |
| """Search Wikipedia for the query and return up to 2 summarized documents.""" | |
| docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
| return "\n\n---\n\n".join( | |
| f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content}</Document>' for d in docs | |
| ) | |
| def web_lookup(query: str) -> str: | |
| """Search the web using Tavily and return up to 3 summarized results.""" | |
| docs = TavilySearchResults(max_results=3).invoke(query=query) | |
| return "\n\n---\n\n".join( | |
| f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content}</Document>' for d in docs | |
| ) | |
| def arxiv_lookup(query: str) -> str: | |
| """Search arXiv for the query and return summaries of up to 3 papers.""" | |
| docs = ArxivLoader(query=query, load_max_docs=3).load() | |
| return "\n\n---\n\n".join( | |
| f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content[:800]}</Document>' for d in docs | |
| ) | |
| # # ---------------- SETUP LOCAL VECTORSTORE ---------------- | |
| # embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
| # text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100) | |
| # sample_docs = [Document(page_content="St. Thomas Aquinas was a medieval Catholic priest and philosopher.", metadata={"source": "wiki", "page": "St. Thomas Aquinas"})] | |
| # split_docs = text_splitter.split_documents(sample_docs) | |
| # vector_db = Chroma.from_documents(documents=split_docs, embedding=embedding_model) | |
| # retriever_tool = create_retriever_tool( | |
| # retriever=vector_db.as_retriever(), | |
| # name="SimilarQuestionFinder", | |
| # description="Retrieve similar questions and examples from vector DB." | |
| # ) | |
| # # ---------------- SYSTEM PROMPT ---------------- | |
| # with open("system_prompt.txt", "r", encoding="utf-8") as f: | |
| # system_content = f.read() | |
| # system_message = SystemMessage(content=system_content) | |
| # # ---------------- BUILD STATE GRAPH ---------------- | |
| # def construct_agent_graph(): | |
| # llama_llm = ChatHuggingFace( | |
| # llm=HuggingFaceEndpoint( | |
| # endpoint_url=HF_OPENAI_URL, | |
| # temperature=0 | |
| # ) | |
| # ).bind_tools([ | |
| # multiply_numbers, | |
| # add_numbers, | |
| # subtract_numbers, | |
| # divide_numbers, | |
| # modulus_numbers, | |
| # power_numbers, | |
| # root_number, | |
| # wiki_lookup, | |
| # web_lookup, | |
| # arxiv_lookup, | |
| # retriever_tool, | |
| # ]) | |
| # def retrieve_node(state: MessagesState): | |
| # similar = vector_db.similarity_search(state["messages"][0].content) | |
| # hint = HumanMessage(content=f"Reference example:\n{similar[0].page_content}" if similar else "") | |
| # return {"messages": [system_message] + state["messages"] + [hint]} | |
| # def respond_node(state: MessagesState): | |
| # return {"messages": [llama_llm.invoke(state["messages"]) ]} | |
| # graph_builder = StateGraph(MessagesState) | |
| # graph_builder.add_node("find_similar", retrieve_node) | |
| # graph_builder.add_node("generate_answer", respond_node) | |
| # graph_builder.add_node("tool_executor", ToolNode([])) | |
| # graph_builder.add_edge(START, "find_similar") | |
| # graph_builder.add_edge("find_similar", "generate_answer") | |
| # graph_builder.add_conditional_edges( | |
| # "generate_answer", | |
| # tools_condition, | |
| # {"tools": "tool_executor", "default": "generate_answer"} | |
| # ) | |
| # graph_builder.add_edge("tool_executor", "generate_answer") | |
| # return graph_builder.compile() | |
| # # ---------------- RUN EXAMPLE ---------------- | |
| # if __name__ == "__main__": | |
| # sample_q = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?" | |
| # agent = construct_agent_graph() | |
| # msgs = [HumanMessage(content=sample_q)] | |
| # out = agent.invoke({"messages": msgs}) | |
| # for m in out["messages"]: | |
| # m.pretty_print() | |
| # ---------------- EMBEDDINGS & VECTOR DB ---------------- | |
| embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100) | |
| sample_docs = [Document(page_content="Sample doc.", metadata={"source":"wiki"})] | |
| split_docs = text_splitter.split_documents(sample_docs) | |
| vector_db = Chroma.from_documents(documents=split_docs, embedding=embedding_model) | |
| retriever_tool = create_retriever_tool( | |
| retriever=vector_db.as_retriever(), | |
| name="SimilarQuestionFinder", | |
| description="Retrieve similar questions and examples from vector DB." | |
| ) | |
| all_tools = [multiply_numbers, add_numbers, subtract_numbers, divide_numbers, | |
| modulus_numbers, power_numbers, root_number, | |
| wiki_lookup, web_lookup, arxiv_lookup, retriever_tool] | |
| # ---------------- SYSTEM PROMPT ---------------- | |
| with open("system_prompt.txt", "r", encoding="utf-8") as f: | |
| system_content = f.read() | |
| system_message = SystemMessage(content=system_content) | |
| # ---------------- BUILD GRAPH ---------------- | |
| def construct_agent_graph(): | |
| llama_llm = ChatGroq( | |
| model="qwen-qwq-32b", | |
| api_key=os.environ["GROQ_API_KEY"], | |
| temperature=0, | |
| ) | |
| def retrieve_node(state: MessagesState): | |
| msgs = [system_message] + state["messages"] | |
| similar = vector_db.similarity_search(state["messages"][0].content) | |
| if similar: | |
| msgs.append(HumanMessage(content=f"Reference example:\n{similar[0].page_content}")) | |
| return {"messages": msgs} | |
| def respond_node(state: MessagesState): | |
| return {"messages": [llama_llm.invoke(state["messages"])]} | |
| graph = StateGraph(MessagesState) | |
| graph.add_node("find_similar", retrieve_node) | |
| graph.add_node("generate_answer", respond_node) | |
| graph.add_node("tool_executor", ToolNode(tools=all_tools)) | |
| graph.add_edge(START, "find_similar") | |
| graph.add_edge("find_similar", "generate_answer") | |
| graph.add_conditional_edges( | |
| "generate_answer", | |
| tools_condition, | |
| {"tools": "tool_executor", "__end__": "__end__"} | |
| ) | |
| graph.add_edge("tool_executor", "generate_answer") | |
| return graph.compile() | |
| # ---------------- RUN EXAMPLE ---------------- | |
| if __name__ == "__main__": | |
| agent = construct_agent_graph() | |
| sample_q = "When was St. Thomas Aquinas added to that page?" | |
| out = agent.invoke({"messages": [HumanMessage(content=sample_q)]}) | |
| for m in out["messages"]: | |
| m.pretty_print() | |