wiki_tool_use / pyserini_wikipedia_kilt_doc.py
SiliangZ's picture
Upload tool
59324d3 verified
from pyserini.search.lucene import LuceneSearcher
from transformers import Tool
import json
# 初始化搜索器放在函数内部以避免模块导入时直接执行
def get_searcher():
return LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
# 或者使用本地索引
# return LuceneSearcher('index-wikipedia-kilt-doc-20210421-f29307.b8ec8feb654f7aaa86f9901dc6c804a8')
def search(query):
searcher = get_searcher()
hits = searcher.search(query, k=1)
hit = hits[0]
# 使用正确的方法获取文档内容
doc_id = hit.docid
doc = searcher.doc(doc_id)
# Document对象需要访问其内容,而不是直接解析
contents = json.loads(doc.raw())['contents']
return contents
class PyseriniWikipediaKiltDoc(Tool):
name = "pyserini_wikipedia_kilt_doc"
description = "This is a tool that returns the top result from the Wikipedia KILT index."
inputs = {
"query": {
"type": "string",
"description": "The search query to find information from Wikipedia"
}
}
output_type = "string" # 修改为 "string" 而不是 "str"
outputs = {
"type": "string",
"description": "Wikipedia article content matching the query"
}
def __call__(self, query: str):
return search(query)
# # 不在模块级别创建实例,而是在需要时创建
# # 这样导入模块时不会立即加载索引和初始化工具
# def get_tool():
# return PyseriniWikipediaKiltDoc()
# # 测试代码
# if __name__ == "__main__":
# tool = get_tool()
# print(tool("What is the capital of France?"))