RAG_System / retriever.py
Ventahana's picture
Update retriever.py
c14934b verified
from smolagents import Tool
import datasets
print("πŸ“‚ Loading guest dataset...")
try:
# Load dataset as shown in course
guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
print(f"βœ… Loaded {len(guest_dataset)} guests")
except:
# Fallback data
guest_dataset = [
{"name": "Lady Ada Lovelace", "relation": "mathematician",
"description": "First computer programmer", "email": "ada@example.com"},
{"name": "Dr. Nikola Tesla", "relation": "inventor",
"description": "Electrical engineering pioneer", "email": "tesla@example.com"}
]
print(f"⚠️ Using {len(guest_dataset)} sample guests")
class GuestInfoRetrieverTool(Tool):
name = "guest_info_retriever"
description = "Retrieves detailed information about gala guests based on their name or relation."
inputs = {
"query": {
"type": "string",
"description": "The name or relation of the guest you want information about."
}
}
output_type = "string"
def forward(self, query: str):
query_lower = query.lower()
results = []
for guest in guest_dataset:
if (query_lower in guest['name'].lower() or
query_lower in guest['relation'].lower()):
guest_info = f"""Name: {guest['name']}
Relation: {guest['relation']}
Description: {guest['description']}
Email: {guest['email']}"""
results.append(guest_info)
if len(results) >= 3:
break
if results:
return "\n\n---\n\n".join(results)
else:
return f"No guest found matching '{query}'"
# Create tool instance
guest_info_tool = GuestInfoRetrieverTool()
print("βœ… RAG tool created")