Ventahana commited on
Commit
c14934b
·
verified ·
1 Parent(s): 13012d5

Update retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +21 -24
retriever.py CHANGED
@@ -1,59 +1,56 @@
1
  from smolagents import Tool
2
  import datasets
3
- from langchain_core.documents import Document
4
 
5
- print("🔄 Loading RAG dataset...")
6
 
7
- # Load the course dataset
8
  try:
 
9
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
10
  print(f"✅ Loaded {len(guest_dataset)} guests")
11
  except:
12
  # Fallback data
13
  guest_dataset = [
14
- {"name": "Lady Ada Lovelace", "relation": "mathematician", "description": "First computer programmer", "email": "ada@example.com"},
15
- {"name": "Dr. Nikola Tesla", "relation": "inventor", "description": "Electrical engineering pioneer", "email": "tesla@example.com"},
16
- {"name": "Marie Curie", "relation": "scientist", "description": "Nobel prize winning physicist", "email": "marie@example.com"}
 
17
  ]
18
  print(f"⚠️ Using {len(guest_dataset)} sample guests")
19
 
20
- # Create documents for RAG
21
- docs = [
22
- Document(
23
- page_content=f"Name: {guest['name']}\nRelation: {guest['relation']}\nDescription: {guest['description']}\nEmail: {guest['email']}",
24
- metadata={"name": guest["name"]}
25
- )
26
- for guest in guest_dataset
27
- ]
28
-
29
  class GuestInfoRetrieverTool(Tool):
30
  name = "guest_info_retriever"
31
- description = "Retrieves detailed information about gala guests using RAG."
32
  inputs = {
33
  "query": {
34
  "type": "string",
35
- "description": "The name or relation of the guest."
36
  }
37
  }
38
  output_type = "string"
39
 
40
  def forward(self, query: str):
41
- # Simple RAG: search through documents
42
  query_lower = query.lower()
43
  results = []
44
 
45
- for doc in docs:
46
- if query_lower in doc.page_content.lower():
47
- results.append(doc.page_content)
 
 
 
 
 
 
 
48
  if len(results) >= 3:
49
  break
50
 
51
  if results:
52
- return "🧠 RAG Results:\n\n" + "\n\n---\n\n".join(results)
53
  else:
54
- return f"No guest found for '{query}'"
55
 
56
- # Create RAG tool
57
  guest_info_tool = GuestInfoRetrieverTool()
58
  print("✅ RAG tool created")
59
 
 
1
  from smolagents import Tool
2
  import datasets
 
3
 
4
+ print("📂 Loading guest dataset...")
5
 
 
6
  try:
7
+ # Load dataset as shown in course
8
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
9
  print(f"✅ Loaded {len(guest_dataset)} guests")
10
  except:
11
  # Fallback data
12
  guest_dataset = [
13
+ {"name": "Lady Ada Lovelace", "relation": "mathematician",
14
+ "description": "First computer programmer", "email": "ada@example.com"},
15
+ {"name": "Dr. Nikola Tesla", "relation": "inventor",
16
+ "description": "Electrical engineering pioneer", "email": "tesla@example.com"}
17
  ]
18
  print(f"⚠️ Using {len(guest_dataset)} sample guests")
19
 
 
 
 
 
 
 
 
 
 
20
  class GuestInfoRetrieverTool(Tool):
21
  name = "guest_info_retriever"
22
+ description = "Retrieves detailed information about gala guests based on their name or relation."
23
  inputs = {
24
  "query": {
25
  "type": "string",
26
+ "description": "The name or relation of the guest you want information about."
27
  }
28
  }
29
  output_type = "string"
30
 
31
  def forward(self, query: str):
 
32
  query_lower = query.lower()
33
  results = []
34
 
35
+ for guest in guest_dataset:
36
+ if (query_lower in guest['name'].lower() or
37
+ query_lower in guest['relation'].lower()):
38
+
39
+ guest_info = f"""Name: {guest['name']}
40
+ Relation: {guest['relation']}
41
+ Description: {guest['description']}
42
+ Email: {guest['email']}"""
43
+
44
+ results.append(guest_info)
45
  if len(results) >= 3:
46
  break
47
 
48
  if results:
49
+ return "\n\n---\n\n".join(results)
50
  else:
51
+ return f"No guest found matching '{query}'"
52
 
53
+ # Create tool instance
54
  guest_info_tool = GuestInfoRetrieverTool()
55
  print("✅ RAG tool created")
56