unit_3_Alfred_RAG / retriever.py
JasonJy14's picture
Create retriever.py
b7bf918 verified
# retriever.py
# This file creates the RAG retrieval tool for guest information
from smolagents import Tool
from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
import datasets
class GuestInfoRetrieverTool(Tool):
"""
A tool that retrieves guest information from the gala database.
Alfred uses this to answer questions about party guests.
"""
name = "guest_info_retriever"
description = "Retrieves detailed information about gala guests including their names, backgrounds, preferences, and stories. Use this tool when you need to answer questions about specific guests or find guests with certain characteristics."
inputs = {
"query": {
"type": "string",
"description": "The search query to find relevant guest information. Be specific about what you're looking for."
}
}
output_type = "string"
def __init__(self, docs, **kwargs):
super().__init__(**kwargs)
self.retriever = BM25Retriever.from_documents(docs)
self.retriever.k = 3 # Return top 3 results
def forward(self, query: str) -> str:
"""
Search for guest information based on the query.
Args:
query: What to search for in the guest database
Returns:
Relevant guest information or a message if nothing found
"""
results = self.retriever.invoke(query)
if results:
return "\n\n---\n\n".join([doc.page_content for doc in results])
else:
return "No matching guest information found. Try a different search query."
def load_guest_dataset():
"""
Load the guest dataset and convert it to Document objects.
Returns:
List of Document objects containing guest information
"""
# Load the dataset from Hugging Face Hub
guest_dataset = datasets.load_dataset(
"agents-course/unit3-invitees",
split="train"
)
# Convert to Document objects for the retriever
# Dataset has fields: name, relation, description, email
docs = [
Document(
page_content="\n".join([
f"Name: {guest['name']}",
f"Relation: {guest['relation']}",
f"Description: {guest['description']}",
f"Email: {guest['email']}"
]),
metadata={"name": guest["name"]}
)
for guest in guest_dataset
]
return docs
def create_guest_retriever_tool():
"""
Create and return the guest retriever tool.
Returns:
GuestInfoRetrieverTool ready to use
"""
docs = load_guest_dataset()
return GuestInfoRetrieverTool(docs)