GalaGuide_Agentic_RAG / retriever.py
dlaima's picture
Update retriever.py
9059902 verified
from smolagents import Tool
from langchain_community.retrievers import BM25Retriever
from langchain.docstore.document import Document
import datasets
import pandas as pd
import os
class GuestInfoRetrieverTool(Tool):
name = "guest_info_retriever"
description = "Retrieves detailed information about gala guests based on their name or relation."
inputs = {
"query": {
"type": "string",
"description": "The name or relation of the guest you want information about."
}
}
output_type = "string"
#def __init__(self, docs):
#self.is_initialized = False
#self.retriever = BM25Retriever.from_documents(docs)
def __init__(self, docs):
self.is_initialized = False
self.docs = docs # 🔁 store the original list manually
self.retriever = BM25Retriever.from_documents(docs)
def _generate_conversation_starter(self, doc: Document):
lines = doc.page_content.splitlines()
name = None
description = ""
for line in lines:
if line.startswith("Name:"):
name = line.replace("Name:", "").strip()
if line.startswith("Description:"):
description = line.replace("Description:", "").strip()
interests = []
for interest in ["art", "science", "sports", "music", "history", "technology", "travel", "literature"]:
if interest.lower() in description.lower():
interests.append(interest)
if interests:
return f"A good icebreaker could be: 'I heard you're into {interests[0]}. What's your favorite part about it?'"
else:
return "Try asking about their background—it sounds fascinating!"
#def forward(self, query: str):
# Handle special case for full guest listing
#if "list" in query.lower() and "guest" in query.lower() and "name" in query.lower():
#return "\n".join([
#doc.metadata.get("name", "Unknown") for doc in self.retriever.docs
#])
def forward(self, query: str):
if any(keyword in query.lower() for keyword in ["list guests", "guest names", "list all guests", "show guests", "all guests", "everyone invited"]):
#if "list" in query.lower() and "guest" in query.lower() and "name" in query.lower():
return "\n".join([doc.metadata.get("name", "Unknown") for doc in self.docs])
# Default BM25 retrieval
results = self.retriever.get_relevant_documents(query)
if results:
responses = []
#for doc in results[:3]:
for doc in results[:10]:
content = doc.page_content
starter = self._generate_conversation_starter(doc)
responses.append(f"{content}\n\n{starter}")
return "\n\n---\n\n".join(responses)
else:
return "No matching guest information found."
def load_guest_dataset(file_path: str = None, show_example: bool = True):
"""
Loads guest dataset either from a file (CSV/JSON) or the Hugging Face default dataset.
If using the Hugging Face dataset, optionally prints a preview example.
"""
if file_path and os.path.exists(file_path):
ext = os.path.splitext(file_path)[1].lower()
if ext == ".csv":
df = pd.read_csv(file_path)
elif ext == ".json":
df = pd.read_json(file_path)
else:
raise ValueError("Unsupported file format. Use .csv or .json.")
else:
guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
df = pd.DataFrame(guest_dataset)
if show_example:
print("\n📌 Example guest from Hugging Face dataset:\n")
print(df.head(1).to_markdown(index=False))
docs = [
Document(
page_content="\n".join([
f"Name: {row['name']}",
f"Relation: {row['relation']}",
f"Description: {row['description']}",
f"Email: {row['email']}"
]),
metadata={"name": row["name"], "email": row["email"]}
)
for _, row in df.iterrows()
]
return GuestInfoRetrieverTool(docs)