alisamak commited on
Commit
86aee1a
·
verified ·
1 Parent(s): ef95f1f

Update basic_agent.py

Browse files
Files changed (1) hide show
  1. basic_agent.py +18 -11
basic_agent.py CHANGED
@@ -6,25 +6,32 @@ from tools import (
6
  )
7
 
8
  class BasicAgent:
 
 
 
 
 
 
 
9
  def __call__(self, question: str) -> str:
10
  try:
11
- tool = self.select_tool(question)
12
- result = tool.invoke({"query": question})
13
- return result
14
  except Exception as e:
15
  return f"Error: {e}"
16
 
17
  def select_tool(self, question: str):
18
  q = question.lower()
19
-
20
  if "youtube.com" in q or "youtu.be" in q or "in the video" in q:
21
- return analyze_youtube_video # direct reference, not index
22
-
23
  elif "how many" in q or "number of" in q:
24
- return extract_number_from_text
25
-
26
  elif "wikipedia" in q or "encyclopedia" in q or "who is" in q:
27
- return search_wikipedia
28
-
29
- return extract_number_from_text # fallback tool
 
30
 
 
6
  )
7
 
8
  class BasicAgent:
9
+ def __init__(self):
10
+ self.tool_registry = {
11
+ "analyze_youtube_video": (analyze_youtube_video, "url"),
12
+ "extract_number_from_text": (extract_number_from_text, "text"),
13
+ "search_wikipedia": (search_wikipedia, "query"),
14
+ }
15
+
16
  def __call__(self, question: str) -> str:
17
  try:
18
+ tool, input_key = self.select_tool(question)
19
+ return tool.invoke({input_key: question})
 
20
  except Exception as e:
21
  return f"Error: {e}"
22
 
23
  def select_tool(self, question: str):
24
  q = question.lower()
25
+
26
  if "youtube.com" in q or "youtu.be" in q or "in the video" in q:
27
+ return self.tool_registry["analyze_youtube_video"]
28
+
29
  elif "how many" in q or "number of" in q:
30
+ return self.tool_registry["extract_number_from_text"]
31
+
32
  elif "wikipedia" in q or "encyclopedia" in q or "who is" in q:
33
+ return self.tool_registry["search_wikipedia"]
34
+
35
+ return self.tool_registry["extract_number_from_text"] # fallback
36
+
37