vark101 commited on
Commit
b642bac
·
verified ·
1 Parent(s): 81917a3

Upload 4 files

Browse files
Files changed (4) hide show
  1. app_langgraph.py +86 -0
  2. math_tools.py +52 -0
  3. requirements.txt +100 -2
  4. search_tools.py +53 -0
app_langgraph.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace, HuggingFaceEmbeddings
8
+ from langchain_core.messages import SystemMessage, HumanMessage
9
+ from langchain_core.globals import set_debug
10
+ from langchain_groq import ChatGroq
11
+ from tools.search_tools import web_search, arvix_search, wiki_search
12
+ from tools.math_tools import multiply, add, subtract, divide
13
+ from supabase.client import Client, create_client
14
+ from langchain.tools.retriever import create_retriever_tool
15
+ from langchain_community.vectorstores import SupabaseVectorStore
16
+ import json
17
+
18
+ # set_debug(True)
19
+ load_dotenv()
20
+
21
+ tools = [
22
+ multiply,
23
+ add,
24
+ subtract,
25
+ divide,
26
+ web_search,
27
+ wiki_search,
28
+ arvix_search
29
+ ]
30
+
31
+ def build_graph():
32
+ hf_token = os.getenv("HF_TOKEN")
33
+ # llm = HuggingFaceEndpoint(
34
+ # repo_id="Qwen/Qwen2.5-Coder-32B-Instruct",
35
+ # huggingfacehub_api_token=hf_token,
36
+ # )
37
+
38
+ # chat = ChatHuggingFace(llm=llm, verbose=True)
39
+ # llm_with_tools = chat.bind_tools(tools)
40
+
41
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
42
+ llm_with_tools = llm.bind_tools(tools)
43
+
44
+ def assistant(state: MessagesState):
45
+ sys_msg = "You are a helpful assistant with access to tools. Understand user requests accurately. Use your tools when needed to answer effectively. Strictly follow all user instructions and constraints." \
46
+ "Pay attention: your output needs to contain only the final answer without any reasoning since it will be strictly evaluated against a dataset which contains only the specific response." \
47
+ "Your final output needs to be just the string or integer containing the answer, not an array or technical stuff."
48
+ return {
49
+ "messages": [llm_with_tools.invoke([sys_msg] + state["messages"])],
50
+ }
51
+
52
+ ## The graph
53
+ builder = StateGraph(MessagesState)
54
+
55
+ builder.add_node("assistant", assistant)
56
+ builder.add_node("tools", ToolNode(tools))
57
+
58
+ builder.add_edge(START, "assistant")
59
+ builder.add_conditional_edges(
60
+ "assistant",
61
+ # If the latest message requires a tool, route to tools
62
+ # Otherwise, provide a direct response
63
+ tools_condition,
64
+ )
65
+ builder.add_edge("tools", "assistant")
66
+ return builder.compile()
67
+
68
+ # test
69
+ if __name__ == "__main__":
70
+
71
+ graph = build_graph()
72
+ with open('sample.jsonl', 'r') as jsonl_file:
73
+ json_list = list(jsonl_file)
74
+
75
+ start = 10 #revisit 5, 8,
76
+ end = start + 1
77
+ for json_str in json_list[start:end]:
78
+ json_data = json.loads(json_str)
79
+ print(f"Question::::::::: {json_data['Question']}")
80
+ print(f"Final answer::::: {json_data['Final answer']}")
81
+
82
+ question = json_data['Question']
83
+ messages = [HumanMessage(content=question)]
84
+ messages = graph.invoke({"messages": messages})
85
+ for m in messages["messages"]:
86
+ m.pretty_print()
math_tools.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.tools import tool
2
+
3
+ @tool
4
+ def multiply(a: int, b: int) -> int:
5
+ """Multiply two numbers.
6
+ Args:
7
+ a: first int
8
+ b: second int
9
+ """
10
+ return a * b
11
+
12
+ @tool
13
+ def add(a: int, b: int) -> int:
14
+ """Add two numbers.
15
+
16
+ Args:
17
+ a: first int
18
+ b: second int
19
+ """
20
+ return a + b
21
+
22
+ @tool
23
+ def subtract(a: int, b: int) -> int:
24
+ """Subtract two numbers.
25
+
26
+ Args:
27
+ a: first int
28
+ b: second int
29
+ """
30
+ return a - b
31
+
32
+ @tool
33
+ def divide(a: int, b: int) -> int:
34
+ """Divide two numbers.
35
+
36
+ Args:
37
+ a: first int
38
+ b: second int
39
+ """
40
+ if b == 0:
41
+ raise ValueError("Cannot divide by zero.")
42
+ return a / b
43
+
44
+ @tool
45
+ def modulus(a: int, b: int) -> int:
46
+ """Get the modulus of two numbers.
47
+
48
+ Args:
49
+ a: first int
50
+ b: second int
51
+ """
52
+ return a % b
requirements.txt CHANGED
@@ -1,2 +1,100 @@
1
- gradio
2
- requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.6.1
2
+ aiohttp==3.12.11
3
+ aiosignal==1.3.2
4
+ aiosqlite==0.21.0
5
+ annotated-types==0.7.0
6
+ anyio==4.9.0
7
+ attrs==25.3.0
8
+ banks==2.1.2
9
+ certifi==2025.4.26
10
+ charset-normalizer==3.4.2
11
+ click==8.2.1
12
+ colorama==0.4.6
13
+ dataclasses-json==0.6.7
14
+ Deprecated==1.2.18
15
+ dirtyjson==1.0.8
16
+ distro==1.9.0
17
+ filelock==3.18.0
18
+ filetype==1.2.0
19
+ frozenlist==1.6.2
20
+ fsspec==2025.3.2
21
+ greenlet==3.2.3
22
+ griffe==1.7.3
23
+ h11==0.16.0
24
+ hf-xet==1.1.0
25
+ httpcore==1.0.9
26
+ httpx==0.28.1
27
+ httpx-sse==0.4.0
28
+ huggingface-hub==0.31.1
29
+ idna==3.10
30
+ Jinja2==3.1.6
31
+ jiter==0.10.0
32
+ joblib==1.5.1
33
+ jsonpatch==1.33
34
+ jsonpointer==3.0.0
35
+ langchain==0.3.25
36
+ langchain-community==0.3.25
37
+ langchain-core==0.3.65
38
+ langchain-huggingface==0.3.0
39
+ langchain-openai==0.3.21
40
+ langchain-text-splitters==0.3.8
41
+ langgraph==0.4.8
42
+ langgraph-checkpoint==2.0.26
43
+ langgraph-prebuilt==0.2.2
44
+ langgraph-sdk==0.1.70
45
+ langsmith==0.3.45
46
+ llama-index-core==0.12.41
47
+ llama-index-embeddings-huggingface==0.5.4
48
+ llama-index-llms-huggingface-api==0.5.0
49
+ markdown-it-py==3.0.0
50
+ MarkupSafe==3.0.2
51
+ marshmallow==3.26.1
52
+ mdurl==0.1.2
53
+ mpmath==1.3.0
54
+ multidict==6.4.4
55
+ mypy_extensions==1.1.0
56
+ nest-asyncio==1.6.0
57
+ networkx==3.5
58
+ nltk==3.9.1
59
+ numpy==2.3.0
60
+ openai==1.85.0
61
+ orjson==3.10.18
62
+ ormsgpack==1.10.0
63
+ packaging==24.2
64
+ pillow==11.2.1
65
+ platformdirs==4.3.8
66
+ propcache==0.3.1
67
+ pydantic==2.11.5
68
+ pydantic-settings==2.9.1
69
+ pydantic_core==2.33.2
70
+ Pygments==2.19.1
71
+ python-dotenv==1.1.0
72
+ PyYAML==6.0.2
73
+ regex==2024.11.6
74
+ requests==2.32.3
75
+ requests-toolbelt==1.0.0
76
+ rich==14.0.0
77
+ safetensors==0.5.3
78
+ scikit-learn==1.7.0
79
+ scipy==1.15.3
80
+ sentence-transformers==4.1.0
81
+ smolagents==1.15.0
82
+ sniffio==1.3.1
83
+ SQLAlchemy==2.0.41
84
+ sympy==1.14.0
85
+ tavily-python==0.7.5
86
+ tenacity==9.1.2
87
+ threadpoolctl==3.6.0
88
+ tiktoken==0.9.0
89
+ tokenizers==0.21.1
90
+ torch==2.7.1
91
+ tqdm==4.67.1
92
+ transformers==4.52.4
93
+ typing-inspect==0.9.0
94
+ typing-inspection==0.4.1
95
+ typing_extensions==4.13.2
96
+ urllib3==2.4.0
97
+ wrapt==1.17.2
98
+ xxhash==3.5.0
99
+ yarl==1.20.0
100
+ zstandard==0.23.0
search_tools.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.tools import tool
2
+ from langchain_community.document_loaders import WikipediaLoader
3
+ from langchain_community.document_loaders import ArxivLoader
4
+ # Search engine specifically for LLMs
5
+ # from langchain_community.tools.tavily_search import TavilySearchResults
6
+ from langchain_tavily import TavilySearch
7
+
8
+
9
+ @tool
10
+ def web_search(query: str) -> str:
11
+ """Search Tavily for a query and return maximum 3 results.
12
+
13
+ Args:
14
+ query: The search query."""
15
+ # print(f"Web search query:::::::::::: {query}")
16
+ search_docs = TavilySearch(max_results=3).invoke({"query":query})
17
+ formatted_search_docs = "\n\n---\n\n".join(
18
+ [
19
+ f'<Document source="{doc["url"]}" page="{doc["title"]}"/>\n{doc["content"]}\n</Document>'
20
+ for doc in search_docs['results']
21
+ ])
22
+ # print(f"Web search result:::::::::::: {formatted_search_docs}")
23
+ return {"web_results": formatted_search_docs}
24
+
25
+ @tool
26
+ def wiki_search(query: str) -> str:
27
+ """Search Wikipedia for a query and return maximum 2 results.
28
+
29
+ Args:
30
+ query: The search query."""
31
+
32
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
33
+ formatted_search_docs = "\n\n---\n\n".join(
34
+ [
35
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
36
+ for doc in search_docs
37
+ ])
38
+
39
+ return {"wiki_results": formatted_search_docs}
40
+
41
+ @tool
42
+ def arvix_search(query: str) -> str:
43
+ """Search Arxiv for a query and return maximum 3 result.
44
+
45
+ Args:
46
+ query: The search query."""
47
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
48
+ formatted_search_docs = "\n\n---\n\n".join(
49
+ [
50
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
51
+ for doc in search_docs
52
+ ])
53
+ return {"arvix_results": formatted_search_docs}