ernani commited on
Commit
e7c51fb
·
1 Parent(s): 857e8c6

updating retriever to use langgraph

Browse files
Files changed (1) hide show
  1. retriever.py +27 -38
retriever.py CHANGED
@@ -1,44 +1,11 @@
1
- from smolagents import Tool
2
- # from langchain_community.retrievers import BM25Retriever
3
  from langchain.docstore.document import Document
4
- import datasets
5
  from sentence_transformers import SentenceTransformer
6
  import torch
 
7
 
8
-
9
- class GuestInfoRetrieverTool(Tool):
10
- name = "guest_info_retriever"
11
- description = "Retrieves detailed information about gala guests based on their name or relation."
12
- inputs = {
13
- "query": {
14
- "type": "string",
15
- "description": "The name or relation of the guest you want information about."
16
- }
17
- }
18
- output_type = "string"
19
-
20
- def __init__(self, docs):
21
- self.is_initialized = False
22
- # Use sentence-transformers for embeddings
23
- self.model = SentenceTransformer('all-MiniLM-L6-v2')
24
- self.embeddings = self.model.encode([doc.page_content for doc in docs], convert_to_tensor=True)
25
- self.docs = docs
26
-
27
- def forward(self, query: str):
28
- query_embedding = self.model.encode(query, convert_to_tensor=True)
29
- # Compute cosine similarities
30
- similarities = torch.nn.functional.cosine_similarity(query_embedding, self.embeddings)
31
- # Get the top 3 most similar documents
32
- top_k = torch.topk(similarities, k=3)
33
- results = [self.docs[i] for i in top_k.indices]
34
- if results:
35
- return "\n\n".join([doc.page_content for doc in results])
36
- else:
37
- return "No matching guest information found."
38
-
39
-
40
  def load_guest_dataset():
41
- # Load the dataset
42
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
43
 
44
  # Convert dataset entries into Document objects
@@ -55,8 +22,30 @@ def load_guest_dataset():
55
  for guest in guest_dataset
56
  ]
57
 
58
- # Return the tool
59
- return GuestInfoRetrieverTool(docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
 
 
1
+ from langchain.tools import Tool
 
2
  from langchain.docstore.document import Document
 
3
  from sentence_transformers import SentenceTransformer
4
  import torch
5
+ import datasets
6
 
7
+ # Load the dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def load_guest_dataset():
 
9
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
10
 
11
  # Convert dataset entries into Document objects
 
22
  for guest in guest_dataset
23
  ]
24
 
25
+ # Initialize the sentence-transformers model
26
+ model = SentenceTransformer('all-MiniLM-L6-v2')
27
+ embeddings = model.encode([doc.page_content for doc in docs], convert_to_tensor=True)
28
+
29
+ # Define the extraction function
30
+ def extract_text(query: str) -> str:
31
+ """Retrieves detailed information about gala guests based on their name or relation."""
32
+ query_embedding = model.encode(query, convert_to_tensor=True)
33
+ similarities = torch.nn.functional.cosine_similarity(query_embedding, embeddings)
34
+ top_k = torch.topk(similarities, k=3)
35
+ results = [docs[i] for i in top_k.indices]
36
+ if results:
37
+ return "\n\n".join([doc.page_content for doc in results])
38
+ else:
39
+ return "No matching guest information found."
40
+
41
+ # Create the tool
42
+ guest_info_tool = Tool(
43
+ name="guest_info_retriever",
44
+ func=extract_text,
45
+ description="Retrieves detailed information about gala guests based on their name or relation."
46
+ )
47
+
48
+ return guest_info_tool
49
 
50
 
51