File size: 7,403 Bytes
b7a5973
17af168
 
 
b7a5973
17af168
 
 
 
 
b7a5973
17af168
 
 
b7a5973
17af168
 
b7a5973
17af168
 
 
 
b7a5973
 
 
 
17af168
 
 
 
 
 
 
b7a5973
 
 
 
17af168
 
 
 
 
 
 
b7a5973
 
 
 
17af168
 
 
 
 
 
 
b7a5973
 
 
 
 
 
17af168
 
 
b7a5973
17af168
 
 
 
 
b7a5973
 
 
 
17af168
 
 
b7a5973
17af168
b7a5973
 
17af168
b7a5973
 
 
 
17af168
 
 
 
 
b7a5973
 
17af168
 
 
b7a5973
 
17af168
b7a5973
 
 
 
17af168
 
 
b7a5973
17af168
b7a5973
 
17af168
 
 
b7a5973
 
17af168
b7a5973
 
 
 
17af168
 
 
 
 
b7a5973
 
 
17af168
b7a5973
17af168
 
 
b7a5973
17af168
 
b7a5973
 
 
 
17af168
b7a5973
 
 
17af168
 
b7a5973
17af168
b7a5973
17af168
 
b7a5973
17af168
 
 
 
 
 
 
 
b7a5973
17af168
 
178bafa
b7a5973
 
 
 
 
 
 
 
 
 
 
17af168
 
b7a5973
be8ceb9
b7a5973
 
 
 
17af168
 
b7a5973
 
 
17af168
 
b7a5973
17af168
b7a5973
 
 
 
 
 
 
 
 
17af168
b7a5973
 
 
 
 
 
17af168
b7a5973
 
 
 
17af168
 
 
 
 
b7a5973
17af168
b7a5973
17af168
 
b7a5973
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""LangGraph Agent with Hugging Face LLM and Robust Retriever"""
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.document_loaders import ArxivLoader
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.tools import tool
from supabase.client import Client, create_client

# Load environment variables from .env file
load_dotenv()

# Define mathematical tools for basic operations
@tool
def multiply(a: int, b: int) -> int:
    """Multiply two numbers.
    Args:
        a: First integer
        b: Second integer
    Returns:
        Product of a and b
    """
    return a * b

@tool
def add(a: int, b: int) -> int:
    """Add two numbers.
    Args:
        a: First integer
        b: Second integer
    Returns:
        Sum of a and b
    """
    return a + b

@tool
def subtract(a: int, b: int) -> int:
    """Subtract two numbers.
    Args:
        a: First integer
        b: Second integer
    Returns:
        Difference of a and b
    """
    return a - b

@tool
def divide(a: int, b: int) -> int:
    """Divide two numbers.
    Args:
        a: First integer
        b: Second integer
    Returns:
        Quotient of a divided by b
    Raises:
        ValueError: If b is zero
    """
    if b == 0:
        raise ValueError("Cannot divide by zero.")
    return a // b  # Integer division for consistency

@tool
def modulus(a: int, b: int) -> int:
    """Get the modulus of two numbers.
    Args:
        a: First integer
        b: Second integer
    Returns:
        Remainder of a divided by b
    """
    return a % b

# Define search tools for external information retrieval
@tool
def wiki_search(query: str) -> dict:
    """Search Wikipedia for a query and return up to 2 results.
    Args:
        query: The search query
    Returns:
        Dictionary with formatted Wikipedia results
    """
    search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
    formatted_search_docs = "\n\n---\n\n".join(
        [
            f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
            for doc in search_docs
        ]
    )
    return {"wiki_results": formatted_search_docs}

@tool
def web_search(query: str) -> dict:
    """Search Tavily for a query and return up to 3 results.
    Args:
        query: The search query
    Returns:
        Dictionary with formatted web search results
    """
    search_docs = TavilySearchResults(max_results=3).invoke(query=query)
    formatted_search_docs = "\n\n---\n\n".join(
        [
            f'<Document source="{doc["url"]}" title="{doc.get("title", "")}">\n{doc["content"]}\n</Document>'
            for doc in search_docs
        ]
    )
    return {"web_results": formatted_search_docs}

@tool
def arxiv_search(query: str) -> dict:
    """Search Arxiv for a query and return up to 3 results.
    Args:
        query: The search query
    Returns:
        Dictionary with formatted Arxiv results
    """
    search_docs = ArxivLoader(query=query, load_max_docs=3).load()
    formatted_search_docs = "\n\n---\n\n".join(
        [
            f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
            for doc in search_docs
        ]
    )
    return {"arxiv_results": formatted_search_docs}

# Load system prompt from file
with open("system_prompt.txt", "r", encoding="utf-8") as f:
    system_prompt = f.read()

# Create system message for the LLM
sys_msg = SystemMessage(content=system_prompt)

# Initialize embeddings for vector store
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

# Initialize Supabase client and vector store
supabase: Client = create_client(
    os.environ.get("SUPABASE_URL"),
    os.environ.get("SUPABASE_SERVICE_KEY")
)
vector_store = SupabaseVectorStore(
    client=supabase,
    embedding=embeddings,
    table_name="documents",
    query_name="match_documents_langchain"
)

# Define tools list
tools = [
    multiply,
    add,
    subtract,
    divide,
    modulus,
    wiki_search,
    web_search,
    arxiv_search
]

def build_graph(provider: str = "huggingface"):
    """Build the LangGraph workflow for the agent.
    Args:
        provider: The LLM provider to use ('huggingface' by default)
    Returns:
        Compiled LangGraph workflow
    """
    # Load environment variables
    load_dotenv()

    # Initialize LLM based on provider
    if provider == "huggingface":
        llm = ChatHuggingFace(
            llm=HuggingFaceEndpoint(
                repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
                huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN"),
                temperature=0.1,  # Low temperature for deterministic responses
                max_new_tokens=512,  # Limit response length
                timeout=60  # Set timeout for API calls
            )
        )
    else:
        raise ValueError("Only 'huggingface' provider is supported.")

    # Bind tools to LLM for tool invocation
    llm_with_tools = llm.bind_tools(tools)

    # Define assistant node to process queries with LLM
    def assistant(state: MessagesState):
        """Assistant node to generate responses using the LLM.
        Args:
            state: Current state with messages
        Returns:
            Updated state with LLM response
        """
        return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}

    # Define retriever node to fetch similar documents
    def retriever(state: MessagesState):
        """Retriever node to search vector store for similar questions.
        Args:
            state: Current state with messages
        Returns:
            Updated state with retrieved answer or fallback message
        """
        query = state["messages"][-1].content
        results = vector_store.similarity_search(query, k=1)
        if not results:
            return {"messages": [AIMessage(content="No relevant information found in the vector store. Relying on LLM and tools.")] + state["messages"]}
        similar_doc = results[0]
        content = similar_doc.page_content
        if "Final answer :" in content:
            answer = content.split("Final answer :")[-1].strip()
        else:
            answer = content.strip()
        return {"messages": [AIMessage(content=answer)] + state["messages"]}

    # Initialize graph
    builder = StateGraph(MessagesState)

    # Add nodes
    builder.add_node("retriever", retriever)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))

    # Define edges
    builder.add_edge(START, "retriever")
    builder.add_edge("retriever", "assistant")
    builder.add_conditional_edges(
        "assistant",
        tools_condition,  # Route to tools if needed
    )
    builder.add_edge("tools", "assistant")

    # Compile and return graph
    return builder.compile()