File size: 6,779 Bytes
9c78fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a48fbc1
9c78fbf
 
 
ff96769
 
 
 
 
9c78fbf
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import textwrap

import os
import re
import argparse

import requests

import google.generativeai as genai

from IPython.display import Markdown

import gradio as gr

# # Used to securely store your API key
# from google.colab import userdata

gemini_api_key = os.environ.get('GEMINI_API_KEY', '-1')
genai.configure(api_key=gemini_api_key)


S2_API_KEY = os.getenv('S2_API_KEY')
initial_result_limit = 10
final_result_limit = 5

# Select relevant fields to pull
fields = 'title,url,abstract,citationCount,journal,isOpenAccess,fieldsOfStudy,year,journal'


def raw_to_markdown(text):
    text = text.replace('•', '  *')
    return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))


def markdown_to_raw(markdown_text):
    """
    This function converts basic markdown text to raw text.

    Args:
        markdown_text: The markdown text string to be converted.

    Returns:
        A string containing the raw text equivalent of the markdown text.
    """
    # Remove headers
    text = re.sub(r'#+ ?', '', markdown_text)

    # Remove bold and italics (can be adjusted based on needs)
    text = re.sub(r'\*\*(.+?)\*\*', r'\1', text)  # Bold
    text = re.sub(r'_(.+?)_', r'\1', text)        # Italics

    # Remove code blocks
    text = re.sub(r'`(.*?)`', '', text, flags=re.DOTALL)

    # Remove lists
    text = re.sub(r'\*+ (.*?)$', r'\1', text, flags=re.MULTILINE)  # Unordered lists
    text.strip()  # Remove extra whitespace

    return text


def find_basis_papers(query):
    papers = None
    if not query:
        print('No query given')
        return None

    rsp = requests.get('https://api.semanticscholar.org/graph/v1/paper/search',
                       headers={'X-API-KEY': S2_API_KEY},
                       params={'query': query, 'limit': initial_result_limit, 'fields': fields})
    rsp.raise_for_status()
    results = rsp.json()
    total = results["total"]
    if not total:
        print('No matches found. Please try another query.')
        return None

    print(f'Found {total} initial results. Showing up to {initial_result_limit}.')
    papers = results['data']
    # print("INITIAL RESULTS")
    # print_papers(papers)

    # Filter paper results
    filtered_papers = list(filter(isValidPaper, papers))

    # print("FILTERED RESULTS")
    # print_papers(filtered_papers)

    # rank paper results
    ranked_papers = sorted(filtered_papers, key=lambda x: (x['year'], x['citationCount']), reverse=True)

    # print("RANKED RESULTS")
    # print_papers(ranked_papers)

    # return 5 best papers
    return ranked_papers[0:5]


# def print_papers(papers):
#     for idx, paper in enumerate(papers):
#         print(f"PAPER {idx}")
#         for key, value in paper.items():
#             if key != 'abstract':
#                 print(f"\t{key}: '{value}'")


def isValidPaper(paper):
    if paper['isOpenAccess'] and paper['abstract']:
        return True
    else:
        return False


# def filter_papers(papers):
#     filtered_papers = []
#     for paper in papers:
#         if paper['isOpenAccess'] and paper['abstract']:
#             # paper is acceptable
#             filtered_papers.append(paper)
#     return filtered_papers


def GEMINI_optimize_query(initial_query: str):
    # initialize gemini LLM
    model = genai.GenerativeModel('gemini-pro')
    chat = model.start_chat(history=[])

    prompt = f"""Given a search query, return an optimized version of the query to find related academic papers 
    QUERY: {initial_query}. 
    Only return the optimized query"""

    response = chat.send_message(prompt)
    optimized_query = markdown_to_raw(response.text)

    return optimized_query


def GEMINI_summarize_abstracts(initial_query: str, papers: str):
    # initialize gemini LLM
    model = genai.GenerativeModel('gemini-pro')
    chat = model.start_chat(history=[])

    prompt = f"""Given the following academic papers, 
    return a review of related literature for the search query: {initial_query}. 
    Ignore papers without abstracts. 
    Here are the papers {papers}
    """

    response = chat.send_message(prompt)
    abstract_summary = markdown_to_raw(response.text)

    return abstract_summary


def create_gemini_model():
    # initialize gemini LLM
    model = genai.GenerativeModel('gemini-pro')
    chat = model.start_chat(history=[])

    return model, chat


# instantiate models
summarizer_model, summarizer_chat = create_gemini_model()
query_optimizer_model, query_optimizer_chat = create_gemini_model()


# def get_paper_links(papers):
#     urls = []
#     for paper in papers:
#         urls = paper['url']

#     return urls


def predict(message, history):
    if history == []:
        query = message
        print(f"INITIAL QUERY: {query}")
        if optimize_query:
            optimizer_prompt = f"""Given a search query, return an optimized
            version of the query to find related academic papers
            QUERY: {query}.
            Only return the optimized query"""

            response = query_optimizer_chat.send_message(optimizer_prompt)
            query = markdown_to_raw(response.text)
            print(f"OPTIMIZED QUERY: {query}")

        # optimized query used to search semantic scholar
        papers = find_basis_papers(query)

        summarizer_prompt = f"""Given the following academic papers, 
        return a review of related literature for the search query: {query}.
        Focus on data/key factors and methodologies considered.
        Here are the papers {papers}
        Include the paper urls at the end of the review of related literature.
        """

        response = summarizer_chat.send_message(summarizer_prompt)
        abstract_summary = markdown_to_raw(response.text)

        return abstract_summary

    response = summarizer_chat.send_message(message)
    response_text = markdown_to_raw(response.text)

    return response_text


def main():
    # GEMINI optimizes query
    gr.ChatInterface(
        predict,
        title="LLM Research Helper",
        description="""Start by inputting a brief description/title
        of your research and our assistant will return a review of
        related literature
        
        ex. Finding optimal site locations for solar farms""",
        examples=['Finding optimal site locations for solar farms',
                  'Wildfire prediction', 
                  'Fish yield prediction']
    ).launch()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Literature review chatbot")
    parser.add_argument("-o", "--optimize_query", help="Use query optimization (True, False)", default=False)
    args = parser.parse_args()

    optimize_query = args.optimize_query if args.optimize_query in [True, False] else False
    main()