VGekko commited on
Commit
556f749
·
verified ·
1 Parent(s): 2e917f1

Upload 3 files

Browse files
Files changed (3) hide show
  1. agent.py +150 -0
  2. app.py +9 -31
  3. requirements.txt +10 -1
agent.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_community.tools.tavily_search import TavilySearchResults
8
+ from langchain_community.document_loaders import WikipediaLoader
9
+ from langchain_core.messages import HumanMessage
10
+ from langchain_core.tools import tool
11
+ from langchain_mistralai import ChatMistralAI
12
+
13
+ load_dotenv()
14
+
15
+
16
+ def web_search(query: str) -> str:
17
+ """Search Tavily for a query and return maximum 3 results.
18
+
19
+ Args:
20
+ query: The search query."""
21
+ search_docs = TavilySearchResults(max_results=3).invoke(input=query)
22
+ formatted_search_docs = "\n\n---\n\n".join(
23
+ [
24
+ f'<Document source="{doc["title"]}" page="{doc.get("url", "")}"/>\n{doc['content']}\n</Document>'
25
+ for doc in search_docs
26
+ ])
27
+ return {"web_results": formatted_search_docs}
28
+
29
+
30
+ def wiki_search(query: str) -> str:
31
+ """Search Wikipedia for a query and return maximum 2 results.
32
+
33
+ Args:
34
+ query: The search query."""
35
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
36
+ formatted_search_docs = "\n\n---\n\n".join(
37
+ [
38
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
39
+ for doc in search_docs
40
+ ])
41
+ return {"wiki_results": formatted_search_docs}
42
+
43
+
44
+ @tool
45
+ def multiply(a: int, b: int) -> int:
46
+ """Multiply two numbers.
47
+
48
+ Args:
49
+ a: first int
50
+ b: second int
51
+ """
52
+ return a * b
53
+
54
+ @tool
55
+ def add(a: int, b: int) -> int:
56
+ """Add two numbers.
57
+
58
+ Args:
59
+ a: first int
60
+ b: second int
61
+ """
62
+ return a + b
63
+
64
+ @tool
65
+ def subtract(a: int, b: int) -> int:
66
+ """Subtract two numbers.
67
+
68
+ Args:
69
+ a: first int
70
+ b: second int
71
+ """
72
+ return a - b
73
+
74
+ @tool
75
+ def divide(a: int, b: int) -> int:
76
+ """Divide two numbers.
77
+
78
+ Args:
79
+ a: first int
80
+ b: second int
81
+ """
82
+ if b == 0:
83
+ raise ValueError("Cannot divide by zero.")
84
+ return a / b
85
+
86
+ @tool
87
+ def modulus(a: int, b: int) -> int:
88
+ """Get the modulus of two numbers.
89
+
90
+ Args:
91
+ a: first int
92
+ b: second int
93
+ """
94
+ return a % b
95
+
96
+ tools = [
97
+ multiply,
98
+ add,
99
+ subtract,
100
+ divide,
101
+ modulus,
102
+ wiki_search,
103
+ web_search,
104
+
105
+ ]
106
+
107
+ def build_graph():
108
+
109
+
110
+ llm = ChatMistralAI(
111
+ model="mistral-small-2503",
112
+ temperature=0,
113
+ max_retries=2,
114
+ )
115
+
116
+ llm_with_tools = llm.bind_tools(tools)
117
+
118
+
119
+ # Node
120
+ def assistant(state: MessagesState):
121
+ """Assistant node"""
122
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
123
+
124
+
125
+ builder = StateGraph(MessagesState)
126
+ builder.add_node("assistant", assistant)
127
+ builder.add_node("tools", ToolNode(tools))
128
+ builder.add_edge(START, "assistant")
129
+ builder.add_conditional_edges(
130
+ "assistant",
131
+ tools_condition,
132
+ )
133
+ builder.add_edge("tools", "assistant")
134
+
135
+ # Compile graph
136
+ return builder.compile()
137
+
138
+ if __name__=="__main__":
139
+ system_prompt = "You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."
140
+ question = "whats 2 + 15?"
141
+ question = f"{system_prompt} \
142
+ The question: \
143
+ {question} \
144
+ "
145
+ graph = build_graph()
146
+ messages = [HumanMessage(content=question)]
147
+ messages = graph.invoke({"messages": messages})
148
+ for m in messages["messages"]:
149
+ m.pretty_print()
150
+
app.py CHANGED
@@ -3,8 +3,9 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
 
 
6
 
7
- from mistralai import Mistral
8
 
9
  # (Keep Constants as is)
10
  # --- Constants ---
@@ -18,45 +19,22 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
18
  class BasicAgent:
19
  def __init__(self):
20
  print("BasicAgent initialized.")
 
21
 
22
  def __call__(self, question: str) -> str:
23
  print(f"Agent received question (first 50 chars): {question[:50]}...")
24
 
25
-
26
- api_key = os.environ["MISTRAL_API_KEY"]
27
- mistral_client = Mistral(api_key=api_key)
28
- model = "mistral-small-2503"
29
-
30
- system_prompt = "You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."
31
- question = f"{system_prompt} \
32
- The question: \
33
- {question} \
34
- "
35
-
36
-
37
-
38
- response = mistral_client.chat.complete(
39
- model= model,
40
- messages = [
41
- {
42
- "role": "user",
43
- "content": question,
44
- },
45
- ]
46
- )
47
-
48
- answer = response.choices[0].message.content
49
-
50
  #answer = 'static answer'
51
-
52
  #answer = "This is a default answer."
53
  #print(f"Agent returning fixed answer: {answer}")
54
- return answer
55
-
56
-
57
-
58
 
59
 
 
 
 
 
 
 
60
 
61
 
62
 
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ from langchain_core.messages import HumanMessage
7
+ from agent import build_graph
8
 
 
9
 
10
  # (Keep Constants as is)
11
  # --- Constants ---
 
19
  class BasicAgent:
20
  def __init__(self):
21
  print("BasicAgent initialized.")
22
+ self.graph = build_graph()
23
 
24
  def __call__(self, question: str) -> str:
25
  print(f"Agent received question (first 50 chars): {question[:50]}...")
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  #answer = 'static answer'
 
28
  #answer = "This is a default answer."
29
  #print(f"Agent returning fixed answer: {answer}")
 
 
 
 
30
 
31
 
32
+ # Wrap the question in a HumanMessage from langchain_core
33
+ messages = [HumanMessage(content=question)]
34
+ messages = self.graph.invoke({"messages": messages})
35
+ answer = messages['messages'][-1].content
36
+ return answer[14:]
37
+
38
 
39
 
40
 
requirements.txt CHANGED
@@ -1,3 +1,12 @@
1
  gradio
2
  requests
3
- mistralai
 
 
 
 
 
 
 
 
 
 
1
  gradio
2
  requests
3
+ mistralai
4
+ langchain_community
5
+ duckduckgo-search
6
+ langchain
7
+ langchain-core
8
+ langchain-tavily
9
+ langgraph
10
+ pymupdf
11
+ wikipedia
12
+ python-dotenv