CindyDelage's picture
Update retriever.py
634d586 verified
from smolagents import Tool
from langchain_community.retrievers import BM25Retriever
from langchain.docstore.document import Document
class VegetableInfoRetrieverTool(Tool):
name = "vegetable_info_retriever"
description = "Retrieve information about vegetables based on their name or characteristics."
inputs = {
"query": {
"type": "string",
"description": "The name or description of the vegetable you want information about."
}
}
output_type = "string"
def __init__(self, docs):
self.is_initialized = False
self.retriever = BM25Retriever.from_documents(docs)
def forward(self, query: str):
results = self.retriever.get_relevant_documents(query)
if results:
return "\n\n".join([doc.page_content for doc in results[:3]])
else:
return "No matching vegetable information found."
def load_vegetable_dataset():
# Manually define the vegetable data
vegetable_data = [
{
"name": "Sweet Potatoes",
"type": "root",
"description": "A starchy root vegetable with sweet flavor, rich in fiber and beta-carotene."
},
{
"name": "Fresh Basil",
"type": "herb",
"description": "An aromatic green herb often used fresh in Mediterranean cuisine."
},
{
"name": "Broccoli",
"type": "cruciferous",
"description": "Green vegetable with edible flowering heads and stalks, high in fiber and vitamin C."
},
{
"name": "Celery",
"type": "stem",
"description": "Crisp green stalks often eaten raw or cooked, known for high water content and fiber."
},
{
"name": "Lettuce",
"type": "leaf",
"description": "Leafy green used in salads, comes in various varieties like romaine or butterhead."
},
]
# Convert each entry to a Document
docs = [
Document(
page_content="\n".join([
f"Name: {veg['name']}",
f"Type: {veg['type']}",
f"Description: {veg['description']}"
]),
metadata={"name": veg["name"]}
)
for veg in vegetable_data
]
# Return the tool
return VegetableInfoRetrieverTool(docs)