Spaces:
Runtime error
Runtime error
| import textwrap | |
| import os | |
| import re | |
| import argparse | |
| import requests | |
| import google.generativeai as genai | |
| from IPython.display import Markdown | |
| import gradio as gr | |
| # # Used to securely store your API key | |
| # from google.colab import userdata | |
| gemini_api_key = os.environ.get('GEMINI_API_KEY', '-1') | |
| genai.configure(api_key=gemini_api_key) | |
| S2_API_KEY = os.getenv('S2_API_KEY') | |
| initial_result_limit = 10 | |
| final_result_limit = 5 | |
| # Select relevant fields to pull | |
| fields = 'title,url,abstract,citationCount,journal,isOpenAccess,fieldsOfStudy,year,journal' | |
| def raw_to_markdown(text): | |
| text = text.replace('•', ' *') | |
| return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True)) | |
| def markdown_to_raw(markdown_text): | |
| """ | |
| This function converts basic markdown text to raw text. | |
| Args: | |
| markdown_text: The markdown text string to be converted. | |
| Returns: | |
| A string containing the raw text equivalent of the markdown text. | |
| """ | |
| # Remove headers | |
| text = re.sub(r'#+ ?', '', markdown_text) | |
| # Remove bold and italics (can be adjusted based on needs) | |
| text = re.sub(r'\*\*(.+?)\*\*', r'\1', text) # Bold | |
| text = re.sub(r'_(.+?)_', r'\1', text) # Italics | |
| # Remove code blocks | |
| text = re.sub(r'`(.*?)`', '', text, flags=re.DOTALL) | |
| # Remove lists | |
| text = re.sub(r'\*+ (.*?)$', r'\1', text, flags=re.MULTILINE) # Unordered lists | |
| text.strip() # Remove extra whitespace | |
| return text | |
| def find_basis_papers(query): | |
| papers = None | |
| if not query: | |
| print('No query given') | |
| return None | |
| rsp = requests.get('https://api.semanticscholar.org/graph/v1/paper/search', | |
| headers={'X-API-KEY': S2_API_KEY}, | |
| params={'query': query, 'limit': initial_result_limit, 'fields': fields}) | |
| rsp.raise_for_status() | |
| results = rsp.json() | |
| total = results["total"] | |
| if not total: | |
| print('No matches found. Please try another query.') | |
| return None | |
| print(f'Found {total} initial results. Showing up to {initial_result_limit}.') | |
| papers = results['data'] | |
| # print("INITIAL RESULTS") | |
| # print_papers(papers) | |
| # Filter paper results | |
| filtered_papers = list(filter(isValidPaper, papers)) | |
| # print("FILTERED RESULTS") | |
| # print_papers(filtered_papers) | |
| # rank paper results | |
| ranked_papers = sorted(filtered_papers, key=lambda x: (x['year'], x['citationCount']), reverse=True) | |
| # print("RANKED RESULTS") | |
| # print_papers(ranked_papers) | |
| # return 5 best papers | |
| return ranked_papers[0:5] | |
| # def print_papers(papers): | |
| # for idx, paper in enumerate(papers): | |
| # print(f"PAPER {idx}") | |
| # for key, value in paper.items(): | |
| # if key != 'abstract': | |
| # print(f"\t{key}: '{value}'") | |
| def isValidPaper(paper): | |
| if paper['isOpenAccess'] and paper['abstract']: | |
| return True | |
| else: | |
| return False | |
| # def filter_papers(papers): | |
| # filtered_papers = [] | |
| # for paper in papers: | |
| # if paper['isOpenAccess'] and paper['abstract']: | |
| # # paper is acceptable | |
| # filtered_papers.append(paper) | |
| # return filtered_papers | |
| def GEMINI_optimize_query(initial_query: str): | |
| # initialize gemini LLM | |
| model = genai.GenerativeModel('gemini-pro') | |
| chat = model.start_chat(history=[]) | |
| prompt = f"""Given a search query, return an optimized version of the query to find related academic papers | |
| QUERY: {initial_query}. | |
| Only return the optimized query""" | |
| response = chat.send_message(prompt) | |
| optimized_query = markdown_to_raw(response.text) | |
| return optimized_query | |
| def GEMINI_summarize_abstracts(initial_query: str, papers: str): | |
| # initialize gemini LLM | |
| model = genai.GenerativeModel('gemini-pro') | |
| chat = model.start_chat(history=[]) | |
| prompt = f"""Given the following academic papers, | |
| return a review of related literature for the search query: {initial_query}. | |
| Ignore papers without abstracts. | |
| Here are the papers {papers} | |
| """ | |
| response = chat.send_message(prompt) | |
| abstract_summary = markdown_to_raw(response.text) | |
| return abstract_summary | |
| def create_gemini_model(): | |
| # initialize gemini LLM | |
| model = genai.GenerativeModel('gemini-pro') | |
| chat = model.start_chat(history=[]) | |
| return model, chat | |
| # instantiate models | |
| summarizer_model, summarizer_chat = create_gemini_model() | |
| query_optimizer_model, query_optimizer_chat = create_gemini_model() | |
| # def get_paper_links(papers): | |
| # urls = [] | |
| # for paper in papers: | |
| # urls = paper['url'] | |
| # return urls | |
| def predict(message, history): | |
| if history == []: | |
| query = message | |
| print(f"INITIAL QUERY: {query}") | |
| if optimize_query: | |
| optimizer_prompt = f"""Given a search query, return an optimized | |
| version of the query to find related academic papers | |
| QUERY: {query}. | |
| Only return the optimized query""" | |
| response = query_optimizer_chat.send_message(optimizer_prompt) | |
| query = markdown_to_raw(response.text) | |
| print(f"OPTIMIZED QUERY: {query}") | |
| # optimized query used to search semantic scholar | |
| papers = find_basis_papers(query) | |
| summarizer_prompt = f"""Given the following academic papers, | |
| return a review of related literature for the search query: {query}. | |
| Focus on data/key factors and methodologies considered. | |
| Here are the papers {papers} | |
| Include the paper urls at the end of the review of related literature. | |
| """ | |
| response = summarizer_chat.send_message(summarizer_prompt) | |
| abstract_summary = markdown_to_raw(response.text) | |
| return abstract_summary | |
| response = summarizer_chat.send_message(message) | |
| response_text = markdown_to_raw(response.text) | |
| return response_text | |
| def main(): | |
| # GEMINI optimizes query | |
| gr.ChatInterface( | |
| predict, | |
| title="LLM Research Helper", | |
| description="""Start by inputting a brief description/title | |
| of your research and our assistant will return a review of | |
| related literature | |
| ex. Finding optimal site locations for solar farms""", | |
| examples=['Finding optimal site locations for solar farms', | |
| 'Wildfire prediction', | |
| 'Fish yield prediction'] | |
| ).launch() | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description="Literature review chatbot") | |
| parser.add_argument("-o", "--optimize_query", help="Use query optimization (True, False)", default=False) | |
| args = parser.parse_args() | |
| optimize_query = args.optimize_query if args.optimize_query in [True, False] else False | |
| main() | |