SergeyO7 commited on
Commit
8f95373
·
verified ·
1 Parent(s): 54b2820

Update retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +64 -18
retriever.py CHANGED
@@ -2,12 +2,14 @@ from smolagents import Tool
2
  from langchain_community.vectorstores import FAISS
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from langchain.docstore.document import Document
 
 
5
  from tools import DuckDuckGoSearchTool
6
  import datasets
7
 
8
- class GuestInfoRetrieverTool(Tool):
9
- name = "guest_info_retriever"
10
- description = "Retrieves detailed information about gala guests using semantic search."
11
  inputs = {
12
  "query": {
13
  "type": "string",
@@ -16,27 +18,61 @@ class GuestInfoRetrieverTool(Tool):
16
  }
17
  output_type = "string"
18
 
19
- def __init__(self, docs):
20
  self.is_initialized = False
21
- # Initialize embedding model
22
  self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
23
- # Create FAISS vector store
24
- self.retriever = FAISS.from_documents(docs, self.embeddings).as_retriever(
25
- search_kwargs={"k": 3}) # Return top 3 results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  self.web_search_tool = DuckDuckGoSearchTool()
27
-
28
  def forward(self, query: str):
29
- results = self.retriever.get_relevant_documents(query)
 
 
30
  if results:
31
- return "\n\n".join([doc.page_content for doc in results])
32
- else:
33
- # Fallback to web search
34
- web_results = self.web_search_tool.forward(f"Who is {query}?")
35
- return f"No guest found in dataset. Web search results:\n{web_results}"
 
 
 
 
 
 
 
 
 
36
 
37
  def load_guest_dataset():
 
38
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
39
- docs = [
40
  Document(
41
  page_content="\n".join([
42
  f"Name: {guest['name']}",
@@ -44,8 +80,18 @@ def load_guest_dataset():
44
  f"Description: {guest['description']}",
45
  f"Email: {guest['email']}"
46
  ]),
47
- metadata={"name": guest["name"]}
48
  )
49
  for guest in guest_dataset
50
  ]
51
- return GuestInfoRetrieverTool(docs)
 
 
 
 
 
 
 
 
 
 
 
2
  from langchain_community.vectorstores import FAISS
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from langchain.docstore.document import Document
5
+ from langchain.retrievers import EnsembleRetriever
6
+ from langchain_community.retrievers import BM25Retriever
7
  from tools import DuckDuckGoSearchTool
8
  import datasets
9
 
10
+ class MultiIndexRetrieverTool(Tool):
11
+ name = "multi_index_guest_retriever"
12
+ description = "Retrieves guest information from multiple indexes and verified sources."
13
  inputs = {
14
  "query": {
15
  "type": "string",
 
18
  }
19
  output_type = "string"
20
 
21
+ def __init__(self, primary_docs, secondary_docs=None):
22
  self.is_initialized = False
 
23
  self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
24
+
25
+ # Primary index (guest dataset)
26
+ self.primary_retriever = FAISS.from_documents(
27
+ primary_docs, self.embeddings
28
+ ).as_retriever(search_kwargs={"k": 3})
29
+
30
+ # Secondary index (e.g., Wikipedia or another dataset)
31
+ self.secondary_retriever = None
32
+ if secondary_docs:
33
+ self.secondary_retriever = FAISS.from_documents(
34
+ secondary_docs, self.embeddings
35
+ ).as_retriever(search_kwargs={"k": 3})
36
+
37
+ # BM25 for keyword-based fallback
38
+ self.bm25_retriever = BM25Retriever.from_documents(primary_docs)
39
+ self.bm25_retriever.k = 3
40
+
41
+ # Ensemble retriever (combines primary and secondary)
42
+ retrievers = [self.primary_retriever, self.bm25_retriever]
43
+ if self.secondary_retriever:
44
+ retrievers.append(self.secondary_retriever)
45
+
46
+ self.ensemble_retriever = EnsembleRetriever(
47
+ retrievers=retrievers, weights=[0.5, 0.3, 0.2] if self.secondary_retriever else [0.7, 0.3]
48
+ )
49
+
50
  self.web_search_tool = DuckDuckGoSearchTool()
51
+
52
  def forward(self, query: str):
53
+ # Retrieve from ensemble
54
+ results = self.ensemble_retriever.get_relevant_documents(query)
55
+
56
  if results:
57
+ # Filter for verified sources (e.g., prioritize dataset over secondary)
58
+ verified_results = [
59
+ doc for doc in results if doc.metadata.get("source", "").startswith("unit3-invitees")
60
+ ]
61
+ other_results = [
62
+ doc for doc in results if not doc.metadata.get("source", "").startswith("unit3-invitees")
63
+ ]
64
+ combined_results = verified_results[:2] + other_results[:1] # Prioritize verified
65
+ if combined_results:
66
+ return "\n\n".join([doc.page_content for doc in combined_results])
67
+
68
+ # Fallback to web search
69
+ web_results = self.web_search_tool.forward(f"Who is {query}?")
70
+ return f"No guest found in indexes. Web search results:\n{web_results}"
71
 
72
  def load_guest_dataset():
73
+ # Primary dataset
74
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
75
+ primary_docs = [
76
  Document(
77
  page_content="\n".join([
78
  f"Name: {guest['name']}",
 
80
  f"Description: {guest['description']}",
81
  f"Email: {guest['email']}"
82
  ]),
83
+ metadata={"name": guest["name"], "source": "unit3-invitees"}
84
  )
85
  for guest in guest_dataset
86
  ]
87
+
88
+ # Secondary dataset (example: Wikipedia-like data)
89
+ secondary_docs = [
90
+ Document(
91
+ page_content="Name: Ada Lovelace\nDescription: Known as the first computer programmer, wrote the first algorithm for Charles Babbage's Analytical Engine.",
92
+ metadata={"name": "Ada Lovelace", "source": "wikipedia"}
93
+ )
94
+ # Add more secondary documents as needed
95
+ ]
96
+
97
+ return MultiIndexRetrieverTool(primary_docs, secondary_docs)