junaid17 commited on
Commit
3e10e7d
·
verified ·
1 Parent(s): 2b77321

Update tools.py

Browse files
Files changed (1) hide show
  1. tools.py +16 -6
tools.py CHANGED
@@ -25,10 +25,13 @@ retriever = None
25
 
26
 
27
  def load_retriever():
 
28
  global retriever
29
  try:
30
  embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
31
- if os.path.exists(os.path.join(VECTORSTORE_DIR, "index.faiss")):
 
 
32
  vectorstore = FAISS.load_local(
33
  VECTORSTORE_DIR,
34
  embeddings,
@@ -41,6 +44,7 @@ def load_retriever():
41
 
42
 
43
  def build_vectorstore(path: str):
 
44
  loader = PyPDFLoader(path)
45
  docs = loader.load()
46
 
@@ -50,22 +54,23 @@ def build_vectorstore(path: str):
50
  )
51
 
52
  split_docs = splitter.split_documents(docs)
53
-
54
  embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
55
- vectorstore = FAISS.from_documents(split_docs, embeddings)
56
 
 
57
  vectorstore.save_local(VECTORSTORE_DIR)
 
58
  return vectorstore
59
 
60
 
61
  def update_retriever(pdf_path: str):
 
62
  global retriever
63
  vectorstore = build_vectorstore(pdf_path)
64
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
65
 
66
 
67
  # ==============================
68
- # RAG TOOL (FIXED)
69
  # ==============================
70
  def create_rag_tool():
71
 
@@ -73,6 +78,7 @@ def create_rag_tool():
73
  def rag_search(query: str) -> str:
74
  """
75
  Retrieve relevant information from uploaded documents.
 
76
  """
77
  global retriever
78
 
@@ -92,10 +98,13 @@ def create_rag_tool():
92
  return rag_search
93
 
94
 
95
- # ---------------- OTHER TOOLS ---------------- #
 
 
96
 
97
  @tool
98
  def arxiv_search(query: str) -> dict:
 
99
  try:
100
  arxiv = ArxivQueryRun(api_wrapper=ArxivAPIWrapper())
101
  return {"results": arxiv.run(query)}
@@ -105,6 +114,7 @@ def arxiv_search(query: str) -> dict:
105
 
106
  @tool
107
  def wikipedia_search(query: str) -> dict:
 
108
  try:
109
  wiki = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
110
  return {"results": wiki.run(query)}
@@ -114,9 +124,9 @@ def wikipedia_search(query: str) -> dict:
114
 
115
  @tool
116
  def tavily_search(query: str) -> dict:
 
117
  try:
118
  search = TavilySearchResults(max_results=5)
119
  return {"results": search.run(query)}
120
  except Exception as e:
121
  return {"error": str(e)}
122
-
 
25
 
26
 
27
  def load_retriever():
28
+ """Load FAISS retriever from disk if available."""
29
  global retriever
30
  try:
31
  embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
32
+ index_path = os.path.join(VECTORSTORE_DIR, "index.faiss")
33
+
34
+ if os.path.exists(index_path):
35
  vectorstore = FAISS.load_local(
36
  VECTORSTORE_DIR,
37
  embeddings,
 
44
 
45
 
46
  def build_vectorstore(path: str):
47
+ """Build FAISS vector store from uploaded PDF."""
48
  loader = PyPDFLoader(path)
49
  docs = loader.load()
50
 
 
54
  )
55
 
56
  split_docs = splitter.split_documents(docs)
 
57
  embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
 
58
 
59
+ vectorstore = FAISS.from_documents(split_docs, embeddings)
60
  vectorstore.save_local(VECTORSTORE_DIR)
61
+
62
  return vectorstore
63
 
64
 
65
  def update_retriever(pdf_path: str):
66
+ """Update retriever after document upload."""
67
  global retriever
68
  vectorstore = build_vectorstore(pdf_path)
69
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
70
 
71
 
72
  # ==============================
73
+ # RAG TOOL
74
  # ==============================
75
  def create_rag_tool():
76
 
 
78
  def rag_search(query: str) -> str:
79
  """
80
  Retrieve relevant information from uploaded documents.
81
+ Uses FAISS-based semantic search.
82
  """
83
  global retriever
84
 
 
98
  return rag_search
99
 
100
 
101
+ # ==============================
102
+ # EXTERNAL TOOLS
103
+ # ==============================
104
 
105
  @tool
106
  def arxiv_search(query: str) -> dict:
107
+ """Search academic papers from arXiv."""
108
  try:
109
  arxiv = ArxivQueryRun(api_wrapper=ArxivAPIWrapper())
110
  return {"results": arxiv.run(query)}
 
114
 
115
  @tool
116
  def wikipedia_search(query: str) -> dict:
117
+ """Search Wikipedia for relevant information."""
118
  try:
119
  wiki = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
120
  return {"results": wiki.run(query)}
 
124
 
125
  @tool
126
  def tavily_search(query: str) -> dict:
127
+ """Search the web using Tavily."""
128
  try:
129
  search = TavilySearchResults(max_results=5)
130
  return {"results": search.run(query)}
131
  except Exception as e:
132
  return {"error": str(e)}