Agent_paper / tools.py
Agentneed's picture
n
4b95d23
from utils import *
import os
import time
import arxiv
import io, sys
import traceback
import matplotlib
import numpy as np
import multiprocessing
from pypdf import PdfReader
from datasets import load_dataset
from psutil._common import bytes2human
from datasets import load_dataset_builder
from semanticscholar import SemanticScholar
from sklearn.metrics.pairwise import linear_kernel
from sklearn.feature_extraction.text import TfidfVectorizer
class HFDataSearch:
def __init__(self, like_thr=3, dwn_thr=50) -> None:
"""
Class for finding relevant huggingface datasets
:param like_thr:
:param dwn_thr:
"""
self.dwn_thr = dwn_thr
self.like_thr = like_thr
self.ds = load_dataset("nkasmanoff/huggingface-datasets")["train"]
# Initialize lists to collect filtered data
filtered_indices = []
filtered_descriptions = []
filtered_likes = []
filtered_downloads = []
# Iterate over the dataset and filter based on criteria
for idx, item in enumerate(self.ds):
# Get likes and downloads, handling None values
likes = int(item['likes']) if item['likes'] is not None else 0
downloads = int(item['downloads']) if item['downloads'] is not None else 0
# Check if likes and downloads meet the thresholds
if likes >= self.like_thr and downloads >= self.dwn_thr:
# Check if the description is a non-empty string
description = item['description']
if isinstance(description, str) and description.strip():
# Collect the data
filtered_indices.append(idx)
filtered_descriptions.append(description)
filtered_likes.append(likes)
filtered_downloads.append(downloads)
# Check if any datasets meet all criteria
if not filtered_indices:
print("No datasets meet the specified criteria.")
self.ds = []
self.descriptions = []
self.likes_norm = []
self.downloads_norm = []
self.description_vectors = None
return # Exit the constructor
# Filter the datasets using the collected indices
self.ds = self.ds.select(filtered_indices)
# Update descriptions, likes, and downloads
self.descriptions = filtered_descriptions
self.likes = np.array(filtered_likes)
self.downloads = np.array(filtered_downloads)
# Normalize likes and downloads
self.likes_norm = self._normalize(self.likes)
self.downloads_norm = self._normalize(self.downloads)
# Vectorize the descriptions
self.vectorizer = TfidfVectorizer()
self.description_vectors = self.vectorizer.fit_transform(self.descriptions)
def _normalize(self, arr):
min_val = arr.min()
max_val = arr.max()
if max_val - min_val == 0:
return np.zeros_like(arr, dtype=float)
return (arr - min_val) / (max_val - min_val)
def retrieve_ds(self, query, N=10, sim_w=1.0, like_w=0.0, dwn_w=0.0):
"""
Retrieves the top N datasets matching the query, weighted by likes and downloads.
:param query: The search query string.
:param N: The number of results to return.
:param sim_w: Weight for cosine similarity.
:param like_w: Weight for likes.
:param dwn_w: Weight for downloads.
:return: List of top N dataset items.
"""
if not self.ds or self.description_vectors is None:
print("No datasets available to search.")
return []
query_vector = self.vectorizer.transform([query])
cosine_similarities = linear_kernel(query_vector, self.description_vectors).flatten()
# Normalize cosine similarities
cosine_similarities_norm = self._normalize(cosine_similarities)
# Compute final scores
final_scores = (
sim_w * cosine_similarities_norm +
like_w * self.likes_norm +
dwn_w * self.downloads_norm
)
# Get top N indices
top_indices = final_scores.argsort()[-N:][::-1]
# Convert indices to Python ints
top_indices = [int(i) for i in top_indices]
top_datasets = [self.ds[i] for i in top_indices]
# check if dataset has a test & train set
has_test_set = list()
has_train_set = list()
ds_size_info = list()
for i in top_indices:
try:
dbuilder = load_dataset_builder(self.ds[i]["id"], trust_remote_code=True).info
except Exception as e:
has_test_set.append(False)
has_train_set.append(False)
ds_size_info.append((None, None, None, None))
continue
if dbuilder.splits is None:
has_test_set.append(False)
has_train_set.append(False)
ds_size_info.append((None, None, None, None))
continue
# Print number of examples for
has_test, has_train = "test" in dbuilder.splits, "train" in dbuilder.splits
has_test_set.append(has_test)
has_train_set.append(has_train)
test_dwn_size, test_elem_size = None, None
train_dwn_size, train_elem_size = None, None
if has_test:
test_dwn_size = bytes2human(dbuilder.splits["test"].num_bytes)
test_elem_size = dbuilder.splits["test"].num_examples
if has_train:
train_dwn_size = bytes2human(dbuilder.splits["train"].num_bytes)
train_elem_size = dbuilder.splits["train"].num_examples
ds_size_info.append((test_dwn_size, test_elem_size, train_dwn_size, train_elem_size))
for _i in range(len(top_datasets)):
top_datasets[_i]["has_test_set"] = has_test_set[_i]
top_datasets[_i]["has_train_set"] = has_train_set[_i]
top_datasets[_i]["test_download_size"] = ds_size_info[_i][0]
top_datasets[_i]["test_element_size"] = ds_size_info[_i][1]
top_datasets[_i]["train_download_size"] = ds_size_info[_i][2]
top_datasets[_i]["train_element_size"] = ds_size_info[_i][3]
return top_datasets
def results_str(self, results):
"""
Provide results as list of results in human-readable format.
:param results: (list(dict)) list of results from search
:return: (list(str)) list of results in human-readable format
"""
result_strs = list()
for result in results:
res_str = f"Dataset ID: {result['id']}\n"
res_str += f"Description: {result['description']}\n"
res_str += f"Likes: {result['likes']}\n"
res_str += f"Downloads: {result['downloads']}\n"
res_str += f"Has Testing Set: {result['has_test_set']}\n"
res_str += f"Has Training Set: {result['has_train_set']}\n"
res_str += f"Test Download Size: {result['test_download_size']}\n"
res_str += f"Test Dataset Size: {result['test_element_size']}\n"
res_str += f"Train Download Size: {result['train_download_size']}\n"
res_str += f"Train Dataset Size: {result['train_element_size']}\n"
result_strs.append(res_str)
return result_strs
class SemanticScholarSearch:
def __init__(self):
self.sch_engine = SemanticScholar(retry=False)
def find_papers_by_str(self, query, N=10):
paper_sums = list()
results = self.sch_engine.search_paper(query, limit=N, min_citation_count=3, open_access_pdf=True)
for _i in range(len(results)):
paper_sum = f'Title: {results[_i].title}\n'
paper_sum += f'Abstract: {results[_i].abstract}\n'
paper_sum += f'Citations: {results[_i].citationCount}\n'
paper_sum += f'Release Date: year {results[_i].publicationDate.year}, month {results[_i].publicationDate.month}, day {results[_i].publicationDate.day}\n'
paper_sum += f'Venue: {results[_i].venue}\n'
paper_sum += f'Paper ID: {results[_i].externalIds["DOI"]}\n'
paper_sums.append(paper_sum)
return paper_sums
def retrieve_full_paper_text(self, query):
pass
class ArxivSearch:
def __init__(self):
# Construct the default API client.
self.sch_engine = arxiv.Client()
def _process_query(self, query: str) -> str:
"""Process query string to fit within MAX_QUERY_LENGTH while preserving as much information as possible"""
MAX_QUERY_LENGTH = 300
if len(query) <= MAX_QUERY_LENGTH:
return query
# Split into words
words = query.split()
processed_query = []
current_length = 0
# Add words while staying under the limit
# Account for spaces between words
for word in words:
# +1 for the space that will be added between words
if current_length + len(word) + 1 <= MAX_QUERY_LENGTH:
processed_query.append(word)
current_length += len(word) + 1
else:
break
return ' '.join(processed_query)
def find_papers_by_str(self, query, N=20):
processed_query = self._process_query(query)
max_retries = 3
retry_count = 0
while retry_count < max_retries:
try:
search = arxiv.Search(
query="abs:" + processed_query,
max_results=N,
sort_by=arxiv.SortCriterion.Relevance)
paper_sums = list()
# `results` is a generator; you can iterate over its elements one by one...
for r in self.sch_engine.results(search):
paperid = r.pdf_url.split("/")[-1]
pubdate = str(r.published).split(" ")[0]
paper_sum = f"Title: {r.title}\n"
paper_sum += f"Summary: {r.summary}\n"
paper_sum += f"Publication Date: {pubdate}\n"
#paper_sum += f"Categories: {' '.join(r.categories)}\n"
paper_sum += f"arXiv paper ID: {paperid}\n"
paper_sums.append(paper_sum)
time.sleep(2.0)
return "\n".join(paper_sums)
except Exception as e:
retry_count += 1
if retry_count < max_retries:
time.sleep(2 * retry_count)
continue
return None
def retrieve_full_paper_text(self, query, MAX_LEN=50000):
pdf_text = str()
paper = next(arxiv.Client().results(arxiv.Search(id_list=[query])))
# Download the PDF to the PWD with a custom filename.
paper.download_pdf(filename="downloaded-paper.pdf")
# creating a pdf reader object
reader = PdfReader('downloaded-paper.pdf')
# Iterate over all the pages
for page_number, page in enumerate(reader.pages, start=1):
# Extract text from the page
try:
text = page.extract_text()
except Exception as e:
os.remove("downloaded-paper.pdf")
time.sleep(2.0)
return "EXTRACTION FAILED"
# Do something with the text (e.g., print it)
pdf_text += f"--- Page {page_number} ---"
pdf_text += text
pdf_text += "\n"
os.remove("downloaded-paper.pdf")
time.sleep(2.0)
return pdf_text[:MAX_LEN]
# Set the non-interactive backend early in the module
matplotlib.use('Agg')
import matplotlib.pyplot as plt
def worker_run_code(code_str, output_queue):
output_capture = io.StringIO()
sys.stdout = output_capture
try:
# Create a globals dictionary with __name__ set to "__main__"
globals_dict = {"__name__": "__main__"}
exec(code_str, globals_dict)
except Exception as e:
output_capture.write(f"[CODE EXECUTION ERROR]: {str(e)}\n")
traceback.print_exc(file=output_capture)
finally:
sys.stdout = sys.__stdout__
output_queue.put(output_capture.getvalue())
def execute_code(code_str, timeout=600, MAX_LEN=1000):
#code_str = code_str.replace("\\n", "\n")
code_str = "from utils import *\n" + code_str
if "load_dataset('pubmed" in code_str:
return "[CODE EXECUTION ERROR] pubmed Download took way too long. Program terminated"
if "exit(" in code_str:
return "[CODE EXECUTION ERROR] The exit() command is not allowed you must remove this."
output_queue = multiprocessing.Queue()
proc = multiprocessing.Process(target=worker_run_code, args=(code_str, output_queue))
proc.start()
proc.join(timeout)
if proc.is_alive():
proc.terminate() # Forcefully kill the process
proc.join()
return (f"[CODE EXECUTION ERROR]: Code execution exceeded the timeout limit of {timeout} seconds. "
"You must reduce the time complexity of your code.")
else:
if not output_queue.empty(): output = output_queue.get()
else: output = ""
return output