TaoZewen commited on
Commit
cbe1829
·
1 Parent(s): 56e29c4

add Wikipedia QA tool

Browse files
Files changed (1) hide show
  1. gemini_agent.py +22 -2
gemini_agent.py CHANGED
@@ -22,6 +22,9 @@ from PIL import Image
22
  import google.generativeai as genai
23
  from pydantic import Field
24
 
 
 
 
25
  from smolagents import WikipediaSearchTool
26
 
27
  class SmolagentToolWrapper(BaseTool):
@@ -278,6 +281,18 @@ def analyze_excel_file(file_path: str, query: str) -> str:
278
  return "Error: pandas and openpyxl are not installed. Please install them with 'pip install pandas openpyxl'."
279
  except Exception as e:
280
  return f"Error analyzing Excel file: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
  class GeminiAgent:
283
  def __init__(self, api_key: str, model_name: str = "gemini-2.0-flash"):
@@ -324,7 +339,12 @@ class GeminiAgent:
324
  name="web_search",
325
  func=self._web_search,
326
  description="Search the web for information"
327
- )
 
 
 
 
 
328
  ]
329
 
330
  # Setup memory
@@ -591,7 +611,7 @@ New question: {input}
591
  agent=agent,
592
  tools=self.tools,
593
  memory=self.memory,
594
- max_iterations=20,
595
  verbose=True,
596
  handle_parsing_errors=True,
597
  return_only_outputs=True # This ensures we only get the final output
 
22
  import google.generativeai as genai
23
  from pydantic import Field
24
 
25
+ from transformers import pipeline
26
+ import wikipedia
27
+
28
  from smolagents import WikipediaSearchTool
29
 
30
  class SmolagentToolWrapper(BaseTool):
 
281
  return "Error: pandas and openpyxl are not installed. Please install them with 'pip install pandas openpyxl'."
282
  except Exception as e:
283
  return f"Error analyzing Excel file: {str(e)}"
284
+
285
+ qa_pipeline = pipeline(
286
+ "question-answering",
287
+ model="distilbert-base-cased-distilled-squad",
288
+ tokenizer="distilbert-base-cased-distilled-squad"
289
+ )
290
+
291
+ def wikipedia_qa(question: str) -> str:
292
+ # 先搜条目——可以传“Mercedes Sosa”或直接用 query
293
+ page = wikipedia.page("Mercedes Sosa").content
294
+ out = qa_pipeline(question=question, context=page)
295
+ return out["answer"]
296
 
297
  class GeminiAgent:
298
  def __init__(self, api_key: str, model_name: str = "gemini-2.0-flash"):
 
339
  name="web_search",
340
  func=self._web_search,
341
  description="Search the web for information"
342
+ ),
343
+ Tool(
344
+ name="wikipedia_qa",
345
+ func=wikipedia_qa,
346
+ description="给定问题 + 维基文章,直接抽取精确答案"
347
+ )
348
  ]
349
 
350
  # Setup memory
 
611
  agent=agent,
612
  tools=self.tools,
613
  memory=self.memory,
614
+ max_iterations=10,
615
  verbose=True,
616
  handle_parsing_errors=True,
617
  return_only_outputs=True # This ensures we only get the final output