File size: 3,663 Bytes
3927a42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from dotenv import load_dotenv
from typing import TypedDict, List, Dict, Any, Optional

from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_huggingface.chat_models import ChatHuggingFace
from langchain_groq.chat_models import ChatGroq

from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition

from tools import (
    add, 
    subtract, multiply, div, modulus, power,
    wikipedia_search, search_web, arxiv_search, 
    save_and_read_file, download_file_from_url, extract_text_from_image,
    pdf_loader
)
from retriever import get_retriever_tool

load_dotenv(dotenv_path = ".env")

# Configurations
SYSTEM_PROMPT_PATH = "system_prompt.txt"
DEFAULT_PROVIDER = "groq"
MODEL_NAME = "llama3-70b-8192"

def load_system_prompt(path: str = SYSTEM_PROMPT_PATH) -> str:
    if not os.path.exists(path):
        raise ValueError(f"System prompt file not foud at: {path}")
    with open(path, "r", encoding = "utf-8") as f:
        return f.read()


system_prompt = load_system_prompt()
sys_msg = SystemMessage(content = system_prompt)

# Load tools
vector_store, vector_retriever, retriever_tool = get_retriever_tool()

TOOLS = [
    # Math
    add, subtract, multiply, div, modulus, power,
    # Documents Search
    wikipedia_search, search_web, arxiv_search, 
    # Process Files
    save_and_read_file, download_file_from_url, extract_text_from_image,
    pdf_loader,
    # Retriever
    retriever_tool
]

def get_llm(provider: str = DEFAULT_PROVIDER):
    if provider == "groq":
        return ChatGroq(model = MODEL_NAME, temperature = 0)
    elif provider == "huggingface":
        raise NotImplementedError("HuggingFace support not yet implemented.")
    else:
        raise ValueError("Invalid LLM provider. Choose 'groq' or 'huggingface'")
    

def build_graph(provider: str = DEFAULT_PROVIDER):
    """
    Builds LangGraph graph
    """
    llm = get_llm(provider)
    
    # Add tools to the LLM
    llm_with_tools = llm.bind_tools(TOOLS)

    def assistant(state: MessagesState):
        return {"messages": llm_with_tools.invoke(state["messages"])}
    
    def retriever(state: MessagesState):
        query = state["messages"][0].content
        similar_qas = vector_store.similarity_search(query)

        if similar_qas:
            reference = similar_qas[0].page_content
            example_qa = HumanMessage(
                content = f"I provide a similar question and answer for reference:\n\n{reference}"
            )
            return {"messages": [sys_msg] + state["messages"] + [example_qa]}
        else:
            return {"messages": [sys_msg] + state["messages"]}
    
    # Graph
    builder = StateGraph(MessagesState)
    
    # Nodes
    builder.add_node("retriever", retriever)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(TOOLS))
    
    # Edges
    builder.add_edge(START, "retriever")
    builder.add_edge("retriever", "assistant")
    builder.add_conditional_edges(
        "assistant",
        tools_condition
    )
    builder.add_edge("tools", "assistant")

    return builder.compile()

if __name__ == "__main__":
    import random
    import json

    with open("metadata.jsonl") as dataset_file:
        json_list = list(dataset_file)

    QAs = [json.loads(qa) for qa in json_list]
    question = random.choice(QAs)["Question"]
    graph = build_graph()
    messages = [HumanMessage(content = question)]
    messages = graph.invoke({"messages": messages})
    for m in messages["messages"]:
        m.pretty_print()