from uuid import uuid4 import datasets from smolagents import Tool class GuestInfoRetrieverTool(Tool): name = "guest_info_retriever" description = "Retrieves detailed information about gala guests based on their name or relation." inputs = { "query": { "type": "string", "description": "The name or relation of the guest you want information about." } } output_type = "string" def __init__(self, vector_store): self.is_initialized = False self.vector_store = vector_store def forward(self, query: str): result = self.vector_store.query( query_texts=[query], n_results=3 ) distances = [distance for distance in result['distances'][0] if distance < 1.3] docs = result['documents'][0] return "\n\n".join([docs[idx] for idx in range(0, len(distances))]) def load_guest_dataset(vector_store): # Load the dataset guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train") # Convert dataset entries into Document objects for guest in guest_dataset: vector_store.add( documents=[ "\n".join([ f"Name: {guest['name']}", f"Relation: {guest['relation']}", f"Description: {guest['description']}", f"Email: {guest['email']}" ]) ], metadatas=[{"name": guest["name"]}], ids=[str(uuid4())]) # Return the tool return GuestInfoRetrieverTool(vector_store)