JasonJy14 commited on
Commit
b7bf918
·
verified ·
1 Parent(s): 05f416b

Create retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +88 -0
retriever.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # retriever.py
2
+ # This file creates the RAG retrieval tool for guest information
3
+
4
+ from smolagents import Tool
5
+ from langchain_community.retrievers import BM25Retriever
6
+ from langchain_core.documents import Document
7
+ import datasets
8
+
9
+ class GuestInfoRetrieverTool(Tool):
10
+ """
11
+ A tool that retrieves guest information from the gala database.
12
+ Alfred uses this to answer questions about party guests.
13
+ """
14
+
15
+ name = "guest_info_retriever"
16
+ description = "Retrieves detailed information about gala guests including their names, backgrounds, preferences, and stories. Use this tool when you need to answer questions about specific guests or find guests with certain characteristics."
17
+
18
+ inputs = {
19
+ "query": {
20
+ "type": "string",
21
+ "description": "The search query to find relevant guest information. Be specific about what you're looking for."
22
+ }
23
+ }
24
+ output_type = "string"
25
+
26
+ def __init__(self, docs, **kwargs):
27
+ super().__init__(**kwargs)
28
+ self.retriever = BM25Retriever.from_documents(docs)
29
+ self.retriever.k = 3 # Return top 3 results
30
+
31
+ def forward(self, query: str) -> str:
32
+ """
33
+ Search for guest information based on the query.
34
+
35
+ Args:
36
+ query: What to search for in the guest database
37
+
38
+ Returns:
39
+ Relevant guest information or a message if nothing found
40
+ """
41
+ results = self.retriever.invoke(query)
42
+
43
+ if results:
44
+ return "\n\n---\n\n".join([doc.page_content for doc in results])
45
+ else:
46
+ return "No matching guest information found. Try a different search query."
47
+
48
+
49
+ def load_guest_dataset():
50
+ """
51
+ Load the guest dataset and convert it to Document objects.
52
+
53
+ Returns:
54
+ List of Document objects containing guest information
55
+ """
56
+ # Load the dataset from Hugging Face Hub
57
+ guest_dataset = datasets.load_dataset(
58
+ "agents-course/unit3-invitees",
59
+ split="train"
60
+ )
61
+
62
+ # Convert to Document objects for the retriever
63
+ # Dataset has fields: name, relation, description, email
64
+ docs = [
65
+ Document(
66
+ page_content="\n".join([
67
+ f"Name: {guest['name']}",
68
+ f"Relation: {guest['relation']}",
69
+ f"Description: {guest['description']}",
70
+ f"Email: {guest['email']}"
71
+ ]),
72
+ metadata={"name": guest["name"]}
73
+ )
74
+ for guest in guest_dataset
75
+ ]
76
+
77
+ return docs
78
+
79
+
80
+ def create_guest_retriever_tool():
81
+ """
82
+ Create and return the guest retriever tool.
83
+
84
+ Returns:
85
+ GuestInfoRetrieverTool ready to use
86
+ """
87
+ docs = load_guest_dataset()
88
+ return GuestInfoRetrieverTool(docs)