paolosandejas-stratpoint's picture
Fix typo
a48fbc1 verified
raw
history blame
6.78 kB
import textwrap
import os
import re
import argparse
import requests
import google.generativeai as genai
from IPython.display import Markdown
import gradio as gr
# # Used to securely store your API key
# from google.colab import userdata
gemini_api_key = os.environ.get('GEMINI_API_KEY', '-1')
genai.configure(api_key=gemini_api_key)
S2_API_KEY = os.getenv('S2_API_KEY')
initial_result_limit = 10
final_result_limit = 5
# Select relevant fields to pull
fields = 'title,url,abstract,citationCount,journal,isOpenAccess,fieldsOfStudy,year,journal'
def raw_to_markdown(text):
text = text.replace('•', ' *')
return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))
def markdown_to_raw(markdown_text):
"""
This function converts basic markdown text to raw text.
Args:
markdown_text: The markdown text string to be converted.
Returns:
A string containing the raw text equivalent of the markdown text.
"""
# Remove headers
text = re.sub(r'#+ ?', '', markdown_text)
# Remove bold and italics (can be adjusted based on needs)
text = re.sub(r'\*\*(.+?)\*\*', r'\1', text) # Bold
text = re.sub(r'_(.+?)_', r'\1', text) # Italics
# Remove code blocks
text = re.sub(r'`(.*?)`', '', text, flags=re.DOTALL)
# Remove lists
text = re.sub(r'\*+ (.*?)$', r'\1', text, flags=re.MULTILINE) # Unordered lists
text.strip() # Remove extra whitespace
return text
def find_basis_papers(query):
papers = None
if not query:
print('No query given')
return None
rsp = requests.get('https://api.semanticscholar.org/graph/v1/paper/search',
headers={'X-API-KEY': S2_API_KEY},
params={'query': query, 'limit': initial_result_limit, 'fields': fields})
rsp.raise_for_status()
results = rsp.json()
total = results["total"]
if not total:
print('No matches found. Please try another query.')
return None
print(f'Found {total} initial results. Showing up to {initial_result_limit}.')
papers = results['data']
# print("INITIAL RESULTS")
# print_papers(papers)
# Filter paper results
filtered_papers = list(filter(isValidPaper, papers))
# print("FILTERED RESULTS")
# print_papers(filtered_papers)
# rank paper results
ranked_papers = sorted(filtered_papers, key=lambda x: (x['year'], x['citationCount']), reverse=True)
# print("RANKED RESULTS")
# print_papers(ranked_papers)
# return 5 best papers
return ranked_papers[0:5]
# def print_papers(papers):
# for idx, paper in enumerate(papers):
# print(f"PAPER {idx}")
# for key, value in paper.items():
# if key != 'abstract':
# print(f"\t{key}: '{value}'")
def isValidPaper(paper):
if paper['isOpenAccess'] and paper['abstract']:
return True
else:
return False
# def filter_papers(papers):
# filtered_papers = []
# for paper in papers:
# if paper['isOpenAccess'] and paper['abstract']:
# # paper is acceptable
# filtered_papers.append(paper)
# return filtered_papers
def GEMINI_optimize_query(initial_query: str):
# initialize gemini LLM
model = genai.GenerativeModel('gemini-pro')
chat = model.start_chat(history=[])
prompt = f"""Given a search query, return an optimized version of the query to find related academic papers
QUERY: {initial_query}.
Only return the optimized query"""
response = chat.send_message(prompt)
optimized_query = markdown_to_raw(response.text)
return optimized_query
def GEMINI_summarize_abstracts(initial_query: str, papers: str):
# initialize gemini LLM
model = genai.GenerativeModel('gemini-pro')
chat = model.start_chat(history=[])
prompt = f"""Given the following academic papers,
return a review of related literature for the search query: {initial_query}.
Ignore papers without abstracts.
Here are the papers {papers}
"""
response = chat.send_message(prompt)
abstract_summary = markdown_to_raw(response.text)
return abstract_summary
def create_gemini_model():
# initialize gemini LLM
model = genai.GenerativeModel('gemini-pro')
chat = model.start_chat(history=[])
return model, chat
# instantiate models
summarizer_model, summarizer_chat = create_gemini_model()
query_optimizer_model, query_optimizer_chat = create_gemini_model()
# def get_paper_links(papers):
# urls = []
# for paper in papers:
# urls = paper['url']
# return urls
def predict(message, history):
if history == []:
query = message
print(f"INITIAL QUERY: {query}")
if optimize_query:
optimizer_prompt = f"""Given a search query, return an optimized
version of the query to find related academic papers
QUERY: {query}.
Only return the optimized query"""
response = query_optimizer_chat.send_message(optimizer_prompt)
query = markdown_to_raw(response.text)
print(f"OPTIMIZED QUERY: {query}")
# optimized query used to search semantic scholar
papers = find_basis_papers(query)
summarizer_prompt = f"""Given the following academic papers,
return a review of related literature for the search query: {query}.
Focus on data/key factors and methodologies considered.
Here are the papers {papers}
Include the paper urls at the end of the review of related literature.
"""
response = summarizer_chat.send_message(summarizer_prompt)
abstract_summary = markdown_to_raw(response.text)
return abstract_summary
response = summarizer_chat.send_message(message)
response_text = markdown_to_raw(response.text)
return response_text
def main():
# GEMINI optimizes query
gr.ChatInterface(
predict,
title="LLM Research Helper",
description="""Start by inputting a brief description/title
of your research and our assistant will return a review of
related literature
ex. Finding optimal site locations for solar farms""",
examples=['Finding optimal site locations for solar farms',
'Wildfire prediction',
'Fish yield prediction']
).launch()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Literature review chatbot")
parser.add_argument("-o", "--optimize_query", help="Use query optimization (True, False)", default=False)
args = parser.parse_args()
optimize_query = args.optimize_query if args.optimize_query in [True, False] else False
main()