junaid17 commited on
Commit
367fd43
·
verified ·
1 Parent(s): 3e10e7d

Update tools.py

Browse files
Files changed (1) hide show
  1. tools.py +31 -38
tools.py CHANGED
@@ -8,7 +8,6 @@ from langchain_community.utilities import WikipediaAPIWrapper, ArxivAPIWrapper
8
  from langchain_community.tools.tavily_search import TavilySearchResults
9
  from dotenv import load_dotenv
10
  import os
11
- import requests
12
 
13
  load_dotenv()
14
 
@@ -18,15 +17,13 @@ load_dotenv()
18
  VECTORSTORE_DIR = "data/vectorstore"
19
  os.makedirs(VECTORSTORE_DIR, exist_ok=True)
20
 
21
- # ==============================
22
- # GLOBAL RETRIEVER
23
- # ==============================
24
  retriever = None
25
 
26
 
27
  def load_retriever():
28
  """Load FAISS retriever from disk if available."""
29
  global retriever
 
30
  try:
31
  embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
32
  index_path = os.path.join(VECTORSTORE_DIR, "index.faiss")
@@ -35,16 +32,22 @@ def load_retriever():
35
  vectorstore = FAISS.load_local(
36
  VECTORSTORE_DIR,
37
  embeddings,
38
- allow_dangerous_deserialization=True
39
  )
40
- retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
41
- print("✅ Vectorstore loaded from disk")
 
 
 
42
  except Exception as e:
43
- print("❌ Failed to load vectorstore:", e)
 
 
 
 
44
 
45
 
46
  def build_vectorstore(path: str):
47
- """Build FAISS vector store from uploaded PDF."""
48
  loader = PyPDFLoader(path)
49
  docs = loader.load()
50
 
@@ -53,20 +56,18 @@ def build_vectorstore(path: str):
53
  chunk_overlap=100
54
  )
55
 
56
- split_docs = splitter.split_documents(docs)
57
  embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
58
 
59
- vectorstore = FAISS.from_documents(split_docs, embeddings)
60
  vectorstore.save_local(VECTORSTORE_DIR)
61
 
62
  return vectorstore
63
 
64
 
65
- def update_retriever(pdf_path: str):
66
- """Update retriever after document upload."""
67
  global retriever
68
- vectorstore = build_vectorstore(pdf_path)
69
- retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
70
 
71
 
72
  # ==============================
@@ -76,48 +77,41 @@ def create_rag_tool():
76
 
77
  @tool
78
  def rag_search(query: str) -> str:
79
- """
80
- Retrieve relevant information from uploaded documents.
81
- Uses FAISS-based semantic search.
82
- """
83
- global retriever
84
 
85
- if retriever is None:
86
- load_retriever()
87
 
88
  if retriever is None:
89
- return "No document has been uploaded yet."
90
 
91
  docs = retriever.invoke(query)
92
 
93
  if not docs:
94
- return "No relevant information found in the uploaded document."
95
 
96
  return "\n\n".join(d.page_content for d in docs)
97
 
98
  return rag_search
99
 
100
 
101
- # ==============================
102
- # EXTERNAL TOOLS
103
- # ==============================
104
 
105
  @tool
106
- def arxiv_search(query: str) -> dict:
107
- """Search academic papers from arXiv."""
108
  try:
109
- arxiv = ArxivQueryRun(api_wrapper=ArxivAPIWrapper())
110
- return {"results": arxiv.run(query)}
111
  except Exception as e:
112
  return {"error": str(e)}
113
 
114
 
115
  @tool
116
- def wikipedia_search(query: str) -> dict:
117
- """Search Wikipedia for relevant information."""
118
  try:
119
- wiki = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
120
- return {"results": wiki.run(query)}
121
  except Exception as e:
122
  return {"error": str(e)}
123
 
@@ -126,7 +120,6 @@ def wikipedia_search(query: str) -> dict:
126
  def tavily_search(query: str) -> dict:
127
  """Search the web using Tavily."""
128
  try:
129
- search = TavilySearchResults(max_results=5)
130
- return {"results": search.run(query)}
131
  except Exception as e:
132
- return {"error": str(e)}
 
8
  from langchain_community.tools.tavily_search import TavilySearchResults
9
  from dotenv import load_dotenv
10
  import os
 
11
 
12
  load_dotenv()
13
 
 
17
  VECTORSTORE_DIR = "data/vectorstore"
18
  os.makedirs(VECTORSTORE_DIR, exist_ok=True)
19
 
 
 
 
20
  retriever = None
21
 
22
 
23
  def load_retriever():
24
  """Load FAISS retriever from disk if available."""
25
  global retriever
26
+
27
  try:
28
  embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
29
  index_path = os.path.join(VECTORSTORE_DIR, "index.faiss")
 
32
  vectorstore = FAISS.load_local(
33
  VECTORSTORE_DIR,
34
  embeddings,
35
+ allow_dangerous_deserialization=True,
36
  )
37
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
38
+ print("✅ Retriever loaded successfully")
39
+ else:
40
+ print("⚠️ No vectorstore found yet")
41
+
42
  except Exception as e:
43
+ print("❌ Retriever load error:", e)
44
+
45
+
46
+ # Load on startup
47
+ load_retriever()
48
 
49
 
50
  def build_vectorstore(path: str):
 
51
  loader = PyPDFLoader(path)
52
  docs = loader.load()
53
 
 
56
  chunk_overlap=100
57
  )
58
 
59
+ chunks = splitter.split_documents(docs)
60
  embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
61
 
62
+ vectorstore = FAISS.from_documents(chunks, embeddings)
63
  vectorstore.save_local(VECTORSTORE_DIR)
64
 
65
  return vectorstore
66
 
67
 
68
+ def update_retriever(path: str):
 
69
  global retriever
70
+ retriever = build_vectorstore(path).as_retriever(search_kwargs={"k": 4})
 
71
 
72
 
73
  # ==============================
 
77
 
78
  @tool
79
  def rag_search(query: str) -> str:
80
+ """Retrieve relevant context from uploaded documents."""
 
 
 
 
81
 
82
+ global retriever
 
83
 
84
  if retriever is None:
85
+ return "No document uploaded yet."
86
 
87
  docs = retriever.invoke(query)
88
 
89
  if not docs:
90
+ return "No relevant information found in the document."
91
 
92
  return "\n\n".join(d.page_content for d in docs)
93
 
94
  return rag_search
95
 
96
 
97
+ # -----------------------------
98
+ # External tools (safe)
99
+ # -----------------------------
100
 
101
  @tool
102
+ def wikipedia_search(query: str) -> dict:
103
+ """Search Wikipedia."""
104
  try:
105
+ return {"results": WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()).run(query)}
 
106
  except Exception as e:
107
  return {"error": str(e)}
108
 
109
 
110
  @tool
111
+ def arxiv_search(query: str) -> dict:
112
+ """Search academic papers on arXiv."""
113
  try:
114
+ return {"results": ArxivQueryRun(api_wrapper=ArxivAPIWrapper()).run(query)}
 
115
  except Exception as e:
116
  return {"error": str(e)}
117
 
 
120
  def tavily_search(query: str) -> dict:
121
  """Search the web using Tavily."""
122
  try:
123
+ return {"results": TavilySearchResults(max_results=5).run(query)}
 
124
  except Exception as e:
125
+ return {"error": str(e)}