agent-gala-smolagents / retriever.py
MeghanaNanuvala's picture
Update retriever.py
c6bbed8 verified
import datasets
from langchain_core.documents import Document
guest_dataset = datasets.load_dataset('agents-course/unit3-invitees', split='train')
docs = [
Document(
page_content="\n".join([
f"Name: {guest['name']}",
f"Relation: {guest['relation']}",
f"Description: {guest['description']}",
f"Email: {guest['email']}"
]),
metadata={'name': guest['name']}
)
for guest in guest_dataset
]
from smolagents import Tool
from langchain_community.retrievers import BM25Retriever
class GuestRetrieverTool(Tool):
name = "guest_info_retriever"
description = "Retrieves information about gala guests by name or relation."
inputs = {
"query": {
"type": "string",
"description": "The name or relation of the guest."
}
}
output_type = "string"
def __init__(self, docs):
self.initialized = True
self.retriever = BM25Retriever.from_documents(docs)
# Optional: limit top-k results
self.retriever.k = 3
def forward(self, query: str) -> str:
# ✅ correct API
results = self.retriever.get_relevant_documents(query)
if results:
return "\n\n".join(doc.page_content for doc in results[:3])
return "No matching guest info found..."
guest_info_tool = GuestRetrieverTool(docs)
# ⚠️ Don't run agents at import time.
if __name__ == "__main__":
from smolagents import CodeAgent, InferenceClientModel
model = InferenceClientModel(model_id="HuggingFaceH4/zephyr-7b-beta")
alfred = CodeAgent(tools=[guest_info_tool], model=model)
print(alfred.run("Tell me about guest named 'Lady Ada Lovelace'"))