jinysun commited on
Commit
f51746a
·
verified ·
1 Parent(s): 503b822

Update tool/search.py

Browse files
Files changed (1) hide show
  1. tool/search.py +31 -25
tool/search.py CHANGED
@@ -6,8 +6,12 @@ subprocess.check_call(["pip", "install", "--no-deps", "paper-scraper @ git+https
6
  subprocess.check_call(["pip", "install", "--no-deps", "google-search-results"])
7
 
8
 
 
 
 
9
  import langchain
10
-
 
11
  import paperqa
12
  import paperscraper
13
  from langchain_community.utilities import SerpAPIWrapper
@@ -17,7 +21,9 @@ from langchain_openai import OpenAIEmbeddings
17
  from pypdf.errors import PdfReadError
18
  from rdkit import Chem, DataStructs
19
  from rdkit.Chem import AllChem
20
-
 
 
21
  def is_smiles(text):
22
  try:
23
  m = Chem.MolFromSmiles(text, sanitize=False)
@@ -38,7 +44,7 @@ def is_multiple_smiles(text):
38
  def split_smiles(text):
39
  return text.split(".")
40
 
41
- def paper_scrap(search: str, pdir: str = "query", semantic_scholar_api_key: str = None) -> dict:
42
  try:
43
  return paperscraper.search_papers(
44
  search,
@@ -63,27 +69,26 @@ def paper_search(llm, query, semantic_scholar_api_key=None):
63
  query_chain = langchain.chains.llm.LLMChain(llm=llm, prompt=prompt)
64
  if not os.path.isdir("./query"): # todo: move to ckpt
65
  os.mkdir("query/")
66
- search = query_chain.run(query)
67
  print("\nSearch:", search)
68
- papers = paper_scrap(search, pdir=f"query/{re.sub(' ', '', search)}", semantic_scholar_api_key=semantic_scholar_api_key)
69
  return papers
70
 
71
 
72
- def scholar2result_llm(llm, query, k=5, max_sources=2, openai_api_key=None, semantic_scholar_api_key=None):
73
  """Useful to answer questions that require
74
  technical knowledge. Ask a specific question."""
75
  papers = paper_search(llm, query, semantic_scholar_api_key=semantic_scholar_api_key)
76
  if len(papers) == 0:
77
  return "Not enough papers found"
78
- docs = paperqa.Docs(
79
- llm=llm,
80
- summary_llm=llm,
81
- embeddings=OpenAIEmbeddings(openai_api_key=openai_api_key),
82
- )
83
  not_loaded = 0
84
  for path, data in papers.items():
85
  try:
86
- docs.add(path, data["citation"])
87
  except (ValueError, FileNotFoundError, PdfReadError):
88
  not_loaded += 1
89
 
@@ -92,12 +97,13 @@ def scholar2result_llm(llm, query, k=5, max_sources=2, openai_api_key=None, sema
92
  else:
93
  print(f"\nFound {len(papers.items())} papers and loaded all of them.")
94
 
95
- answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer
96
- return answer
 
97
 
98
 
99
- class Scholar2ResultLLM(BaseTool):
100
- name : str = "LiteratureSearch"
101
  description: str = (
102
  "Useful to answer questions that require technical "
103
  "knowledge. Ask a specific question."
@@ -109,28 +115,30 @@ class Scholar2ResultLLM(BaseTool):
109
 
110
  def __init__(self, llm, openai_api_key, semantic_scholar_api_key):
111
  super().__init__()
112
- self.llm = llm
113
  # api keys
114
  self.openai_api_key = openai_api_key
115
  self.semantic_scholar_api_key = semantic_scholar_api_key
116
-
 
117
  def _run(self, query) -> str:
118
- return scholar2result_llm(
 
 
119
  self.llm,
120
  query,
121
  openai_api_key=self.openai_api_key,
122
  semantic_scholar_api_key=self.semantic_scholar_api_key
123
- )
124
 
125
  async def _arun(self, query) -> str:
126
  """Use the tool asynchronously."""
127
  raise NotImplementedError("this tool does not support async")
128
 
129
-
130
  def web_search(keywords, search_engine="google"):
131
  try:
132
  return SerpAPIWrapper(
133
- serpapi_api_key='3795acda6a74ea15033d34b54eac82982b26f559147d9cf04aca4bfca91c3e9d', search_engine=search_engine
134
  ).run(keywords)
135
  except:
136
  return "No results, try another search"
@@ -156,6 +164,4 @@ class WebSearch(BaseTool):
156
  return web_search(query)
157
 
158
  async def _arun(self, query: str) -> str:
159
- raise NotImplementedError("Async not implemented")
160
-
161
-
 
6
  subprocess.check_call(["pip", "install", "--no-deps", "google-search-results"])
7
 
8
 
9
+ import os
10
+ import re
11
+
12
  import langchain
13
+ from paperqa import Docs, Settings
14
+ import asyncio
15
  import paperqa
16
  import paperscraper
17
  from langchain_community.utilities import SerpAPIWrapper
 
21
  from pypdf.errors import PdfReadError
22
  from rdkit import Chem, DataStructs
23
  from rdkit.Chem import AllChem
24
+ import nest_asyncio
25
+ from langchain_openai import ChatOpenAI
26
+ nest_asyncio.apply()
27
  def is_smiles(text):
28
  try:
29
  m = Chem.MolFromSmiles(text, sanitize=False)
 
44
  def split_smiles(text):
45
  return text.split(".")
46
 
47
+ def paper_scraper(search: str, pdir: str = "query", semantic_scholar_api_key: str = None) -> dict:
48
  try:
49
  return paperscraper.search_papers(
50
  search,
 
69
  query_chain = langchain.chains.llm.LLMChain(llm=llm, prompt=prompt)
70
  if not os.path.isdir("./query"): # todo: move to ckpt
71
  os.mkdir("query/")
72
+ search = query_chain.invoke(query)
73
  print("\nSearch:", search)
74
+ papers = paper_scraper(search['text'], semantic_scholar_api_key=semantic_scholar_api_key)
75
  return papers
76
 
77
 
78
+ async def scholar2result_llm(llm, query, k=5, max_sources=2, openai_api_key=None, semantic_scholar_api_key=None):
79
  """Useful to answer questions that require
80
  technical knowledge. Ask a specific question."""
81
  papers = paper_search(llm, query, semantic_scholar_api_key=semantic_scholar_api_key)
82
  if len(papers) == 0:
83
  return "Not enough papers found"
84
+ docs = Docs()
85
+ settings = Settings()
86
+ settings.llm = llm
87
+
 
88
  not_loaded = 0
89
  for path, data in papers.items():
90
  try:
91
+ await docs.aadd(path)
92
  except (ValueError, FileNotFoundError, PdfReadError):
93
  not_loaded += 1
94
 
 
97
  else:
98
  print(f"\nFound {len(papers.items())} papers and loaded all of them.")
99
 
100
+
101
+ answer = await docs.aquery(query)
102
+ return answer.answer
103
 
104
 
105
+ class LiteratureSearch(BaseTool):
106
+ name: str = "LiteratureSearch"
107
  description: str = (
108
  "Useful to answer questions that require technical "
109
  "knowledge. Ask a specific question."
 
115
 
116
  def __init__(self, llm, openai_api_key, semantic_scholar_api_key):
117
  super().__init__()
118
+
119
  # api keys
120
  self.openai_api_key = openai_api_key
121
  self.semantic_scholar_api_key = semantic_scholar_api_key
122
+ self.llm = ChatOpenAI(model="gpt-4o-2024-11-20",openai_api_key=self.openai_api_key,
123
+ base_url=os.getenv("OPENAI_API_BASE"))
124
  def _run(self, query) -> str:
125
+ os.environ["OPENAI_API_KEY"] = self.openai_api_key
126
+ os.environ["OPENAI_API_BASE"] = os.getenv("OPENAI_API_BASE")
127
+ return asyncio.run(scholar2result_llm(
128
  self.llm,
129
  query,
130
  openai_api_key=self.openai_api_key,
131
  semantic_scholar_api_key=self.semantic_scholar_api_key
132
+ ))
133
 
134
  async def _arun(self, query) -> str:
135
  """Use the tool asynchronously."""
136
  raise NotImplementedError("this tool does not support async")
137
 
 
138
  def web_search(keywords, search_engine="google"):
139
  try:
140
  return SerpAPIWrapper(
141
+ serpapi_api_key=os.getenv("SERP_API_KEY"), search_engine=search_engine
142
  ).run(keywords)
143
  except:
144
  return "No results, try another search"
 
164
  return web_search(query)
165
 
166
  async def _arun(self, query: str) -> str:
167
+ raise NotImplementedError("Async not implemented")