sampsong commited on
Commit
3ffdb34
·
1 Parent(s): 98d7add

1) add agent, tool, and system prompt

Browse files
Files changed (6) hide show
  1. Agents/agent.py +115 -0
  2. Prompts/SystemPrompt.txt +5 -0
  3. Tools/tools.py +104 -0
  4. agent.py +8 -0
  5. app.py +11 -3
  6. tools.py +0 -0
Agents/agent.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from Tools.tools import webSearch, arxivSearch, wikiSearch,add,multiply,divide,substract, remainder
3
+ from langchain_core.messages import SystemMessage, HumanMessage
4
+ from dotenv import load_dotenv
5
+ from supabase.client import Client, create_client
6
+ from langchain_groq import ChatGroq
7
+ from langgraph.graph import START, StateGraph, MessagesState
8
+ from langgraph.prebuilt import ToolNode, tools_condition
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
+ from langchain.tools.retriever import create_retriever_tool
11
+ from langchain_community.vectorstores import SupabaseVectorStore
12
+ from langfuse.langchain import CallbackHandler
13
+
14
+ load_dotenv()
15
+
16
+ langfuse_handler = CallbackHandler()
17
+
18
+ with open("Prompt/SystemPrompt.txt","r",encoding="utf-8") as f:
19
+ systemPrompt = f.read()
20
+ print(systemPrompt)
21
+
22
+ sysMsg = SystemMessage(content=systemPrompt)
23
+
24
+ embeddings = HuggingFaceEmbeddings(
25
+ model_name = "sentence-transformers/all-mpnet-base-v2"
26
+ )
27
+
28
+ supabase: Client = create_client(
29
+ os.environ.get("SUPABASE_URL"),os.environ.get("SUPABASE_SERVICE_KEY")
30
+ )
31
+
32
+ vector_store = SupabaseVectorStore(
33
+ client=supabase,
34
+ embeddings=embeddings,
35
+ table_name="documents",
36
+ query_name="match_documents",
37
+ )
38
+
39
+ create_retriever_tool = create_retriever_tool(
40
+ retriever=vector_store.as_retriever(),
41
+ name="Question Search",
42
+ description="A tool to retrieve similar questions from a vector store.",
43
+ )
44
+
45
+ tools = [
46
+ webSearch,
47
+ wikiSearch,
48
+ arxivSearch,
49
+ multiply,
50
+ add,
51
+ substract,
52
+ divide,
53
+ remainder
54
+ ]
55
+
56
+ def build_graph(provider: str="groq"):
57
+ if provider== "groq":
58
+ llm=ChatGroq(model="qwen/qwen3-32b",temperature=0)
59
+ elif provider == "huggingface":
60
+ llm = ChatHuggingFace(
61
+ llm=HuggingFaceEndpoint(
62
+ repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
63
+ task="text-generation",
64
+ max_new_tokens=1024,
65
+ do_sample=False,
66
+ repetition_penalty=1.03,
67
+ temperature=0,
68
+ ),
69
+ verbose=True,
70
+ )
71
+ else:
72
+ raise ValueError("Invalid Provider. Choose 'groq or 'huggingface'")
73
+ llmWithTools = llm.bind_tools(tools)
74
+
75
+ def assistant(state: MessagesState):
76
+ return {"message": [llmWithTools.invoke(state["messages"])]}
77
+
78
+ def retriever(state: MessagesState):
79
+ similiarQuestion = vector_store.similarity_search(state["message"][0].content)
80
+
81
+ if similiarQuestion:
82
+ exampleMessage = HumanMessage(
83
+ content=f"Here i provide a similiar question and answer for reference: \n\n{similiarQuestion[0].page_content}",
84
+ )
85
+ return {"messages": [sysMsg] + state["messages"] + [exampleMessage]}
86
+ else:
87
+ return {"message": [sysMsg] + state["messages"]}
88
+
89
+ builder = StateGraph(MessagesState)
90
+ builder.add_node("retriever",retriever)
91
+ builder.add_node("assistant",assistant)
92
+ builder.add_node("tools", ToolNode(tools))
93
+ builder.add_edge(START,"retriever")
94
+ builder.add_edge("retriever","assistant")
95
+ builder.add_conditional_edges(
96
+ "assistant",
97
+ tools_condition,
98
+ )
99
+ builder.add_edge("tools","assistant")
100
+ return builder.compile()
101
+
102
+
103
+ #test
104
+ if __name__ == "__main__":
105
+ question = "When was a picture of St. Thomas Aquinas first added to the wikipedia page on the principle of double effect?"
106
+ graph = build_graph(provider="groq")
107
+ messages = [HumanMessage(content=question)]
108
+ messages = graph.invoke(
109
+ input= {"messages": messages},
110
+ config={"callbacks": [langfuse_handler]}
111
+ )
112
+ graph.get_graph().draw_mermaid_png()
113
+
114
+ for m in messages["messages"]:
115
+ m.pretty_print()
Prompts/SystemPrompt.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ You are a helpful assistant tasked with answering question using sets of tools.
2
+ Now i will ask you a question. Report your thoughts and finish your answer with the following template:
3
+ FINAL ANSWER: [YOUR FINAL ANSWER].
4
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
5
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
Tools/tools.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.tools.tavily_search import TavilySearchResults
2
+ from langchain_community.document_loaders import WikipediaLoader
3
+ from langchain_community.document_loaders import ArxivLoader
4
+ from langchain_community.vectorstores import SupabaseVectorStore
5
+ from langchain_core.messages import SystemMessage, HumanMessage
6
+ from langchain_core.tools import tool
7
+ from langchain.tools.retriever import create_retriever_tool
8
+
9
+ @tool
10
+ def add(a: int, b:int) -> int:
11
+ """ add two integer
12
+ Args:
13
+ a: first integer
14
+ b: second integer
15
+ """
16
+ return a + b
17
+
18
+ @tool
19
+ def substract(a: int, b:int) -> int:
20
+ """ substract two integer
21
+ Args:
22
+ a : first integer
23
+ b : second integer
24
+ """
25
+ return a - b
26
+
27
+ @tool
28
+ def multiply(a: int, b: int) -> int:
29
+ """multiply two integer
30
+ Args:
31
+ a: first integer
32
+ b: second integer
33
+ """
34
+ return a * b
35
+
36
+ @tool
37
+ def divide(a: int, b: int) -> int:
38
+ """ divide two integer
39
+ args:
40
+ a: first integer
41
+ b: second integer
42
+ """
43
+ return a / b
44
+
45
+ @tool
46
+ def remainder(a: int, b: int) -> int:
47
+ """ left over of division
48
+ args:
49
+ a: first integer
50
+ b: second integer
51
+ """
52
+ return a % b
53
+
54
+ @tool
55
+ def wikiSearch(searchQuery:str) -> str:
56
+ """ search wikipedia to get three matching results
57
+
58
+ args:
59
+ searchQuery: the search query
60
+ """
61
+ search_results = WikipediaLoader(query=searchQuery, load_max_docs=3).load()
62
+ formatted_results = "\n\n--\n\n".join(
63
+ [
64
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page","")}"/>\n{doc.page_content}\n</Document>'
65
+ for doc in search_results
66
+ ])
67
+ return {"wiki_results": formatted_results}
68
+
69
+ @tool
70
+ def arxivSearch(searchQuery:str) -> str:
71
+ """
72
+ search arxiv to get three matching results
73
+
74
+ args:
75
+ searchQuery: the search query
76
+ """
77
+ search_results = ArxivLoader(query=searchQuery, load_max_docs=3).load()
78
+ formatted_results = "\n\n--\n\n".join(
79
+ [
80
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page","")}"/>\n{doc.page_content}\n</Document>'
81
+ for doc in search_results
82
+ ])
83
+ return {"arxiv_result": formatted_results}
84
+
85
+ @tool
86
+ def webSearch(searchQuery:str) -> str:
87
+ """
88
+ search the web using Tavily to get three matching results
89
+
90
+ args:
91
+ searchQuery: search query
92
+ """
93
+ search_results = TavilySearchResults(max_results=3).invoke(query=searchQuery)
94
+ formatted_results = "\n\n--\n\n".join(
95
+ [
96
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page","")}"/>\n{doc.page_content}\n</Document>'
97
+ for doc in search_results
98
+ ]
99
+ )
100
+ return {"web_search": formatted_results}
101
+
102
+
103
+
104
+
agent.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from tools import webSearch, arxivSearch, wikiSearch,add,multiply,divide,substract, remainder
2
+ from dotenv import load_dotenv
3
+
4
+ load_dotenv()
5
+
6
+
7
+
8
+
app.py CHANGED
@@ -3,21 +3,29 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
 
 
 
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
9
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
10
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
  class BasicAgent:
14
  def __init__(self):
15
  print("BasicAgent initialized.")
 
16
  def __call__(self, question: str) -> str:
17
  print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
19
- print(f"Agent returning fixed answer: {fixed_answer}")
20
- return fixed_answer
 
 
 
21
 
22
  def run_and_submit_all( profile: gr.OAuthProfile | None):
23
  """
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ from langchain_core.messages import HumanMessage
7
+ from agent import build_graph
8
+ from langfuse.langchain import CallbackHandler
9
 
10
  # (Keep Constants as is)
11
  # --- Constants ---
12
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
13
+ langfuse_handler = CallbackHandler()
14
 
15
  # --- Basic Agent Definition ---
16
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
17
  class BasicAgent:
18
  def __init__(self):
19
  print("BasicAgent initialized.")
20
+ self.graph = build_graph
21
  def __call__(self, question: str) -> str:
22
  print(f"Agent received question (first 50 chars): {question[:50]}...")
23
+ messages = [HumanMessage(content=question)]
24
+ messages = self.graph.invoke({"messages": messages})
25
+ messages = self.graph.invoke({"messages": messages},config={"callbacks": [langfuse_handler]})
26
+ self.graph.get_graph().draw_mermaid_png()
27
+ answer = messages['messages'][-1].content
28
+ return answer[14:]
29
 
30
  def run_and_submit_all( profile: gr.OAuthProfile | None):
31
  """
tools.py DELETED
File without changes