jinysun commited on
Commit
0d546a3
·
verified ·
1 Parent(s): b58ae5c

Update tool/search.py

Browse files
Files changed (1) hide show
  1. tool/search.py +141 -72
tool/search.py CHANGED
@@ -3,8 +3,8 @@ import re
3
  import langchain
4
  from paperqa import Docs, Settings
5
  import asyncio
6
- import paperqa
7
- import paperscraper
8
  from langchain_community.utilities import SerpAPIWrapper
9
  from langchain.base_language import BaseLanguageModel
10
  from langchain.tools import BaseTool
@@ -25,6 +25,43 @@ def is_smiles(text):
25
 
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def is_multiple_smiles(text):
29
  if is_smiles(text):
30
  return "." in text
@@ -45,85 +82,117 @@ def paper_scraper(search: str, pdir: str = "query", semantic_scholar_api_key: st
45
  return {}
46
 
47
 
48
- def paper_search(llm, query, semantic_scholar_api_key=None):
49
- prompt = langchain.prompts.PromptTemplate(
50
- input_variables=["question"],
51
- template="""
52
- I would like to find scholarly papers to answer
53
- this question: {question}. Your response must be at
54
- most 10 words long.
55
- 'A search query that would bring up papers that can answer
56
- this question would be: '""",
57
- )
58
-
59
- query_chain = langchain.chains.llm.LLMChain(llm=llm, prompt=prompt)
60
- if not os.path.isdir("./query"): # todo: move to ckpt
61
- os.mkdir("query/")
62
- search = query_chain.invoke(query)
63
- print("\nSearch:", search)
64
- papers = paper_scraper(search['text'], semantic_scholar_api_key=semantic_scholar_api_key)
65
- return papers
66
-
67
-
68
- async def scholar2result_llm(llm, query, k=5, max_sources=2, openai_api_key=None, semantic_scholar_api_key=None):
69
- """Useful to answer questions that require
70
- technical knowledge. Ask a specific question."""
71
- papers = paper_search(llm, query, semantic_scholar_api_key=semantic_scholar_api_key)
72
- if len(papers) == 0:
73
- return "Not enough papers found"
74
- docs = Docs()
75
- settings = Settings()
76
- settings.llm = llm
77
 
78
- not_loaded = 0
79
- for path, data in papers.items():
80
- try:
81
- await docs.aadd(path)
82
- except (ValueError, FileNotFoundError, PdfReadError):
83
- not_loaded += 1
84
-
85
- if not_loaded > 0:
86
- print(f"\nFound {len(papers.items())} papers but couldn't load {not_loaded}.")
87
- else:
88
- print(f"\nFound {len(papers.items())} papers and loaded all of them.")
89
 
90
 
91
- answer = await docs.aquery(query)
92
- return answer.answer
 
 
 
 
 
 
 
 
 
 
 
93
 
94
 
95
- class LiteratureSearch(BaseTool):
96
- name: str = "LiteratureSearch"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  description: str = (
98
- "Useful to answer questions that require technical "
99
- "knowledge. Ask a specific question."
100
  )
101
- llm: BaseLanguageModel = None
102
- openai_api_key: str = None
103
- semantic_scholar_api_key: str = None
104
-
105
 
106
- def __init__(self, llm, openai_api_key, semantic_scholar_api_key):
107
  super().__init__()
108
-
109
- # api keys
110
- self.openai_api_key = openai_api_key
111
- self.semantic_scholar_api_key = semantic_scholar_api_key
112
- self.llm = ChatOpenAI(model="gpt-4o-2024-11-20",openai_api_key=self.openai_api_key,
113
- base_url=os.getenv("OPENAI_API_BASE"))
114
- def _run(self, query) -> str:
115
- os.environ["OPENAI_API_KEY"] = self.openai_api_key
116
- os.environ["OPENAI_API_BASE"] = os.getenv("OPENAI_API_BASE")
117
- return asyncio.run(scholar2result_llm(
118
- self.llm,
119
- query,
120
- openai_api_key=self.openai_api_key,
121
- semantic_scholar_api_key=self.semantic_scholar_api_key
122
- ))
123
-
124
- async def _arun(self, query) -> str:
125
- """Use the tool asynchronously."""
126
- raise NotImplementedError("this tool does not support async")
127
 
128
  def web_search(keywords, search_engine="google"):
129
  try:
 
3
  import langchain
4
  from paperqa import Docs, Settings
5
  import asyncio
6
+ #import paperqa
7
+ #import paperscraper
8
  from langchain_community.utilities import SerpAPIWrapper
9
  from langchain.base_language import BaseLanguageModel
10
  from langchain.tools import BaseTool
 
25
 
26
 
27
 
28
+ def is_multiple_smiles(text):
29
+ if is_smiles(text):
30
+ return "." in text
31
+ return False
32
+
33
+
34
+ def split_smiles(text):
35
+ return text.split(".")
36
+ import os
37
+ import re
38
+
39
+ import langchain
40
+ from paperqa import Docs, Settings
41
+ import asyncio
42
+ # import paperqa
43
+ # import paperscraper
44
+ from langchain_community.utilities import SerpAPIWrapper
45
+ from langchain.base_language import BaseLanguageModel
46
+ from langchain.tools import BaseTool
47
+ from langchain_openai import OpenAIEmbeddings
48
+ from pypdf.errors import PdfReadError
49
+ from rdkit import Chem, DataStructs
50
+ from rdkit.Chem import AllChem
51
+ import nest_asyncio
52
+ from langchain_openai import ChatOpenAI
53
+ nest_asyncio.apply()
54
+ def is_smiles(text):
55
+ try:
56
+ m = Chem.MolFromSmiles(text, sanitize=False)
57
+ if m is None:
58
+ return False
59
+ return True
60
+ except:
61
+ return False
62
+
63
+
64
+
65
  def is_multiple_smiles(text):
66
  if is_smiles(text):
67
  return "." in text
 
82
  return {}
83
 
84
 
85
+ # def paper_search(llm, query, semantic_scholar_api_key=None):
86
+ # prompt = langchain.prompts.PromptTemplate(
87
+ # input_variables=["question"],
88
+ # template="""
89
+ # I would like to find scholarly papers to answer
90
+ # this question: {question}. Your response must be at
91
+ # most 10 words long.
92
+ # 'A search query that would bring up papers that can answer
93
+ # this question would be: '""",
94
+ # )
95
+
96
+ # query_chain = langchain.chains.llm.LLMChain(llm=llm, prompt=prompt)
97
+ # if not os.path.isdir("./query"): # todo: move to ckpt
98
+ # os.mkdir("query/")
99
+ # search = query_chain.invoke(query)
100
+ # print("\nSearch:", search)
101
+ # papers = paper_scraper(search['text'], semantic_scholar_api_key=semantic_scholar_api_key)
102
+ # return papers
103
+
104
+
105
+ # async def scholar2result_llm(llm, query, k=5, max_sources=2, openai_api_key=None, semantic_scholar_api_key=None):
106
+ # """Useful to answer questions that require
107
+ # technical knowledge. Ask a specific question."""
108
+ # papers = paper_search(llm, query, semantic_scholar_api_key=semantic_scholar_api_key)
109
+ # if len(papers) == 0:
110
+ # return "Not enough papers found"
111
+ # docs = Docs()
112
+ # settings = Settings()
113
+ # settings.llm = llm
114
 
115
+ # not_loaded = 0
116
+ # for path, data in papers.items():
117
+ # try:
118
+ # await docs.aadd(path)
119
+ # except (ValueError, FileNotFoundError, PdfReadError):
120
+ # not_loaded += 1
121
+
122
+ # if not_loaded > 0:
123
+ # print(f"\nFound {len(papers.items())} papers but couldn't load {not_loaded}.")
124
+ # else:
125
+ # print(f"\nFound {len(papers.items())} papers and loaded all of them.")
126
 
127
 
128
+ # answer = await docs.aquery(query)
129
+ # return answer.answer
130
+
131
+
132
+ # class LiteratureSearch(BaseTool):
133
+ # name: str = "LiteratureSearch"
134
+ # description: str = (
135
+ # "Useful to answer questions that require technical "
136
+ # "knowledge. Ask a specific question."
137
+ # )
138
+ # llm: BaseLanguageModel = None
139
+ # openai_api_key: str = None
140
+ # semantic_scholar_api_key: str = None
141
 
142
 
143
+ # def __init__(self, llm, openai_api_key, semantic_scholar_api_key):
144
+ # super().__init__()
145
+
146
+ # # api keys
147
+ # self.openai_api_key = openai_api_key
148
+ # self.semantic_scholar_api_key = semantic_scholar_api_key
149
+ # self.llm = ChatOpenAI(model="gpt-4o-2024-11-20",openai_api_key=self.openai_api_key,
150
+ # base_url=os.getenv("OPENAI_API_BASE"))
151
+ # def _run(self, query) -> str:
152
+ # os.environ["OPENAI_API_KEY"] = self.openai_api_key
153
+ # os.environ["OPENAI_API_BASE"] = os.getenv("OPENAI_API_BASE")
154
+ # return asyncio.run(scholar2result_llm(
155
+ # self.llm,
156
+ # query,
157
+ # openai_api_key=self.openai_api_key,
158
+ # semantic_scholar_api_key=self.semantic_scholar_api_key
159
+ # ))
160
+
161
+ # async def _arun(self, query) -> str:
162
+ # """Use the tool asynchronously."""
163
+ # raise NotImplementedError("this tool does not support async")
164
+
165
+ def web_search(keywords, search_engine="google"):
166
+ try:
167
+ return SerpAPIWrapper(
168
+ serpapi_api_key=os.getenv("SERP_API_KEY"), search_engine=search_engine
169
+ ).run(keywords)
170
+ except:
171
+ return "No results, try another search"
172
+
173
+
174
+ class WebSearch(BaseTool):
175
+ name: str = "WebSearch"
176
  description: str = (
177
+ "Input a specific question, returns an answer from web search. "
178
+ "Give more detailed information and use more general features to formulate your questions."
179
  )
180
+ serp_api_key: str = None
 
 
 
181
 
182
+ def __init__(self, serp_api_key: str = None):
183
  super().__init__()
184
+ self.serp_api_key = serp_api_key
185
+
186
+ def _run(self, query: str) -> str:
187
+ if not self.serp_api_key:
188
+ return (
189
+ "No SerpAPI key found. This tool may not be used without a SerpAPI key."
190
+ )
191
+ return web_search(query)
192
+
193
+ async def _arun(self, query: str) -> str:
194
+ raise NotImplementedError("Async not implemented")
195
+
 
 
 
 
 
 
 
196
 
197
  def web_search(keywords, search_engine="google"):
198
  try: