ernani commited on
Commit
953b948
·
1 Parent(s): 8507438

improving retriever BM25 - using sentence transformers - adding memory management tool - combine with websearch - integrate multiple indexes

Browse files
Files changed (3) hide show
  1. app.py +60 -10
  2. retriever.py +14 -5
  3. tools.py +25 -3
app.py CHANGED
@@ -1,13 +1,29 @@
1
  import gradio as gr
2
  import random
3
- from smolagents import GradioUI, CodeAgent, HfApiModel
4
 
5
- # Import our custom tools from their modules
6
- from tools import DuckDuckGoSearchTool, WeatherInfoTool, HubStatsTool
 
 
 
 
 
 
7
  from retriever import load_guest_dataset
8
 
9
  # Initialize the Hugging Face model
10
- model = HfApiModel()
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Initialize the web search tool
13
  search_tool = DuckDuckGoSearchTool()
@@ -21,13 +37,47 @@ hub_stats_tool = HubStatsTool()
21
  # Load the guest dataset and initialize the guest info tool
22
  guest_info_tool = load_guest_dataset()
23
 
24
- # Create Alfred with all the tools
25
- alfred = CodeAgent(
26
- tools=[guest_info_tool, weather_info_tool, hub_stats_tool, search_tool],
27
- model=model,
28
- add_base_tools=True, # Add any additional base tools
29
- planning_interval=3 # Enable planning every 3 steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  )
31
 
 
 
 
32
  if __name__ == "__main__":
33
  GradioUI(alfred).launch()
 
1
  import gradio as gr
2
  import random
3
+ from typing import TypedDict, Annotated
4
 
5
+ from langgraph.graph.message import add_messages
6
+ from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
7
+ from langgraph.prebuilt import ToolNode
8
+ from langgraph.graph import START, StateGraph
9
+ from langgraph.prebuilt import tools_condition
10
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
11
+
12
+ from tools import DuckDuckGoSearchTool, WeatherInfoTool, HubStatsTool, MemoryManagementTool
13
  from retriever import load_guest_dataset
14
 
15
  # Initialize the Hugging Face model
16
+ llm = HuggingFaceEndpoint(
17
+ repo_id="Qwen/Qwen2.5-Coder-32B-Instruct",
18
+ huggingfacehub_api_token=HF_TOKEN,
19
+ )
20
+
21
+ chat = ChatHuggingFace(llm=llm, verbose=True)
22
+ tools = [guest_info_tool, weather_info_tool, hub_stats_tool, search_tool, memory_management_tool]
23
+ chat_with_tools = chat.bind_tools(tools)
24
+
25
+ # Initialize the memory store
26
+ memory_management_tool = MemoryManagementTool()
27
 
28
  # Initialize the web search tool
29
  search_tool = DuckDuckGoSearchTool()
 
37
  # Load the guest dataset and initialize the guest info tool
38
  guest_info_tool = load_guest_dataset()
39
 
40
+
41
+ # Generate the AgentState and Agent graph
42
+ class AgentState(TypedDict):
43
+ messages: Annotated[list[AnyMessage], add_messages]
44
+
45
+ def assistant(state: AgentState):
46
+ # Retrieve past messages from memory
47
+ past_messages = memory_store.retrieve(state["messages"])
48
+ # Add new messages to memory
49
+ memory_store.add(state["messages"])
50
+
51
+ # Check if the query is about an unfamiliar guest
52
+ if "guest" in state["messages"][-1].content:
53
+ search_results = search_tool.forward(state["messages"][-1].content)
54
+ return {
55
+ "messages": [search_results]
56
+ }
57
+ else:
58
+ return {
59
+ "messages": [chat_with_tools.invoke(state["messages"] + past_messages)],
60
+ }
61
+
62
+ ## The graph
63
+ builder = StateGraph(AgentState)
64
+
65
+ # Define nodes: these do the work
66
+ builder.add_node("assistant", assistant)
67
+ builder.add_node("tools", ToolNode(tools))
68
+
69
+ # Define the graph
70
+
71
+ builder.add_edge(START, "assistant")
72
+ builder.add_conditional_edges(
73
+ "assistant",
74
+ # If the latest message requires a tool, route to tools
75
+ # Otherwise, provide a direct response
76
+ tools_condition,
77
  )
78
 
79
+ builder.add_edge("tools", "assistant")
80
+ alfred = builder.compile()
81
+
82
  if __name__ == "__main__":
83
  GradioUI(alfred).launch()
retriever.py CHANGED
@@ -1,7 +1,9 @@
1
  from smolagents import Tool
2
- from langchain_community.retrievers import BM25Retriever
3
  from langchain.docstore.document import Document
4
  import datasets
 
 
5
 
6
 
7
  class GuestInfoRetrieverTool(Tool):
@@ -17,13 +19,20 @@ class GuestInfoRetrieverTool(Tool):
17
 
18
  def __init__(self, docs):
19
  self.is_initialized = False
20
- self.retriever = BM25Retriever.from_documents(docs)
21
-
 
 
22
 
23
  def forward(self, query: str):
24
- results = self.retriever.get_relevant_documents(query)
 
 
 
 
 
25
  if results:
26
- return "\n\n".join([doc.page_content for doc in results[:3]])
27
  else:
28
  return "No matching guest information found."
29
 
 
1
  from smolagents import Tool
2
+ # from langchain_community.retrievers import BM25Retriever
3
  from langchain.docstore.document import Document
4
  import datasets
5
+ from sentence_transformers import SentenceTransformer
6
+ import torch
7
 
8
 
9
  class GuestInfoRetrieverTool(Tool):
 
19
 
20
  def __init__(self, docs):
21
  self.is_initialized = False
22
+ # Use sentence-transformers for embeddings
23
+ self.model = SentenceTransformer('all-MiniLM-L6-v2')
24
+ self.embeddings = self.model.encode([doc.page_content for doc in docs], convert_to_tensor=True)
25
+ self.docs = docs
26
 
27
  def forward(self, query: str):
28
+ query_embedding = self.model.encode(query, convert_to_tensor=True)
29
+ # Compute cosine similarities
30
+ similarities = torch.nn.functional.cosine_similarity(query_embedding, self.embeddings)
31
+ # Get the top 3 most similar documents
32
+ top_k = torch.topk(similarities, k=3)
33
+ results = [self.docs[i] for i in top_k.indices]
34
  if results:
35
+ return "\n\n".join([doc.page_content for doc in results])
36
  else:
37
  return "No matching guest information found."
38
 
tools.py CHANGED
@@ -2,11 +2,13 @@ from smolagents import DuckDuckGoSearchTool
2
  from smolagents import Tool
3
  import random
4
  from huggingface_hub import list_models
5
-
6
 
7
  # Initialize the DuckDuckGo search tool
8
- #search_tool = DuckDuckGoSearchTool()
9
-
 
 
10
 
11
  class WeatherInfoTool(Tool):
12
  name = "weather_info"
@@ -54,3 +56,23 @@ class HubStatsTool(Tool):
54
  except Exception as e:
55
  return f"Error fetching models for {author}: {str(e)}"
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from smolagents import Tool
3
  import random
4
  from huggingface_hub import list_models
5
+ from langgraph.store.memory import InMemoryStore
6
 
7
  # Initialize the DuckDuckGo search tool
8
+ search_tool = DuckDuckGoSearchTool()
9
+ store = InMemoryStore(
10
+ index={"embed": "openai:text-embedding-3-small"}
11
+ )
12
 
13
  class WeatherInfoTool(Tool):
14
  name = "weather_info"
 
56
  except Exception as e:
57
  return f"Error fetching models for {author}: {str(e)}"
58
 
59
+ class MemoryManagementTool(Tool):
60
+ name = "memory_management"
61
+ description = "Manages and queries conversation memory."
62
+ inputs = {
63
+ "query": {
64
+ "type": "string",
65
+ "description": "The query to search in memory."
66
+ }
67
+ }
68
+ output_type = "string"
69
+
70
+ def forward(self, query: str):
71
+ # Retrieve relevant memory entries
72
+ results = store.retrieve(query)
73
+ if results:
74
+ return "\n\n".join(results)
75
+ else:
76
+ return "No relevant memory found."
77
+
78
+