Spaces:
Runtime error
Runtime error
Update retriever.py
Browse files- retriever.py +21 -24
retriever.py
CHANGED
|
@@ -1,59 +1,56 @@
|
|
| 1 |
from smolagents import Tool
|
| 2 |
import datasets
|
| 3 |
-
from langchain_core.documents import Document
|
| 4 |
|
| 5 |
-
print("
|
| 6 |
|
| 7 |
-
# Load the course dataset
|
| 8 |
try:
|
|
|
|
| 9 |
guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
|
| 10 |
print(f"✅ Loaded {len(guest_dataset)} guests")
|
| 11 |
except:
|
| 12 |
# Fallback data
|
| 13 |
guest_dataset = [
|
| 14 |
-
{"name": "Lady Ada Lovelace", "relation": "mathematician",
|
| 15 |
-
|
| 16 |
-
{"name": "
|
|
|
|
| 17 |
]
|
| 18 |
print(f"⚠️ Using {len(guest_dataset)} sample guests")
|
| 19 |
|
| 20 |
-
# Create documents for RAG
|
| 21 |
-
docs = [
|
| 22 |
-
Document(
|
| 23 |
-
page_content=f"Name: {guest['name']}\nRelation: {guest['relation']}\nDescription: {guest['description']}\nEmail: {guest['email']}",
|
| 24 |
-
metadata={"name": guest["name"]}
|
| 25 |
-
)
|
| 26 |
-
for guest in guest_dataset
|
| 27 |
-
]
|
| 28 |
-
|
| 29 |
class GuestInfoRetrieverTool(Tool):
|
| 30 |
name = "guest_info_retriever"
|
| 31 |
-
description = "Retrieves detailed information about gala guests
|
| 32 |
inputs = {
|
| 33 |
"query": {
|
| 34 |
"type": "string",
|
| 35 |
-
"description": "The name or relation of the guest."
|
| 36 |
}
|
| 37 |
}
|
| 38 |
output_type = "string"
|
| 39 |
|
| 40 |
def forward(self, query: str):
|
| 41 |
-
# Simple RAG: search through documents
|
| 42 |
query_lower = query.lower()
|
| 43 |
results = []
|
| 44 |
|
| 45 |
-
for
|
| 46 |
-
if query_lower in
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
if len(results) >= 3:
|
| 49 |
break
|
| 50 |
|
| 51 |
if results:
|
| 52 |
-
return "
|
| 53 |
else:
|
| 54 |
-
return f"No guest found
|
| 55 |
|
| 56 |
-
# Create
|
| 57 |
guest_info_tool = GuestInfoRetrieverTool()
|
| 58 |
print("✅ RAG tool created")
|
| 59 |
|
|
|
|
| 1 |
from smolagents import Tool
|
| 2 |
import datasets
|
|
|
|
| 3 |
|
| 4 |
+
print("📂 Loading guest dataset...")
|
| 5 |
|
|
|
|
| 6 |
try:
|
| 7 |
+
# Load dataset as shown in course
|
| 8 |
guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
|
| 9 |
print(f"✅ Loaded {len(guest_dataset)} guests")
|
| 10 |
except:
|
| 11 |
# Fallback data
|
| 12 |
guest_dataset = [
|
| 13 |
+
{"name": "Lady Ada Lovelace", "relation": "mathematician",
|
| 14 |
+
"description": "First computer programmer", "email": "ada@example.com"},
|
| 15 |
+
{"name": "Dr. Nikola Tesla", "relation": "inventor",
|
| 16 |
+
"description": "Electrical engineering pioneer", "email": "tesla@example.com"}
|
| 17 |
]
|
| 18 |
print(f"⚠️ Using {len(guest_dataset)} sample guests")
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
class GuestInfoRetrieverTool(Tool):
|
| 21 |
name = "guest_info_retriever"
|
| 22 |
+
description = "Retrieves detailed information about gala guests based on their name or relation."
|
| 23 |
inputs = {
|
| 24 |
"query": {
|
| 25 |
"type": "string",
|
| 26 |
+
"description": "The name or relation of the guest you want information about."
|
| 27 |
}
|
| 28 |
}
|
| 29 |
output_type = "string"
|
| 30 |
|
| 31 |
def forward(self, query: str):
|
|
|
|
| 32 |
query_lower = query.lower()
|
| 33 |
results = []
|
| 34 |
|
| 35 |
+
for guest in guest_dataset:
|
| 36 |
+
if (query_lower in guest['name'].lower() or
|
| 37 |
+
query_lower in guest['relation'].lower()):
|
| 38 |
+
|
| 39 |
+
guest_info = f"""Name: {guest['name']}
|
| 40 |
+
Relation: {guest['relation']}
|
| 41 |
+
Description: {guest['description']}
|
| 42 |
+
Email: {guest['email']}"""
|
| 43 |
+
|
| 44 |
+
results.append(guest_info)
|
| 45 |
if len(results) >= 3:
|
| 46 |
break
|
| 47 |
|
| 48 |
if results:
|
| 49 |
+
return "\n\n---\n\n".join(results)
|
| 50 |
else:
|
| 51 |
+
return f"No guest found matching '{query}'"
|
| 52 |
|
| 53 |
+
# Create tool instance
|
| 54 |
guest_info_tool = GuestInfoRetrieverTool()
|
| 55 |
print("✅ RAG tool created")
|
| 56 |
|