PAOLO SANDEJAS commited on
Commit ·
cd1ffaa
1
Parent(s): 97da4ac
Add app files
Browse files- app.py +488 -0
- papers/.DS_Store +0 -0
- requirements.txt +233 -0
app.py
ADDED
|
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib
|
| 2 |
+
import textwrap
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
import requests
|
| 9 |
+
|
| 10 |
+
import google.generativeai as genai
|
| 11 |
+
|
| 12 |
+
from IPython.display import display
|
| 13 |
+
from IPython.display import Markdown
|
| 14 |
+
|
| 15 |
+
from chromadb import Documents, EmbeddingFunction, Embeddings
|
| 16 |
+
from pypdf import PdfReader
|
| 17 |
+
from pypdf.errors import PdfReadError
|
| 18 |
+
import chromadb
|
| 19 |
+
from typing import List
|
| 20 |
+
import shutil
|
| 21 |
+
import ast
|
| 22 |
+
|
| 23 |
+
# from timeout import timeout, TimeoutError
|
| 24 |
+
|
| 25 |
+
import gradio as gr
|
| 26 |
+
|
| 27 |
+
PAPERS_DIR = "/Users/paoloantoniosandejas/Documents/experiment-3/experiment-3/initial_experiments/ps/LLM Research Helper v2/papers"
|
| 28 |
+
RAG_DIR = "/Users/paoloantoniosandejas/Documents/experiment-3/experiment-3/initial_experiments/ps/LLM Research Helper v2/RAG/contents"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
gemini_api_key = os.environ.get('GEMINI_API_KEY', '-1')
|
| 32 |
+
genai.configure(api_key=gemini_api_key)
|
| 33 |
+
|
| 34 |
+
S2_API_KEY = os.getenv('S2_API_KEY')
|
| 35 |
+
initial_result_limit = 10
|
| 36 |
+
final_result_limit = 5
|
| 37 |
+
|
| 38 |
+
# Select relevant fields to pull
|
| 39 |
+
fields = 'title,url,abstract,citationCount,authors,isOpenAccess,fieldsOfStudy,year,journal,openAccessPdf'
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def raw_to_markdown(text):
|
| 43 |
+
text = text.replace('•', ' *')
|
| 44 |
+
return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def markdown_to_raw(markdown_text):
|
| 48 |
+
"""
|
| 49 |
+
This function converts basic markdown text to raw text.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
markdown_text: The markdown text string to be converted.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
A string containing the raw text equivalent of the markdown text.
|
| 56 |
+
"""
|
| 57 |
+
# Remove headers
|
| 58 |
+
text = re.sub(r'#+ ?', '', markdown_text)
|
| 59 |
+
|
| 60 |
+
# Remove bold and italics (can be adjusted based on needs)
|
| 61 |
+
text = re.sub(r'\*\*(.+?)\*\*', r'\1', text) # Bold
|
| 62 |
+
text = re.sub(r'_(.+?)_', r'\1', text) # Italics
|
| 63 |
+
|
| 64 |
+
# Remove code blocks
|
| 65 |
+
text = re.sub(r'`(.*?)`', '', text, flags=re.DOTALL)
|
| 66 |
+
|
| 67 |
+
# Remove lists
|
| 68 |
+
text = re.sub(r'\*+ (.*?)$', r'\1', text, flags=re.MULTILINE) # Unordered lists
|
| 69 |
+
text.strip() # Remove extra whitespace
|
| 70 |
+
|
| 71 |
+
return text
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def find_basis_papers(query):
|
| 75 |
+
papers = None
|
| 76 |
+
if not query:
|
| 77 |
+
print('No query given')
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
rsp = requests.get('https://api.semanticscholar.org/graph/v1/paper/search',
|
| 81 |
+
headers={'X-API-KEY': S2_API_KEY},
|
| 82 |
+
params={'query': query, 'limit': initial_result_limit, 'fields': fields})
|
| 83 |
+
rsp.raise_for_status()
|
| 84 |
+
results = rsp.json()
|
| 85 |
+
total = results["total"]
|
| 86 |
+
if not total:
|
| 87 |
+
print('No matches found. Please try another query.')
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
print(f'Found {total} initial results. Showing up to {initial_result_limit}.')
|
| 91 |
+
papers = results['data']
|
| 92 |
+
# print("INITIAL RESULTS")
|
| 93 |
+
# print_papers(papers)
|
| 94 |
+
|
| 95 |
+
# Filter paper results
|
| 96 |
+
filtered_papers = list(filter(isValidPaper, papers))
|
| 97 |
+
|
| 98 |
+
# print("FILTERED RESULTS")
|
| 99 |
+
# print_papers(filtered_papers)
|
| 100 |
+
|
| 101 |
+
# rank paper results
|
| 102 |
+
ranked_papers = sorted(filtered_papers, key=lambda x: (x['year'], x['citationCount']), reverse=True)
|
| 103 |
+
|
| 104 |
+
# print("RANKED RESULTS")
|
| 105 |
+
# print_papers(ranked_papers)
|
| 106 |
+
|
| 107 |
+
# return 5 best papers
|
| 108 |
+
return ranked_papers[0:5]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# def find_recommendations(paper):
|
| 112 |
+
# print(f"Up to {result_limit} recommendations based on: {paper['title']}")
|
| 113 |
+
# rsp = requests.get(f"https://api.semanticscholar.org/recommendations/v1/papers/forpaper/{paper['paperId']}",
|
| 114 |
+
# headers={'X-API-KEY': S2_API_KEY},
|
| 115 |
+
# params={'fields': fields, 'limit': 10})
|
| 116 |
+
# rsp.raise_for_status()
|
| 117 |
+
# results = rsp.json()
|
| 118 |
+
# print_papers(results['recommendedPapers'])
|
| 119 |
+
# return results['recommendedPapers']
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def print_papers(papers):
|
| 123 |
+
for idx, paper in enumerate(papers):
|
| 124 |
+
print(f"PAPER {idx}")
|
| 125 |
+
for key, value in paper.items():
|
| 126 |
+
if key != 'abstract':
|
| 127 |
+
print(f"\t{key}: '{value}'")
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def isValidPaper(paper):
|
| 131 |
+
if paper['isOpenAccess'] and paper['abstract'] and paper['openAccessPdf']:
|
| 132 |
+
return True
|
| 133 |
+
else:
|
| 134 |
+
return False
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# def filter_papers(papers):
|
| 138 |
+
# filtered_papers = []
|
| 139 |
+
# for paper in papers:
|
| 140 |
+
# if paper['isOpenAccess'] and paper['abstract'] and paper['openAccessPdf']:
|
| 141 |
+
# # paper is acceptable
|
| 142 |
+
# filtered_papers.append(paper)
|
| 143 |
+
# return filtered_papers
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def load_pdf(file_path):
|
| 147 |
+
"""
|
| 148 |
+
Reads the text content from a PDF file and returns it as a single string.
|
| 149 |
+
|
| 150 |
+
Parameters:
|
| 151 |
+
- file_path (str): The file path to the PDF file.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
- str: The concatenated text content of all pages in the PDF.
|
| 155 |
+
"""
|
| 156 |
+
# Logic to read pdf
|
| 157 |
+
reader = PdfReader(file_path)
|
| 158 |
+
|
| 159 |
+
# Loop over each page and store it in a variable
|
| 160 |
+
text = ""
|
| 161 |
+
for page in reader.pages:
|
| 162 |
+
text += page.extract_text()
|
| 163 |
+
|
| 164 |
+
return text
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def split_text(text: str):
|
| 168 |
+
"""
|
| 169 |
+
Splits a text string into a list of non-empty substrings based on the specified pattern.
|
| 170 |
+
The "\n \n" pattern will split the document para by para
|
| 171 |
+
Parameters:
|
| 172 |
+
- text (str): The input text to be split.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
- List[str]: A list containing non-empty substrings obtained by splitting the input text.
|
| 176 |
+
|
| 177 |
+
"""
|
| 178 |
+
split_text = re.split('\n \n', text)
|
| 179 |
+
return [i for i in split_text if i != ""]
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class GeminiEmbeddingFunction(EmbeddingFunction):
|
| 183 |
+
"""
|
| 184 |
+
Custom embedding function using the Gemini AI API for document retrieval.
|
| 185 |
+
|
| 186 |
+
This class extends the EmbeddingFunction class and implements the __call__ method
|
| 187 |
+
to generate embeddings for a given set of documents using the Gemini AI API.
|
| 188 |
+
|
| 189 |
+
Parameters:
|
| 190 |
+
- input (Documents): A collection of documents to be embedded.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
- Embeddings: Embeddings generated for the input documents.
|
| 194 |
+
"""
|
| 195 |
+
def __call__(self, input: Documents) -> Embeddings:
|
| 196 |
+
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
| 197 |
+
if not gemini_api_key:
|
| 198 |
+
raise ValueError("Gemini API Key not provided. Please provide GEMINI_API_KEY as an environment variable")
|
| 199 |
+
genai.configure(api_key=gemini_api_key)
|
| 200 |
+
model = "models/embedding-001"
|
| 201 |
+
title = "Custom query"
|
| 202 |
+
return genai.embed_content(model=model,
|
| 203 |
+
content=input,
|
| 204 |
+
task_type="retrieval_document",
|
| 205 |
+
title=title)["embedding"]
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def create_chroma_db(documents: List, path: str, name: str):
|
| 209 |
+
"""
|
| 210 |
+
Creates a Chroma database using the provided documents, path, and collection name.
|
| 211 |
+
|
| 212 |
+
Parameters:
|
| 213 |
+
- documents: An iterable of documents to be added to the Chroma database.
|
| 214 |
+
- path (str): The path where the Chroma database will be stored.
|
| 215 |
+
- name (str): The name of the collection within the Chroma database.
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
- Tuple[chromadb.Collection, str]: A tuple containing the created Chroma Collection and its name.
|
| 219 |
+
"""
|
| 220 |
+
chroma_client = chromadb.PersistentClient(path=path)
|
| 221 |
+
db = chroma_client.create_collection(name=name, embedding_function=GeminiEmbeddingFunction())
|
| 222 |
+
|
| 223 |
+
for i, d in enumerate(documents):
|
| 224 |
+
db.add(documents=d, ids=str(i))
|
| 225 |
+
|
| 226 |
+
return db, name
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def load_chroma_collection(path, name):
|
| 230 |
+
"""
|
| 231 |
+
Loads an existing Chroma collection from the specified path with the given name.
|
| 232 |
+
|
| 233 |
+
Parameters:
|
| 234 |
+
- path (str): The path where the Chroma database is stored.
|
| 235 |
+
- name (str): The name of the collection within the Chroma database.
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
- chromadb.Collection: The loaded Chroma Collection.
|
| 239 |
+
"""
|
| 240 |
+
chroma_client = chromadb.PersistentClient(path=path)
|
| 241 |
+
db = chroma_client.get_collection(name=name, embedding_function=GeminiEmbeddingFunction())
|
| 242 |
+
|
| 243 |
+
return db
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def delete_chroma_collection(path, name):
|
| 247 |
+
chroma_client = chromadb.PersistentClient(path=path)
|
| 248 |
+
chroma_client.delete_collection(name=name)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def delete_all_paper_dbs(papers):
|
| 252 |
+
for idx in range(len(papers)):
|
| 253 |
+
delete_chroma_collection(path=RAG_DIR,
|
| 254 |
+
name=f"paper_{idx}")
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def get_relevant_passage(query, db, n_results):
|
| 258 |
+
passage = db.query(query_texts=[query], n_results=n_results)['documents'][0]
|
| 259 |
+
return passage
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def make_rag_prompt(query, relevant_passage):
|
| 263 |
+
escaped = relevant_passage.replace("'", "").replace('"', "").replace("\n", " ")
|
| 264 |
+
prompt = ("""You are a helpful and informative bot that answers questions using text from the reference passage included below. \
|
| 265 |
+
Be sure to respond in a complete sentence, being comprehensive, including all relevant background information. \
|
| 266 |
+
However, you are talking to a non-technical audience, so be sure to break down complicated concepts and \
|
| 267 |
+
strike a friendly and converstional tone. \
|
| 268 |
+
If the passage is irrelevant to the answer, you may ignore it.
|
| 269 |
+
QUESTION: '{query}'
|
| 270 |
+
PASSAGE: '{relevant_passage}'
|
| 271 |
+
|
| 272 |
+
ANSWER:
|
| 273 |
+
""").format(query=query, relevant_passage=escaped)
|
| 274 |
+
|
| 275 |
+
return prompt
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def generate_answer_prompt(prompt):
|
| 279 |
+
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
| 280 |
+
if not gemini_api_key:
|
| 281 |
+
raise ValueError("Gemini API Key not provided. Please provide GEMINI_API_KEY as an environment variable")
|
| 282 |
+
genai.configure(api_key=gemini_api_key)
|
| 283 |
+
model = genai.GenerativeModel('gemini-pro')
|
| 284 |
+
answer = model.generate_content(prompt)
|
| 285 |
+
return answer.text
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def generate_answer_db(db,query):
|
| 289 |
+
# retrieve top 3 relevant text chunks
|
| 290 |
+
relevant_text = get_relevant_passage(query,db,n_results=3)
|
| 291 |
+
# print(relevant_text)
|
| 292 |
+
prompt = make_rag_prompt(query,
|
| 293 |
+
relevant_passage="".join(relevant_text)) # joining the relevant chunks to create a single passage
|
| 294 |
+
answer = generate_answer_prompt(prompt)
|
| 295 |
+
|
| 296 |
+
return answer
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def pull_paper(paper_url, filepath):
|
| 300 |
+
r = requests.get(paper_url)
|
| 301 |
+
with open(filepath, 'wb') as outfile:
|
| 302 |
+
outfile.write(r.content)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def RAG_create_paper_dbs(papers):
|
| 306 |
+
if os.path.exists('RAG/contents') and os.path.isdir('RAG/contents'):
|
| 307 |
+
# Delete current dbs
|
| 308 |
+
shutil.rmtree('RAG/contents')
|
| 309 |
+
|
| 310 |
+
vector_dbs = {} # key: name, value: db
|
| 311 |
+
|
| 312 |
+
urls = [p['openAccessPdf']['url'] if p['openAccessPdf'] else None for p in papers]
|
| 313 |
+
|
| 314 |
+
print(urls)
|
| 315 |
+
|
| 316 |
+
for idx, test_paper in enumerate(papers):
|
| 317 |
+
# Get full paper
|
| 318 |
+
paper_title = test_paper['title']
|
| 319 |
+
paper_primary_author = test_paper['authors'][0]['name']
|
| 320 |
+
paper_url = test_paper['openAccessPdf']['url'] if test_paper['openAccessPdf'] else None
|
| 321 |
+
paper_year = test_paper['year']
|
| 322 |
+
paper_abstract = test_paper['abstract']
|
| 323 |
+
|
| 324 |
+
filename = f"{paper_primary_author} {paper_year} - {paper_title}.pdf"
|
| 325 |
+
filepath = f"{PAPERS_DIR}/{filename}"
|
| 326 |
+
|
| 327 |
+
print(f'getting {filename}...')
|
| 328 |
+
|
| 329 |
+
# SKIP FOR NOW - PAPERS ALREADY SAVED
|
| 330 |
+
skip_idxs = []
|
| 331 |
+
if idx not in skip_idxs and paper_url is not None:
|
| 332 |
+
try:
|
| 333 |
+
pull_paper(paper_url, filepath)
|
| 334 |
+
except TimeoutError:
|
| 335 |
+
print("Paper taking too long...")
|
| 336 |
+
|
| 337 |
+
print('\t- DONE!')
|
| 338 |
+
|
| 339 |
+
# intialize to abstract
|
| 340 |
+
pdf_text = paper_abstract
|
| 341 |
+
|
| 342 |
+
try:
|
| 343 |
+
PdfReader(filepath)
|
| 344 |
+
pdf_text = load_pdf(file_path=filepath)
|
| 345 |
+
except:
|
| 346 |
+
print("\t- invalid PDF file! Using abstract as fallback")
|
| 347 |
+
|
| 348 |
+
print(f'saving {filename} as a vector db...')
|
| 349 |
+
|
| 350 |
+
# Save paper as vector DB
|
| 351 |
+
# pdf_text = load_pdf(file_path=filepath)
|
| 352 |
+
|
| 353 |
+
chunked_text = split_text(text=pdf_text)
|
| 354 |
+
|
| 355 |
+
db, name = create_chroma_db(documents=chunked_text,
|
| 356 |
+
path=RAG_DIR,
|
| 357 |
+
name=f"paper_{idx}")
|
| 358 |
+
|
| 359 |
+
vector_dbs[name] = db
|
| 360 |
+
|
| 361 |
+
print('\t- DONE!')
|
| 362 |
+
|
| 363 |
+
return vector_dbs
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def ask_all_papers(vector_dbs, query):
|
| 367 |
+
answers = {}
|
| 368 |
+
for name, db in vector_dbs.items():
|
| 369 |
+
db = load_chroma_collection(path="RAG/contents", name=name)
|
| 370 |
+
|
| 371 |
+
answer = generate_answer_db(db, query=query)
|
| 372 |
+
|
| 373 |
+
# print(f"{name} answer: {answer}\n\n")
|
| 374 |
+
|
| 375 |
+
answers[name] = answer
|
| 376 |
+
|
| 377 |
+
return answers
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def GEMINI_list_features(answers):
|
| 381 |
+
generation_config = {
|
| 382 |
+
"temperature": 0.5
|
| 383 |
+
# "top_p": 0.95,
|
| 384 |
+
# "top_k": 0,
|
| 385 |
+
# "max_output_tokens": 8192,
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
model = genai.GenerativeModel(model_name='gemini-pro', generation_config=generation_config)
|
| 389 |
+
chat = model.start_chat(history=[])
|
| 390 |
+
|
| 391 |
+
prompt = f"""Given the following lists of variables considered,
|
| 392 |
+
return a list of the common variables. Only return a python list.
|
| 393 |
+
LISTS OF VARIABLES CONSIDERED: {answers}"""
|
| 394 |
+
|
| 395 |
+
response = chat.send_message(prompt)
|
| 396 |
+
response = markdown_to_raw(response.text)
|
| 397 |
+
|
| 398 |
+
return response
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def GEMINI_predict_target(initial_query: str):
|
| 402 |
+
# initialize gemini LLM
|
| 403 |
+
model = genai.GenerativeModel('gemini-pro')
|
| 404 |
+
chat = model.start_chat(history=[])
|
| 405 |
+
|
| 406 |
+
prompt = f"""Given this search query, what does the user want to predict?
|
| 407 |
+
QUERY: {initial_query}.
|
| 408 |
+
Only return the answer"""
|
| 409 |
+
|
| 410 |
+
response = chat.send_message(prompt)
|
| 411 |
+
predict_target = markdown_to_raw(response.text)
|
| 412 |
+
|
| 413 |
+
return predict_target
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def GEMINI_optimize_query(initial_query: str):
|
| 417 |
+
# initialize gemini LLM
|
| 418 |
+
model = genai.GenerativeModel('gemini-pro')
|
| 419 |
+
chat = model.start_chat(history=[])
|
| 420 |
+
|
| 421 |
+
prompt = f"""Given a search query, return an optimized version of the query to find related academic papers
|
| 422 |
+
QUERY: {initial_query}.
|
| 423 |
+
Only return the optimized query. If you feel the query is already concise and optimized, return the original query"""
|
| 424 |
+
|
| 425 |
+
response = chat.send_message(prompt)
|
| 426 |
+
optimized_query = markdown_to_raw(response.text)
|
| 427 |
+
|
| 428 |
+
return optimized_query
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def GEMINI_summarize_abstracts(initial_query: str, papers: str):
|
| 432 |
+
# initialize gemini LLM
|
| 433 |
+
model = genai.GenerativeModel('gemini-pro')
|
| 434 |
+
chat = model.start_chat(history=[])
|
| 435 |
+
|
| 436 |
+
prompt = f"""Given the following academic papers,
|
| 437 |
+
return a review of related literature for the search query: {query}.
|
| 438 |
+
Focus on data/key factors and methodologies considered.
|
| 439 |
+
Here are the papers {papers}
|
| 440 |
+
Include the paper urls at the end of the review of related literature.
|
| 441 |
+
"""
|
| 442 |
+
response = chat.send_message(prompt)
|
| 443 |
+
abstract_summary = markdown_to_raw(response.text)
|
| 444 |
+
|
| 445 |
+
return abstract_summary
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def predict(message, history):
|
| 449 |
+
# if message == "delete":
|
| 450 |
+
# delete_all_paper_dbs(papers)
|
| 451 |
+
# if history == []:
|
| 452 |
+
predict_target = GEMINI_predict_target(message)
|
| 453 |
+
papers = find_basis_papers(message)
|
| 454 |
+
vector_dbs = RAG_create_paper_dbs(papers)
|
| 455 |
+
# predict_target = 'solar site score'
|
| 456 |
+
|
| 457 |
+
answers = ask_all_papers(vector_dbs, f"list the independent variables considered to predict {predict_target}")
|
| 458 |
+
feature_list = GEMINI_list_features(answers)
|
| 459 |
+
res = ast.literal_eval(feature_list)
|
| 460 |
+
|
| 461 |
+
response = f"""
|
| 462 |
+
COMMON FEATURES TO CONSIDER: {res}
|
| 463 |
+
|
| 464 |
+
vectordb answers: {answers}
|
| 465 |
+
"""
|
| 466 |
+
|
| 467 |
+
delete_all_paper_dbs(papers)
|
| 468 |
+
# response = summarizer_chat.send_message(message)
|
| 469 |
+
# response_text = markdown_to_raw(response.text)
|
| 470 |
+
|
| 471 |
+
return response
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def main():
|
| 475 |
+
# GEMINI optimizes query
|
| 476 |
+
gr.ChatInterface(
|
| 477 |
+
predict,
|
| 478 |
+
title="LLM Research Helper",
|
| 479 |
+
description="""Start by inputing a brief description/title
|
| 480 |
+
of your research and our assistant will return a review of
|
| 481 |
+
related literature
|
| 482 |
+
|
| 483 |
+
ex. Finding optimal site locations for solar farms"""
|
| 484 |
+
).launch(debug=True)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
if __name__ == '__main__':
|
| 488 |
+
main()
|
papers/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiofiles==23.2.1
|
| 2 |
+
aiohttp==3.9.4
|
| 3 |
+
aiosignal==1.3.1
|
| 4 |
+
altair==5.3.0
|
| 5 |
+
annotated-types==0.6.0
|
| 6 |
+
anyio==4.3.0
|
| 7 |
+
appnope==0.1.4
|
| 8 |
+
argon2-cffi==23.1.0
|
| 9 |
+
argon2-cffi-bindings==21.2.0
|
| 10 |
+
arrow==1.3.0
|
| 11 |
+
arxiv==2.1.0
|
| 12 |
+
asgiref==3.8.1
|
| 13 |
+
asttokens==2.4.1
|
| 14 |
+
async-lru==2.0.4
|
| 15 |
+
attrs==23.2.0
|
| 16 |
+
Babel==2.14.0
|
| 17 |
+
backoff==2.2.1
|
| 18 |
+
bcrypt==4.1.3
|
| 19 |
+
beautifulsoup4==4.12.3
|
| 20 |
+
bleach==6.1.0
|
| 21 |
+
build==1.2.1
|
| 22 |
+
cachetools==5.3.3
|
| 23 |
+
certifi==2024.2.2
|
| 24 |
+
cffi==1.16.0
|
| 25 |
+
charset-normalizer==3.3.2
|
| 26 |
+
chroma-hnswlib==0.7.3
|
| 27 |
+
chromadb==0.4.22
|
| 28 |
+
click==8.1.7
|
| 29 |
+
coloredlogs==15.0.1
|
| 30 |
+
comm==0.2.2
|
| 31 |
+
contourpy==1.2.1
|
| 32 |
+
cycler==0.12.1
|
| 33 |
+
dataclasses-json==0.6.5
|
| 34 |
+
datasets==2.18.0
|
| 35 |
+
debugpy==1.8.1
|
| 36 |
+
decorator==5.1.1
|
| 37 |
+
defusedxml==0.7.1
|
| 38 |
+
Deprecated==1.2.14
|
| 39 |
+
dill==0.3.8
|
| 40 |
+
distro==1.9.0
|
| 41 |
+
evaluate==0.4.1
|
| 42 |
+
executing==2.0.1
|
| 43 |
+
fastapi==0.110.1
|
| 44 |
+
fastjsonschema==2.19.1
|
| 45 |
+
feedparser==6.0.10
|
| 46 |
+
ffmpy==0.3.2
|
| 47 |
+
filelock==3.13.4
|
| 48 |
+
flatbuffers==24.3.25
|
| 49 |
+
fonttools==4.51.0
|
| 50 |
+
fqdn==1.5.1
|
| 51 |
+
frozenlist==1.4.1
|
| 52 |
+
fsspec==2024.2.0
|
| 53 |
+
google-ai-generativelanguage==0.6.2
|
| 54 |
+
google-api-core==2.19.0
|
| 55 |
+
google-api-python-client==2.125.0
|
| 56 |
+
google-auth==2.29.0
|
| 57 |
+
google-auth-httplib2==0.2.0
|
| 58 |
+
google-generativeai==0.5.2
|
| 59 |
+
googleapis-common-protos==1.63.0
|
| 60 |
+
gradio==4.28.3
|
| 61 |
+
gradio_client==0.16.0
|
| 62 |
+
grpcio==1.62.1
|
| 63 |
+
grpcio-status==1.62.1
|
| 64 |
+
h11==0.14.0
|
| 65 |
+
httpcore==1.0.5
|
| 66 |
+
httplib2==0.22.0
|
| 67 |
+
httptools==0.6.1
|
| 68 |
+
httpx==0.27.0
|
| 69 |
+
huggingface-hub==0.22.2
|
| 70 |
+
humanfriendly==10.0
|
| 71 |
+
idna==3.7
|
| 72 |
+
importlib-metadata==7.0.0
|
| 73 |
+
importlib_resources==6.4.0
|
| 74 |
+
ipykernel==6.29.4
|
| 75 |
+
ipython==8.23.0
|
| 76 |
+
ipywidgets==8.1.2
|
| 77 |
+
isoduration==20.11.0
|
| 78 |
+
jedi==0.19.1
|
| 79 |
+
Jinja2==3.1.3
|
| 80 |
+
jiwer==3.0.3
|
| 81 |
+
json5==0.9.25
|
| 82 |
+
jsonpatch==1.33
|
| 83 |
+
jsonpointer==2.4
|
| 84 |
+
jsonschema==4.21.1
|
| 85 |
+
jsonschema-specifications==2023.12.1
|
| 86 |
+
jupyter==1.0.0
|
| 87 |
+
jupyter-console==6.6.3
|
| 88 |
+
jupyter-events==0.10.0
|
| 89 |
+
jupyter-lsp==2.2.5
|
| 90 |
+
jupyter_client==8.6.1
|
| 91 |
+
jupyter_core==5.7.2
|
| 92 |
+
jupyter_server==2.14.0
|
| 93 |
+
jupyter_server_terminals==0.5.3
|
| 94 |
+
jupyterlab==4.1.6
|
| 95 |
+
jupyterlab_pygments==0.3.0
|
| 96 |
+
jupyterlab_server==2.26.0
|
| 97 |
+
jupyterlab_widgets==3.0.10
|
| 98 |
+
kiwisolver==1.4.5
|
| 99 |
+
kubernetes==29.0.0
|
| 100 |
+
langchain==0.1.17
|
| 101 |
+
langchain-community==0.0.36
|
| 102 |
+
langchain-core==0.1.50
|
| 103 |
+
langchain-google-genai==1.0.3
|
| 104 |
+
langchain-text-splitters==0.0.1
|
| 105 |
+
langsmith==0.1.54
|
| 106 |
+
markdown-it-py==3.0.0
|
| 107 |
+
MarkupSafe==2.1.5
|
| 108 |
+
marshmallow==3.21.2
|
| 109 |
+
matplotlib==3.8.4
|
| 110 |
+
matplotlib-inline==0.1.6
|
| 111 |
+
mdurl==0.1.2
|
| 112 |
+
mistune==3.0.2
|
| 113 |
+
mmh3==4.1.0
|
| 114 |
+
monotonic==1.6
|
| 115 |
+
mpmath==1.3.0
|
| 116 |
+
multidict==6.0.5
|
| 117 |
+
multiprocess==0.70.16
|
| 118 |
+
mypy-extensions==1.0.0
|
| 119 |
+
nbclient==0.10.0
|
| 120 |
+
nbconvert==7.16.3
|
| 121 |
+
nbformat==5.10.4
|
| 122 |
+
nest-asyncio==1.6.0
|
| 123 |
+
notebook==7.1.2
|
| 124 |
+
notebook_shim==0.2.4
|
| 125 |
+
numpy==1.26.4
|
| 126 |
+
oauthlib==3.2.2
|
| 127 |
+
onnxruntime==1.17.3
|
| 128 |
+
openai==1.17.1
|
| 129 |
+
opentelemetry-api==1.24.0
|
| 130 |
+
opentelemetry-exporter-otlp-proto-common==1.24.0
|
| 131 |
+
opentelemetry-exporter-otlp-proto-grpc==1.24.0
|
| 132 |
+
opentelemetry-instrumentation==0.45b0
|
| 133 |
+
opentelemetry-instrumentation-asgi==0.45b0
|
| 134 |
+
opentelemetry-instrumentation-fastapi==0.45b0
|
| 135 |
+
opentelemetry-proto==1.24.0
|
| 136 |
+
opentelemetry-sdk==1.24.0
|
| 137 |
+
opentelemetry-semantic-conventions==0.45b0
|
| 138 |
+
opentelemetry-util-http==0.45b0
|
| 139 |
+
orjson==3.10.0
|
| 140 |
+
overrides==7.7.0
|
| 141 |
+
packaging==23.2
|
| 142 |
+
pandas==2.2.2
|
| 143 |
+
pandocfilters==1.5.1
|
| 144 |
+
parso==0.8.4
|
| 145 |
+
pexpect==4.9.0
|
| 146 |
+
pillow==10.3.0
|
| 147 |
+
platformdirs==4.2.0
|
| 148 |
+
posthog==3.5.0
|
| 149 |
+
prometheus_client==0.20.0
|
| 150 |
+
prompt-toolkit==3.0.43
|
| 151 |
+
proto-plus==1.23.0
|
| 152 |
+
protobuf==4.25.3
|
| 153 |
+
psutil==5.9.8
|
| 154 |
+
ptyprocess==0.7.0
|
| 155 |
+
pulsar-client==3.5.0
|
| 156 |
+
pure-eval==0.2.2
|
| 157 |
+
pyarrow==15.0.2
|
| 158 |
+
pyarrow-hotfix==0.6
|
| 159 |
+
pyasn1==0.6.0
|
| 160 |
+
pyasn1_modules==0.4.0
|
| 161 |
+
pycparser==2.22
|
| 162 |
+
pydantic==2.7.0
|
| 163 |
+
pydantic_core==2.18.1
|
| 164 |
+
pydub==0.25.1
|
| 165 |
+
Pygments==2.17.2
|
| 166 |
+
pyparsing==3.1.2
|
| 167 |
+
pypdf==4.0.0
|
| 168 |
+
PyPDF2==3.0.1
|
| 169 |
+
PyPika==0.48.9
|
| 170 |
+
pyproject_hooks==1.1.0
|
| 171 |
+
python-dateutil==2.9.0.post0
|
| 172 |
+
python-dotenv==1.0.1
|
| 173 |
+
python-json-logger==2.0.7
|
| 174 |
+
python-multipart==0.0.9
|
| 175 |
+
pytz==2024.1
|
| 176 |
+
PyYAML==6.0.1
|
| 177 |
+
pyzmq==25.1.2
|
| 178 |
+
qtconsole==5.5.1
|
| 179 |
+
QtPy==2.4.1
|
| 180 |
+
rapidfuzz==3.8.1
|
| 181 |
+
referencing==0.34.0
|
| 182 |
+
requests==2.31.0
|
| 183 |
+
requests-oauthlib==2.0.0
|
| 184 |
+
responses==0.18.0
|
| 185 |
+
rfc3339-validator==0.1.4
|
| 186 |
+
rfc3986-validator==0.1.1
|
| 187 |
+
rich==13.7.1
|
| 188 |
+
rpds-py==0.18.0
|
| 189 |
+
rsa==4.9
|
| 190 |
+
ruff==0.3.7
|
| 191 |
+
semantic-version==2.10.0
|
| 192 |
+
Send2Trash==1.8.3
|
| 193 |
+
setuptools==68.2.2
|
| 194 |
+
sgmllib3k==1.0.0
|
| 195 |
+
shellingham==1.5.4
|
| 196 |
+
six==1.16.0
|
| 197 |
+
sniffio==1.3.1
|
| 198 |
+
soupsieve==2.5
|
| 199 |
+
SQLAlchemy==2.0.30
|
| 200 |
+
stack-data==0.6.3
|
| 201 |
+
starlette==0.37.2
|
| 202 |
+
sympy==1.12
|
| 203 |
+
tenacity==8.2.3
|
| 204 |
+
terminado==0.18.1
|
| 205 |
+
tinycss2==1.2.1
|
| 206 |
+
tokenizers==0.19.1
|
| 207 |
+
tomlkit==0.12.0
|
| 208 |
+
toolz==0.12.1
|
| 209 |
+
tornado==6.4
|
| 210 |
+
tqdm==4.66.2
|
| 211 |
+
traitlets==5.14.2
|
| 212 |
+
typer==0.12.3
|
| 213 |
+
types-python-dateutil==2.9.0.20240316
|
| 214 |
+
typing-inspect==0.9.0
|
| 215 |
+
typing_extensions==4.11.0
|
| 216 |
+
tzdata==2024.1
|
| 217 |
+
uri-template==1.3.0
|
| 218 |
+
uritemplate==4.1.1
|
| 219 |
+
urllib3==2.2.1
|
| 220 |
+
uvicorn==0.29.0
|
| 221 |
+
uvloop==0.19.0
|
| 222 |
+
watchfiles==0.21.0
|
| 223 |
+
wcwidth==0.2.13
|
| 224 |
+
webcolors==1.13
|
| 225 |
+
webencodings==0.5.1
|
| 226 |
+
websocket-client==1.7.0
|
| 227 |
+
websockets==11.0.3
|
| 228 |
+
wheel==0.41.2
|
| 229 |
+
widgetsnbextension==4.0.10
|
| 230 |
+
wrapt==1.16.0
|
| 231 |
+
xxhash==3.4.1
|
| 232 |
+
yarl==1.9.4
|
| 233 |
+
zipp==3.18.1
|