Basti-1995 commited on
Commit
d565e36
·
1 Parent(s): 8507438

1st commit - advanced retriever

Browse files
__pycache__/retriever.cpython-310.pyc ADDED
Binary file (2.07 kB). View file
 
__pycache__/tools.cpython-310.pyc ADDED
Binary file (1.88 kB). View file
 
app.py CHANGED
@@ -23,7 +23,12 @@ guest_info_tool = load_guest_dataset()
23
 
24
  # Create Alfred with all the tools
25
  alfred = CodeAgent(
26
- tools=[guest_info_tool, weather_info_tool, hub_stats_tool, search_tool],
 
 
 
 
 
27
  model=model,
28
  add_base_tools=True, # Add any additional base tools
29
  planning_interval=3 # Enable planning every 3 steps
 
23
 
24
  # Create Alfred with all the tools
25
  alfred = CodeAgent(
26
+ tools=[
27
+ guest_info_tool,
28
+ weather_info_tool,
29
+ hub_stats_tool,
30
+ search_tool
31
+ ],
32
  model=model,
33
  add_base_tools=True, # Add any additional base tools
34
  planning_interval=3 # Enable planning every 3 steps
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- datasets
2
- smolagents
3
- langchain-community
4
  rank_bm25
 
1
+ datasets
2
+ smolagents
3
+ langchain-community
4
  rank_bm25
retriever.py CHANGED
@@ -1,12 +1,75 @@
1
- from smolagents import Tool
2
- from langchain_community.retrievers import BM25Retriever
3
  from langchain.docstore.document import Document
 
 
 
 
4
  import datasets
5
 
6
 
7
- class GuestInfoRetrieverTool(Tool):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  name = "guest_info_retriever"
9
- description = "Retrieves detailed information about gala guests based on their name or relation."
 
 
 
10
  inputs = {
11
  "query": {
12
  "type": "string",
@@ -15,24 +78,19 @@ class GuestInfoRetrieverTool(Tool):
15
  }
16
  output_type = "string"
17
 
18
- def __init__(self, docs):
19
- self.is_initialized = False
20
- self.retriever = BM25Retriever.from_documents(docs)
21
-
22
 
23
  def forward(self, query: str):
24
  results = self.retriever.get_relevant_documents(query)
25
  if results:
26
- return "\n\n".join([doc.page_content for doc in results[:3]])
27
  else:
28
  return "No matching guest information found."
29
 
30
 
31
  def load_guest_dataset():
32
- # Load the dataset
33
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
34
-
35
- # Convert dataset entries into Document objects
36
  docs = [
37
  Document(
38
  page_content="\n".join([
@@ -45,9 +103,5 @@ def load_guest_dataset():
45
  )
46
  for guest in guest_dataset
47
  ]
48
-
49
- # Return the tool
50
- return GuestInfoRetrieverTool(docs)
51
-
52
-
53
 
 
1
+ from langchain_community.retrievers import BM25Retriever, EnsembleRetriever
2
+ from langchain.vectorstores import FAISS
3
  from langchain.docstore.document import Document
4
+ from langchain_community.embeddings import HuggingFaceEmbeddings
5
+ from sentence_transformers.util import cos_sim
6
+ from smolagents import Tool
7
+ import numpy as np
8
  import datasets
9
 
10
 
11
+ class HybridRetriever:
12
+ def __init__(self, docs, mode="rerank", k=5):
13
+ """
14
+ mode: "ensemble" or "rerank"
15
+ k: number of top docs to return
16
+ """
17
+ self.docs = docs
18
+ self.mode = mode
19
+ self.k = k
20
+ self.embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
21
+
22
+ # Initialize BM25 retriever
23
+ self.bm25 = BM25Retriever.from_documents(docs)
24
+ self.bm25.k = 20
25
+
26
+ # Initialize FAISS retriever
27
+ self.faiss = FAISS.from_documents(docs, self.embedding_model)
28
+ self.faiss_retriever = self.faiss.as_retriever(search_kwargs={"k": 20})
29
+
30
+ # For reranker mode, cache doc embeddings
31
+ self.doc_embeddings = {
32
+ doc.page_content: self.embedding_model.embed_query(doc.page_content)
33
+ for doc in docs
34
+ }
35
+
36
+ # Ensemble retriever setup
37
+ if mode == "ensemble":
38
+ self.retriever = EnsembleRetriever(
39
+ retrievers=[self.bm25, self.faiss_retriever],
40
+ weights=[0.5, 0.5]
41
+ )
42
+
43
+ def get_relevant_documents(self, query: str):
44
+ if self.mode == "ensemble":
45
+ return self.retriever.get_relevant_documents(query)[:self.k]
46
+
47
+ elif self.mode == "rerank":
48
+ bm25_candidates = self.bm25.get_relevant_documents(query)
49
+ query_embedding = self.embedding_model.embed_query(query)
50
+
51
+ scores = []
52
+ for doc in bm25_candidates:
53
+ doc_vec = self.doc_embeddings.get(doc.page_content)
54
+ if doc_vec is not None:
55
+ sim = np.dot(query_embedding, doc_vec) / (
56
+ np.linalg.norm(query_embedding) * np.linalg.norm(doc_vec)
57
+ )
58
+ scores.append((sim, doc))
59
+
60
+ top_docs = sorted(scores, key=lambda x: x[0], reverse=True)[:self.k]
61
+ return [doc for _, doc in top_docs]
62
+
63
+ else:
64
+ raise ValueError(f"Unsupported mode: {self.mode}")
65
+
66
+
67
+ class GuestInfoHybridTool(Tool):
68
  name = "guest_info_retriever"
69
+ description = (
70
+ "Retrieves detailed information about gala guests based on their name or relation "
71
+ "using a hybrid of BM25 and embeddings. Supports ensemble or reranking."
72
+ )
73
  inputs = {
74
  "query": {
75
  "type": "string",
 
78
  }
79
  output_type = "string"
80
 
81
+ def __init__(self, docs, mode="rerank"):
82
+ self.retriever = HybridRetriever(docs, mode=mode)
 
 
83
 
84
  def forward(self, query: str):
85
  results = self.retriever.get_relevant_documents(query)
86
  if results:
87
+ return "\n\n".join([doc.page_content for doc in results])
88
  else:
89
  return "No matching guest information found."
90
 
91
 
92
  def load_guest_dataset():
 
93
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
 
 
94
  docs = [
95
  Document(
96
  page_content="\n".join([
 
103
  )
104
  for guest in guest_dataset
105
  ]
106
+ return GuestInfoHybridTool(docs, mode="rerank")
 
 
 
 
107
 
tools.py CHANGED
@@ -45,7 +45,7 @@ class HubStatsTool(Tool):
45
  try:
46
  # List models from the specified author, sorted by downloads
47
  models = list(list_models(author=author, sort="downloads", direction=-1, limit=1))
48
-
49
  if models:
50
  model = models[0]
51
  return f"The most downloaded model by {author} is {model.id} with {model.downloads:,} downloads."
 
45
  try:
46
  # List models from the specified author, sorted by downloads
47
  models = list(list_models(author=author, sort="downloads", direction=-1, limit=1))
48
+
49
  if models:
50
  model = models[0]
51
  return f"The most downloaded model by {author} is {model.id} with {model.downloads:,} downloads."