mrpe24 commited on
Commit
ad2b3b2
·
1 Parent(s): 8507438

Implemented RAG

Browse files
Files changed (3) hide show
  1. app.py +10 -6
  2. requirements.txt +2 -2
  3. retriever.py +27 -28
app.py CHANGED
@@ -1,10 +1,10 @@
1
- import gradio as gr
2
- import random
3
  from smolagents import GradioUI, CodeAgent, HfApiModel
4
 
 
5
  # Import our custom tools from their modules
6
  from tools import DuckDuckGoSearchTool, WeatherInfoTool, HubStatsTool
7
- from retriever import load_guest_dataset
8
 
9
  # Initialize the Hugging Face model
10
  model = HfApiModel()
@@ -19,15 +19,19 @@ weather_info_tool = WeatherInfoTool()
19
  hub_stats_tool = HubStatsTool()
20
 
21
  # Load the guest dataset and initialize the guest info tool
22
- guest_info_tool = load_guest_dataset()
 
 
 
 
23
 
24
  # Create Alfred with all the tools
25
  alfred = CodeAgent(
26
  tools=[guest_info_tool, weather_info_tool, hub_stats_tool, search_tool],
27
  model=model,
28
- add_base_tools=True, # Add any additional base tools
29
  planning_interval=3 # Enable planning every 3 steps
30
  )
31
 
32
  if __name__ == "__main__":
33
- GradioUI(alfred).launch()
 
1
+ import chromadb
2
+ from chromadb.utils import embedding_functions
3
  from smolagents import GradioUI, CodeAgent, HfApiModel
4
 
5
+ from retriever import load_guest_dataset
6
  # Import our custom tools from their modules
7
  from tools import DuckDuckGoSearchTool, WeatherInfoTool, HubStatsTool
 
8
 
9
  # Initialize the Hugging Face model
10
  model = HfApiModel()
 
19
  hub_stats_tool = HubStatsTool()
20
 
21
  # Load the guest dataset and initialize the guest info tool
22
+ chroma_client = chromadb.Client()
23
+ default_ef = embedding_functions.DefaultEmbeddingFunction()
24
+ collection = chroma_client.create_collection(name="guests-collection", embedding_function=default_ef)
25
+
26
+ guest_info_tool = load_guest_dataset(collection)
27
 
28
  # Create Alfred with all the tools
29
  alfred = CodeAgent(
30
  tools=[guest_info_tool, weather_info_tool, hub_stats_tool, search_tool],
31
  model=model,
32
+ add_base_tools=False, # Add any additional base tools
33
  planning_interval=3 # Enable planning every 3 steps
34
  )
35
 
36
  if __name__ == "__main__":
37
+ GradioUI(alfred).launch()
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  datasets
2
  smolagents
3
- langchain-community
4
- rank_bm25
 
1
  datasets
2
  smolagents
3
+ langchain-community
4
+ langchain-huggingface
retriever.py CHANGED
@@ -1,7 +1,7 @@
1
- from smolagents import Tool
2
- from langchain_community.retrievers import BM25Retriever
3
- from langchain.docstore.document import Document
4
  import datasets
 
5
 
6
 
7
  class GuestInfoRetrieverTool(Tool):
@@ -15,39 +15,38 @@ class GuestInfoRetrieverTool(Tool):
15
  }
16
  output_type = "string"
17
 
18
- def __init__(self, docs):
19
  self.is_initialized = False
20
- self.retriever = BM25Retriever.from_documents(docs)
21
-
22
 
23
  def forward(self, query: str):
24
- results = self.retriever.get_relevant_documents(query)
25
- if results:
26
- return "\n\n".join([doc.page_content for doc in results[:3]])
27
- else:
28
- return "No matching guest information found."
 
 
 
29
 
30
 
31
- def load_guest_dataset():
32
  # Load the dataset
33
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
34
 
35
  # Convert dataset entries into Document objects
36
- docs = [
37
- Document(
38
- page_content="\n".join([
39
- f"Name: {guest['name']}",
40
- f"Relation: {guest['relation']}",
41
- f"Description: {guest['description']}",
42
- f"Email: {guest['email']}"
43
- ]),
44
- metadata={"name": guest["name"]}
45
- )
46
- for guest in guest_dataset
47
- ]
48
 
49
  # Return the tool
50
- return GuestInfoRetrieverTool(docs)
51
-
52
-
53
-
 
1
+ from uuid import uuid4
2
+
 
3
  import datasets
4
+ from smolagents import Tool
5
 
6
 
7
  class GuestInfoRetrieverTool(Tool):
 
15
  }
16
  output_type = "string"
17
 
18
+ def __init__(self, vector_store):
19
  self.is_initialized = False
20
+ self.vector_store = vector_store
 
21
 
22
  def forward(self, query: str):
23
+ result = self.vector_store.query(
24
+ query_texts=[query],
25
+ n_results=3
26
+ )
27
+
28
+ distances = [distance for distance in result['distances'][0] if distance < 1.3]
29
+ docs = result['documents'][0]
30
+ return "\n\n".join([docs[idx] for idx in range(0, len(distances))])
31
 
32
 
33
+ def load_guest_dataset(vector_store):
34
  # Load the dataset
35
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
36
 
37
  # Convert dataset entries into Document objects
38
+ for guest in guest_dataset:
39
+ vector_store.add(
40
+ documents=[
41
+ "\n".join([
42
+ f"Name: {guest['name']}",
43
+ f"Relation: {guest['relation']}",
44
+ f"Description: {guest['description']}",
45
+ f"Email: {guest['email']}"
46
+ ])
47
+ ],
48
+ metadatas=[{"name": guest["name"]}],
49
+ ids=[str(uuid4())])
50
 
51
  # Return the tool
52
+ return GuestInfoRetrieverTool(vector_store)