Spaces:
Runtime error
Runtime error
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.docstore.document import Document | |
| from langchain.tools import StructuredTool | |
| from typing import List | |
| import datasets | |
| # Internal variable to cache retriever | |
| _guest_bm25_retriever = None | |
| def load_guest_dataset(): | |
| global _guest_bm25_retriever | |
| if _guest_bm25_retriever is not None: | |
| return _guest_bm25_retriever | |
| # Load the dataset | |
| guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train") | |
| # Convert dataset entries into Document objects | |
| docs = [ | |
| Document( | |
| page_content="\n".join([ | |
| f"Name: {guest['name']}", | |
| f"Relation: {guest['relation']}", | |
| f"Description: {guest['description']}", | |
| f"Email: {guest['email']}" | |
| ]), | |
| metadata={"name": guest["name"]} | |
| ) | |
| for guest in guest_dataset | |
| ] | |
| # Create and cache the retriever | |
| _guest_bm25_retriever = BM25Retriever.from_documents(docs) | |
| return _guest_bm25_retriever | |
| def retrieve_guest_info(query: str) -> str: | |
| retriever = load_guest_dataset() | |
| results: List[Document] = retriever.get_relevant_documents(query) | |
| if results: | |
| return "\n\n".join([doc.page_content for doc in results[:3]]) | |
| else: | |
| return "No matching guest information found." | |
| GuestInfoTool = StructuredTool.from_function( | |
| name="guest_info", | |
| description="Retrieves detailed information about gala guests based on their name or relation.", | |
| func=retrieve_guest_info | |
| ) | |