add Wikipedia QA tool
Browse files- gemini_agent.py +22 -2
gemini_agent.py
CHANGED
|
@@ -22,6 +22,9 @@ from PIL import Image
|
|
| 22 |
import google.generativeai as genai
|
| 23 |
from pydantic import Field
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
from smolagents import WikipediaSearchTool
|
| 26 |
|
| 27 |
class SmolagentToolWrapper(BaseTool):
|
|
@@ -278,6 +281,18 @@ def analyze_excel_file(file_path: str, query: str) -> str:
|
|
| 278 |
return "Error: pandas and openpyxl are not installed. Please install them with 'pip install pandas openpyxl'."
|
| 279 |
except Exception as e:
|
| 280 |
return f"Error analyzing Excel file: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
class GeminiAgent:
|
| 283 |
def __init__(self, api_key: str, model_name: str = "gemini-2.0-flash"):
|
|
@@ -324,7 +339,12 @@ class GeminiAgent:
|
|
| 324 |
name="web_search",
|
| 325 |
func=self._web_search,
|
| 326 |
description="Search the web for information"
|
| 327 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
]
|
| 329 |
|
| 330 |
# Setup memory
|
|
@@ -591,7 +611,7 @@ New question: {input}
|
|
| 591 |
agent=agent,
|
| 592 |
tools=self.tools,
|
| 593 |
memory=self.memory,
|
| 594 |
-
max_iterations=
|
| 595 |
verbose=True,
|
| 596 |
handle_parsing_errors=True,
|
| 597 |
return_only_outputs=True # This ensures we only get the final output
|
|
|
|
| 22 |
import google.generativeai as genai
|
| 23 |
from pydantic import Field
|
| 24 |
|
| 25 |
+
from transformers import pipeline
|
| 26 |
+
import wikipedia
|
| 27 |
+
|
| 28 |
from smolagents import WikipediaSearchTool
|
| 29 |
|
| 30 |
class SmolagentToolWrapper(BaseTool):
|
|
|
|
| 281 |
return "Error: pandas and openpyxl are not installed. Please install them with 'pip install pandas openpyxl'."
|
| 282 |
except Exception as e:
|
| 283 |
return f"Error analyzing Excel file: {str(e)}"
|
| 284 |
+
|
| 285 |
+
qa_pipeline = pipeline(
|
| 286 |
+
"question-answering",
|
| 287 |
+
model="distilbert-base-cased-distilled-squad",
|
| 288 |
+
tokenizer="distilbert-base-cased-distilled-squad"
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def wikipedia_qa(question: str) -> str:
|
| 292 |
+
# 先搜条目——可以传“Mercedes Sosa”或直接用 query
|
| 293 |
+
page = wikipedia.page("Mercedes Sosa").content
|
| 294 |
+
out = qa_pipeline(question=question, context=page)
|
| 295 |
+
return out["answer"]
|
| 296 |
|
| 297 |
class GeminiAgent:
|
| 298 |
def __init__(self, api_key: str, model_name: str = "gemini-2.0-flash"):
|
|
|
|
| 339 |
name="web_search",
|
| 340 |
func=self._web_search,
|
| 341 |
description="Search the web for information"
|
| 342 |
+
),
|
| 343 |
+
Tool(
|
| 344 |
+
name="wikipedia_qa",
|
| 345 |
+
func=wikipedia_qa,
|
| 346 |
+
description="给定问题 + 维基文章,直接抽取精确答案"
|
| 347 |
+
)
|
| 348 |
]
|
| 349 |
|
| 350 |
# Setup memory
|
|
|
|
| 611 |
agent=agent,
|
| 612 |
tools=self.tools,
|
| 613 |
memory=self.memory,
|
| 614 |
+
max_iterations=10,
|
| 615 |
verbose=True,
|
| 616 |
handle_parsing_errors=True,
|
| 617 |
return_only_outputs=True # This ensures we only get the final output
|