from smolagents import Tool from langchain_community.retrievers import BM25Retriever from langchain.docstore.document import Document import datasets import pandas as pd import os class GuestInfoRetrieverTool(Tool): name = "guest_info_retriever" description = "Retrieves detailed information about gala guests based on their name or relation." inputs = { "query": { "type": "string", "description": "The name or relation of the guest you want information about." } } output_type = "string" #def __init__(self, docs): #self.is_initialized = False #self.retriever = BM25Retriever.from_documents(docs) def __init__(self, docs): self.is_initialized = False self.docs = docs # 🔁 store the original list manually self.retriever = BM25Retriever.from_documents(docs) def _generate_conversation_starter(self, doc: Document): lines = doc.page_content.splitlines() name = None description = "" for line in lines: if line.startswith("Name:"): name = line.replace("Name:", "").strip() if line.startswith("Description:"): description = line.replace("Description:", "").strip() interests = [] for interest in ["art", "science", "sports", "music", "history", "technology", "travel", "literature"]: if interest.lower() in description.lower(): interests.append(interest) if interests: return f"A good icebreaker could be: 'I heard you're into {interests[0]}. What's your favorite part about it?'" else: return "Try asking about their background—it sounds fascinating!" #def forward(self, query: str): # Handle special case for full guest listing #if "list" in query.lower() and "guest" in query.lower() and "name" in query.lower(): #return "\n".join([ #doc.metadata.get("name", "Unknown") for doc in self.retriever.docs #]) def forward(self, query: str): if any(keyword in query.lower() for keyword in ["list guests", "guest names", "list all guests", "show guests", "all guests", "everyone invited"]): #if "list" in query.lower() and "guest" in query.lower() and "name" in query.lower(): return "\n".join([doc.metadata.get("name", "Unknown") for doc in self.docs]) # Default BM25 retrieval results = self.retriever.get_relevant_documents(query) if results: responses = [] #for doc in results[:3]: for doc in results[:10]: content = doc.page_content starter = self._generate_conversation_starter(doc) responses.append(f"{content}\n\n{starter}") return "\n\n---\n\n".join(responses) else: return "No matching guest information found." def load_guest_dataset(file_path: str = None, show_example: bool = True): """ Loads guest dataset either from a file (CSV/JSON) or the Hugging Face default dataset. If using the Hugging Face dataset, optionally prints a preview example. """ if file_path and os.path.exists(file_path): ext = os.path.splitext(file_path)[1].lower() if ext == ".csv": df = pd.read_csv(file_path) elif ext == ".json": df = pd.read_json(file_path) else: raise ValueError("Unsupported file format. Use .csv or .json.") else: guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train") df = pd.DataFrame(guest_dataset) if show_example: print("\n📌 Example guest from Hugging Face dataset:\n") print(df.head(1).to_markdown(index=False)) docs = [ Document( page_content="\n".join([ f"Name: {row['name']}", f"Relation: {row['relation']}", f"Description: {row['description']}", f"Email: {row['email']}" ]), metadata={"name": row["name"], "email": row["email"]} ) for _, row in df.iterrows() ] return GuestInfoRetrieverTool(docs)