D3MI4N commited on
Commit
06c6a48
·
1 Parent(s): 52762da

replace local config to use HF UI secret variable for search tool

Browse files
Files changed (2) hide show
  1. qa_graph.py +7 -9
  2. tools/search_tool.py +6 -4
qa_graph.py CHANGED
@@ -1,5 +1,3 @@
1
- # qa_graph.py
2
-
3
  from typing import TypedDict
4
  import re
5
  from langgraph.graph import StateGraph, START, END
@@ -7,12 +5,12 @@ from tools.calculator_tool import calculator_tool
7
  from tools.search_tool import search_tool
8
  from transformers import pipeline
9
 
10
- # 1) Define the shape of our state
11
  class QAState(TypedDict):
12
  question: str # incoming question
13
  answer: str # to store tool output or synthesized answer
14
 
15
- # 2) Use text2text-generation for T5 models like flan-t5
16
  synthesizer = pipeline(
17
  "text2text-generation",
18
  model="google/flan-t5-small",
@@ -23,16 +21,16 @@ synthesizer = pipeline(
23
  temperature=0.7
24
  )
25
 
26
- # 3) Tool agent: calculator for math, search for other
27
  def QAAgent(state: QAState) -> QAState:
28
  q = state["question"].strip()
29
  if re.fullmatch(r"[0-9\s\+\-\*\/\.\(\)]+", q):
30
  state["answer"] = calculator_tool.invoke(q)
31
  else:
32
- state["answer"] = search_tool.invoke(q) # update to `.invoke(q)` only if search_tool is a LangChain tool
33
  return state
34
 
35
- # 4) Synthesis agent to generate final response
36
  def SynthesisAgent(state: QAState) -> QAState:
37
  question = state["question"]
38
  tool_out = state["answer"]
@@ -46,7 +44,7 @@ def SynthesisAgent(state: QAState) -> QAState:
46
  state["answer"] = completion
47
  return state
48
 
49
- # 5) Build the graph: START -> QAAgent -> SynthesisAgent -> END
50
  builder = StateGraph(QAState)
51
  builder.set_entry_point("QAAgent")
52
  builder.add_node("QAAgent", QAAgent)
@@ -58,7 +56,7 @@ builder.add_edge("SynthesisAgent", END)
58
 
59
  graph = builder.compile()
60
 
61
- # 6) Local testing
62
  if __name__ == "__main__":
63
  # Math example
64
  s1: QAState = {"question": "2 + 2", "answer": ""}
 
 
 
1
  from typing import TypedDict
2
  import re
3
  from langgraph.graph import StateGraph, START, END
 
5
  from tools.search_tool import search_tool
6
  from transformers import pipeline
7
 
8
+ # Shape of the state
9
  class QAState(TypedDict):
10
  question: str # incoming question
11
  answer: str # to store tool output or synthesized answer
12
 
13
+ # Use text2text-generation for llm model
14
  synthesizer = pipeline(
15
  "text2text-generation",
16
  model="google/flan-t5-small",
 
21
  temperature=0.7
22
  )
23
 
24
+ # Tool agent: calculator for math, search for other
25
  def QAAgent(state: QAState) -> QAState:
26
  q = state["question"].strip()
27
  if re.fullmatch(r"[0-9\s\+\-\*\/\.\(\)]+", q):
28
  state["answer"] = calculator_tool.invoke(q)
29
  else:
30
+ state["answer"] = search_tool.invoke(q)
31
  return state
32
 
33
+ # Synthesis agent to generate final response
34
  def SynthesisAgent(state: QAState) -> QAState:
35
  question = state["question"]
36
  tool_out = state["answer"]
 
44
  state["answer"] = completion
45
  return state
46
 
47
+ # Build the graph
48
  builder = StateGraph(QAState)
49
  builder.set_entry_point("QAAgent")
50
  builder.add_node("QAAgent", QAAgent)
 
56
 
57
  graph = builder.compile()
58
 
59
+ # Local testing
60
  if __name__ == "__main__":
61
  # Math example
62
  s1: QAState = {"question": "2 + 2", "answer": ""}
tools/search_tool.py CHANGED
@@ -1,6 +1,8 @@
1
  from tavily import TavilyClient
2
  from langchain.tools import tool
3
- from config import TAVILY_API_KEY
 
 
4
 
5
  class SearchTool:
6
  def __init__(self, api_key: str):
@@ -8,16 +10,16 @@ class SearchTool:
8
 
9
  def search(self, query: str):
10
  response = self.client.search(query)
11
- # Extract a string summary of results (you can adapt this as needed)
12
  results = response.get("results", [])
13
  if not results:
14
  return "No results found."
15
- # For simplicity, join first 3 results' titles or snippets
16
  summaries = [res.get("title", "") or res.get("snippet", "") for res in results[:3]]
17
  return " | ".join(summaries)
18
 
19
  search_tool_instance = SearchTool(api_key=TAVILY_API_KEY)
20
 
21
- @tool(description="Use this tool to search for information on the web using Tavily API and return a summary of results.")
22
  def search_tool(query: str) -> str:
23
  return search_tool_instance.search(query)
 
1
  from tavily import TavilyClient
2
  from langchain.tools import tool
3
+ # from config import TAVILY_API_KEY
4
+ import os
5
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
6
 
7
  class SearchTool:
8
  def __init__(self, api_key: str):
 
10
 
11
  def search(self, query: str):
12
  response = self.client.search(query)
13
+ # Extract a string summary of results
14
  results = response.get("results", [])
15
  if not results:
16
  return "No results found."
17
+ # For simplicity, joining first 3 results' titles or snippets
18
  summaries = [res.get("title", "") or res.get("snippet", "") for res in results[:3]]
19
  return " | ".join(summaries)
20
 
21
  search_tool_instance = SearchTool(api_key=TAVILY_API_KEY)
22
 
23
+ @tool(description="Tool to search for information on the web using Tavily API and return a summary of results.")
24
  def search_tool(query: str) -> str:
25
  return search_tool_instance.search(query)