Spaces:
Sleeping
Sleeping
| from smolagents import Tool | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.docstore.document import Document | |
| import datasets | |
| import pandas as pd | |
| import os | |
| class GuestInfoRetrieverTool(Tool): | |
| name = "guest_info_retriever" | |
| description = "Retrieves detailed information about gala guests based on their name or relation." | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "The name or relation of the guest you want information about." | |
| } | |
| } | |
| output_type = "string" | |
| #def __init__(self, docs): | |
| #self.is_initialized = False | |
| #self.retriever = BM25Retriever.from_documents(docs) | |
| def __init__(self, docs): | |
| self.is_initialized = False | |
| self.docs = docs # 🔁 store the original list manually | |
| self.retriever = BM25Retriever.from_documents(docs) | |
| def _generate_conversation_starter(self, doc: Document): | |
| lines = doc.page_content.splitlines() | |
| name = None | |
| description = "" | |
| for line in lines: | |
| if line.startswith("Name:"): | |
| name = line.replace("Name:", "").strip() | |
| if line.startswith("Description:"): | |
| description = line.replace("Description:", "").strip() | |
| interests = [] | |
| for interest in ["art", "science", "sports", "music", "history", "technology", "travel", "literature"]: | |
| if interest.lower() in description.lower(): | |
| interests.append(interest) | |
| if interests: | |
| return f"A good icebreaker could be: 'I heard you're into {interests[0]}. What's your favorite part about it?'" | |
| else: | |
| return "Try asking about their background—it sounds fascinating!" | |
| #def forward(self, query: str): | |
| # Handle special case for full guest listing | |
| #if "list" in query.lower() and "guest" in query.lower() and "name" in query.lower(): | |
| #return "\n".join([ | |
| #doc.metadata.get("name", "Unknown") for doc in self.retriever.docs | |
| #]) | |
| def forward(self, query: str): | |
| if any(keyword in query.lower() for keyword in ["list guests", "guest names", "list all guests", "show guests", "all guests", "everyone invited"]): | |
| #if "list" in query.lower() and "guest" in query.lower() and "name" in query.lower(): | |
| return "\n".join([doc.metadata.get("name", "Unknown") for doc in self.docs]) | |
| # Default BM25 retrieval | |
| results = self.retriever.get_relevant_documents(query) | |
| if results: | |
| responses = [] | |
| #for doc in results[:3]: | |
| for doc in results[:10]: | |
| content = doc.page_content | |
| starter = self._generate_conversation_starter(doc) | |
| responses.append(f"{content}\n\n{starter}") | |
| return "\n\n---\n\n".join(responses) | |
| else: | |
| return "No matching guest information found." | |
| def load_guest_dataset(file_path: str = None, show_example: bool = True): | |
| """ | |
| Loads guest dataset either from a file (CSV/JSON) or the Hugging Face default dataset. | |
| If using the Hugging Face dataset, optionally prints a preview example. | |
| """ | |
| if file_path and os.path.exists(file_path): | |
| ext = os.path.splitext(file_path)[1].lower() | |
| if ext == ".csv": | |
| df = pd.read_csv(file_path) | |
| elif ext == ".json": | |
| df = pd.read_json(file_path) | |
| else: | |
| raise ValueError("Unsupported file format. Use .csv or .json.") | |
| else: | |
| guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train") | |
| df = pd.DataFrame(guest_dataset) | |
| if show_example: | |
| print("\n📌 Example guest from Hugging Face dataset:\n") | |
| print(df.head(1).to_markdown(index=False)) | |
| docs = [ | |
| Document( | |
| page_content="\n".join([ | |
| f"Name: {row['name']}", | |
| f"Relation: {row['relation']}", | |
| f"Description: {row['description']}", | |
| f"Email: {row['email']}" | |
| ]), | |
| metadata={"name": row["name"], "email": row["email"]} | |
| ) | |
| for _, row in df.iterrows() | |
| ] | |
| return GuestInfoRetrieverTool(docs) |