Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| from typing import TypedDict, List, Dict, Any, Optional | |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
| from langchain_huggingface.chat_models import ChatHuggingFace | |
| from langchain_groq.chat_models import ChatGroq | |
| from langgraph.graph.message import add_messages | |
| from langgraph.graph import StateGraph, START, END, MessagesState | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from tools import ( | |
| add, | |
| subtract, multiply, div, modulus, power, | |
| wikipedia_search, search_web, arxiv_search, | |
| save_and_read_file, download_file_from_url, extract_text_from_image, | |
| pdf_loader | |
| ) | |
| from retriever import get_retriever_tool | |
| load_dotenv(dotenv_path = ".env") | |
| # Configurations | |
| SYSTEM_PROMPT_PATH = "system_prompt.txt" | |
| DEFAULT_PROVIDER = "groq" | |
| MODEL_NAME = "llama3-70b-8192" | |
| def load_system_prompt(path: str = SYSTEM_PROMPT_PATH) -> str: | |
| if not os.path.exists(path): | |
| raise ValueError(f"System prompt file not foud at: {path}") | |
| with open(path, "r", encoding = "utf-8") as f: | |
| return f.read() | |
| system_prompt = load_system_prompt() | |
| sys_msg = SystemMessage(content = system_prompt) | |
| # Load tools | |
| vector_store, vector_retriever, retriever_tool = get_retriever_tool() | |
| TOOLS = [ | |
| # Math | |
| add, subtract, multiply, div, modulus, power, | |
| # Documents Search | |
| wikipedia_search, search_web, arxiv_search, | |
| # Process Files | |
| save_and_read_file, download_file_from_url, extract_text_from_image, | |
| pdf_loader, | |
| # Retriever | |
| retriever_tool | |
| ] | |
| def get_llm(provider: str = DEFAULT_PROVIDER): | |
| if provider == "groq": | |
| return ChatGroq(model = MODEL_NAME, temperature = 0) | |
| elif provider == "huggingface": | |
| raise NotImplementedError("HuggingFace support not yet implemented.") | |
| else: | |
| raise ValueError("Invalid LLM provider. Choose 'groq' or 'huggingface'") | |
| def build_graph(provider: str = DEFAULT_PROVIDER): | |
| """ | |
| Builds LangGraph graph | |
| """ | |
| llm = get_llm(provider) | |
| # Add tools to the LLM | |
| llm_with_tools = llm.bind_tools(TOOLS) | |
| def assistant(state: MessagesState): | |
| return {"messages": llm_with_tools.invoke(state["messages"])} | |
| def retriever(state: MessagesState): | |
| query = state["messages"][0].content | |
| similar_qas = vector_store.similarity_search(query) | |
| if similar_qas: | |
| reference = similar_qas[0].page_content | |
| example_qa = HumanMessage( | |
| content = f"I provide a similar question and answer for reference:\n\n{reference}" | |
| ) | |
| return {"messages": [sys_msg] + state["messages"] + [example_qa]} | |
| else: | |
| return {"messages": [sys_msg] + state["messages"]} | |
| # Graph | |
| builder = StateGraph(MessagesState) | |
| # Nodes | |
| builder.add_node("retriever", retriever) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(TOOLS)) | |
| # Edges | |
| builder.add_edge(START, "retriever") | |
| builder.add_edge("retriever", "assistant") | |
| builder.add_conditional_edges( | |
| "assistant", | |
| tools_condition | |
| ) | |
| builder.add_edge("tools", "assistant") | |
| return builder.compile() | |
| if __name__ == "__main__": | |
| import random | |
| import json | |
| with open("metadata.jsonl") as dataset_file: | |
| json_list = list(dataset_file) | |
| QAs = [json.loads(qa) for qa in json_list] | |
| question = random.choice(QAs)["Question"] | |
| graph = build_graph() | |
| messages = [HumanMessage(content = question)] | |
| messages = graph.invoke({"messages": messages}) | |
| for m in messages["messages"]: | |
| m.pretty_print() |