Spaces:
Sleeping
Sleeping
| # retriever.py | |
| # This file creates the RAG retrieval tool for guest information | |
| from smolagents import Tool | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain_core.documents import Document | |
| import datasets | |
| class GuestInfoRetrieverTool(Tool): | |
| """ | |
| A tool that retrieves guest information from the gala database. | |
| Alfred uses this to answer questions about party guests. | |
| """ | |
| name = "guest_info_retriever" | |
| description = "Retrieves detailed information about gala guests including their names, backgrounds, preferences, and stories. Use this tool when you need to answer questions about specific guests or find guests with certain characteristics." | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "The search query to find relevant guest information. Be specific about what you're looking for." | |
| } | |
| } | |
| output_type = "string" | |
| def __init__(self, docs, **kwargs): | |
| super().__init__(**kwargs) | |
| self.retriever = BM25Retriever.from_documents(docs) | |
| self.retriever.k = 3 # Return top 3 results | |
| def forward(self, query: str) -> str: | |
| """ | |
| Search for guest information based on the query. | |
| Args: | |
| query: What to search for in the guest database | |
| Returns: | |
| Relevant guest information or a message if nothing found | |
| """ | |
| results = self.retriever.invoke(query) | |
| if results: | |
| return "\n\n---\n\n".join([doc.page_content for doc in results]) | |
| else: | |
| return "No matching guest information found. Try a different search query." | |
| def load_guest_dataset(): | |
| """ | |
| Load the guest dataset and convert it to Document objects. | |
| Returns: | |
| List of Document objects containing guest information | |
| """ | |
| # Load the dataset from Hugging Face Hub | |
| guest_dataset = datasets.load_dataset( | |
| "agents-course/unit3-invitees", | |
| split="train" | |
| ) | |
| # Convert to Document objects for the retriever | |
| # Dataset has fields: name, relation, description, email | |
| docs = [ | |
| Document( | |
| page_content="\n".join([ | |
| f"Name: {guest['name']}", | |
| f"Relation: {guest['relation']}", | |
| f"Description: {guest['description']}", | |
| f"Email: {guest['email']}" | |
| ]), | |
| metadata={"name": guest["name"]} | |
| ) | |
| for guest in guest_dataset | |
| ] | |
| return docs | |
| def create_guest_retriever_tool(): | |
| """ | |
| Create and return the guest retriever tool. | |
| Returns: | |
| GuestInfoRetrieverTool ready to use | |
| """ | |
| docs = load_guest_dataset() | |
| return GuestInfoRetrieverTool(docs) |