dlaima commited on
Commit
9059902
·
verified ·
1 Parent(s): 2b751c4

Update retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +23 -24
retriever.py CHANGED
@@ -4,9 +4,7 @@ from langchain.docstore.document import Document
4
  import datasets
5
  import pandas as pd
6
  import os
7
- import logging
8
 
9
- logging.basicConfig(level=logging.INFO)
10
 
11
  class GuestInfoRetrieverTool(Tool):
12
  name = "guest_info_retriever"
@@ -19,10 +17,16 @@ class GuestInfoRetrieverTool(Tool):
19
  }
20
  output_type = "string"
21
 
 
 
 
 
22
  def __init__(self, docs):
23
- self.docs = docs
 
24
  self.retriever = BM25Retriever.from_documents(docs)
25
 
 
26
  def _generate_conversation_starter(self, doc: Document):
27
  lines = doc.page_content.splitlines()
28
  name = None
@@ -43,22 +47,24 @@ class GuestInfoRetrieverTool(Tool):
43
  else:
44
  return "Try asking about their background—it sounds fascinating!"
45
 
46
- def forward(self, query: str):
47
- query = query.strip().lower()
 
 
 
 
48
 
49
- # Handle guest listing queries
50
- guest_list_keywords = [
51
- "list guests", "guest names", "list all guests",
52
- "show guests", "all guests", "everyone invited"
53
- ]
54
- if any(keyword in query for keyword in guest_list_keywords):
55
- logging.info("Listing all guests from provided dataset.")
56
  return "\n".join([doc.metadata.get("name", "Unknown") for doc in self.docs])
57
 
58
- # Fallback to BM25 search
 
59
  results = self.retriever.get_relevant_documents(query)
60
  if results:
61
  responses = []
 
62
  for doc in results[:10]:
63
  content = doc.page_content
64
  starter = self._generate_conversation_starter(doc)
@@ -69,14 +75,13 @@ class GuestInfoRetrieverTool(Tool):
69
 
70
 
71
  def load_guest_dataset(file_path: str = None, show_example: bool = True):
 
72
  """
73
- Loads a guest dataset from a CSV/JSON file or from Hugging Face if no file provided.
74
- Ensures necessary columns exist.
75
- Returns a Tool that can search for guest info.
76
  """
77
  if file_path and os.path.exists(file_path):
78
  ext = os.path.splitext(file_path)[1].lower()
79
- logging.info(f"📁 Loading guest data from: {file_path}")
80
  if ext == ".csv":
81
  df = pd.read_csv(file_path)
82
  elif ext == ".json":
@@ -84,17 +89,12 @@ def load_guest_dataset(file_path: str = None, show_example: bool = True):
84
  else:
85
  raise ValueError("Unsupported file format. Use .csv or .json.")
86
  else:
87
- logging.info("📡 Loading default guest data from Hugging Face.")
88
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
89
  df = pd.DataFrame(guest_dataset)
90
  if show_example:
91
  print("\n📌 Example guest from Hugging Face dataset:\n")
92
  print(df.head(1).to_markdown(index=False))
93
 
94
- required_columns = {"name", "relation", "description", "email"}
95
- if not required_columns.issubset(df.columns):
96
- raise ValueError(f"Missing required columns. Expected: {required_columns}")
97
-
98
  docs = [
99
  Document(
100
  page_content="\n".join([
@@ -108,5 +108,4 @@ def load_guest_dataset(file_path: str = None, show_example: bool = True):
108
  for _, row in df.iterrows()
109
  ]
110
 
111
- logging.info(f"✅ Loaded {len(docs)} guests into retrieval tool.")
112
- return GuestInfoRetrieverTool(docs)
 
4
  import datasets
5
  import pandas as pd
6
  import os
 
7
 
 
8
 
9
  class GuestInfoRetrieverTool(Tool):
10
  name = "guest_info_retriever"
 
17
  }
18
  output_type = "string"
19
 
20
+ #def __init__(self, docs):
21
+ #self.is_initialized = False
22
+ #self.retriever = BM25Retriever.from_documents(docs)
23
+
24
  def __init__(self, docs):
25
+ self.is_initialized = False
26
+ self.docs = docs # 🔁 store the original list manually
27
  self.retriever = BM25Retriever.from_documents(docs)
28
 
29
+
30
  def _generate_conversation_starter(self, doc: Document):
31
  lines = doc.page_content.splitlines()
32
  name = None
 
47
  else:
48
  return "Try asking about their background—it sounds fascinating!"
49
 
50
+ #def forward(self, query: str):
51
+ # Handle special case for full guest listing
52
+ #if "list" in query.lower() and "guest" in query.lower() and "name" in query.lower():
53
+ #return "\n".join([
54
+ #doc.metadata.get("name", "Unknown") for doc in self.retriever.docs
55
+ #])
56
 
57
+ def forward(self, query: str):
58
+ if any(keyword in query.lower() for keyword in ["list guests", "guest names", "list all guests", "show guests", "all guests", "everyone invited"]):
59
+ #if "list" in query.lower() and "guest" in query.lower() and "name" in query.lower():
 
 
 
 
60
  return "\n".join([doc.metadata.get("name", "Unknown") for doc in self.docs])
61
 
62
+
63
+ # Default BM25 retrieval
64
  results = self.retriever.get_relevant_documents(query)
65
  if results:
66
  responses = []
67
+ #for doc in results[:3]:
68
  for doc in results[:10]:
69
  content = doc.page_content
70
  starter = self._generate_conversation_starter(doc)
 
75
 
76
 
77
  def load_guest_dataset(file_path: str = None, show_example: bool = True):
78
+
79
  """
80
+ Loads guest dataset either from a file (CSV/JSON) or the Hugging Face default dataset.
81
+ If using the Hugging Face dataset, optionally prints a preview example.
 
82
  """
83
  if file_path and os.path.exists(file_path):
84
  ext = os.path.splitext(file_path)[1].lower()
 
85
  if ext == ".csv":
86
  df = pd.read_csv(file_path)
87
  elif ext == ".json":
 
89
  else:
90
  raise ValueError("Unsupported file format. Use .csv or .json.")
91
  else:
 
92
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
93
  df = pd.DataFrame(guest_dataset)
94
  if show_example:
95
  print("\n📌 Example guest from Hugging Face dataset:\n")
96
  print(df.head(1).to_markdown(index=False))
97
 
 
 
 
 
98
  docs = [
99
  Document(
100
  page_content="\n".join([
 
108
  for _, row in df.iterrows()
109
  ]
110
 
111
+ return GuestInfoRetrieverTool(docs)