kiyer commited on
Commit
99c0738
Β·
verified Β·
1 Parent(s): 0bf4d73

trying some optimizations

Browse files
Files changed (1) hide show
  1. app_gradio.py +342 -4
app_gradio.py CHANGED
@@ -45,10 +45,231 @@ from string import punctuation
45
  import pytextrank
46
  from prompts import *
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  openai_key = os.environ['openai_key']
49
  cohere_key = os.environ['cohere_key']
50
  os.environ["OPENAI_API_KEY"] = os.environ['openai_key']
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def load_nlp():
53
  nlp = spacy.load("en_core_web_sm")
54
  nlp.add_pipe("textrank")
@@ -89,9 +310,11 @@ def load_arxiv_corpus():
89
  # arxiv_corpus.load_faiss_index('embed', 'data/astrophindex.faiss')
90
 
91
  # keeping it up to date with the dataset
92
- arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data', split='train')
93
- arxiv_corpus.add_faiss_index(column='embed')
94
- print('loading arxiv corpus from disk')
 
 
95
  return arxiv_corpus
96
 
97
  class RetrievalSystem():
@@ -649,6 +872,121 @@ def run_pathfinder(query, top_k, extra_keywords, toggles, prompt_type, rag_type,
649
 
650
  yield formatted_df, rag_answer['answer'], consensus, qn_type, fig
651
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  def create_interface():
653
  custom_css = """
654
  #custom-slider-* {
@@ -687,7 +1025,7 @@ def create_interface():
687
 
688
  inputs = [query, top_k, keywords, toggles, prompt_type, rag_type]
689
  outputs = [ret_papers, search_results_state, qntype, conc, plot]
690
- btn.click(fn=run_pathfinder, inputs=inputs, outputs=outputs)
691
 
692
  return demo
693
 
 
45
  import pytextrank
46
  from prompts import *
47
 
48
+ import os
49
+ from datasets import load_dataset
50
+ import pickle
51
+ import faiss
52
+ import numpy as np
53
+ from functools import lru_cache
54
+ import asyncio
55
+ import aiohttp
56
+ from concurrent.futures import ThreadPoolExecutor
57
+ import time
58
+
59
+ # Add to your main function
60
+ import gc
61
+
62
+ def cleanup_memory():
63
+ """Force garbage collection and clear caches"""
64
+ gc.collect()
65
+ chromadb.api.client.SharedSystemClient.clear_system_cache()
66
+
67
  openai_key = os.environ['openai_key']
68
  cohere_key = os.environ['cohere_key']
69
  os.environ["OPENAI_API_KEY"] = os.environ['openai_key']
70
 
71
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid tokenizer warnings
72
+ os.environ["HF_DATASETS_CACHE"] = "./cache" # Control cache location
73
+
74
+ # Use Hugging Face's built-in caching
75
+ from datasets import enable_caching
76
+ enable_caching()
77
+
78
+ class OptimizedDatasetLoader:
79
+ def __init__(self, cache_dir="./cache"):
80
+ self.cache_dir = cache_dir
81
+ os.makedirs(cache_dir, exist_ok=True)
82
+
83
+ @lru_cache(maxsize=1)
84
+ def load_arxiv_corpus_cached(self):
85
+ """Load dataset with aggressive caching"""
86
+ cache_path = os.path.join(self.cache_dir, "arxiv_corpus.pkl")
87
+ index_path = os.path.join(self.cache_dir, "faiss_index.bin")
88
+
89
+ # Try to load from cache first
90
+ if os.path.exists(cache_path) and os.path.exists(index_path):
91
+ print("Loading from cache...")
92
+ with open(cache_path, 'rb') as f:
93
+ arxiv_corpus = pickle.load(f)
94
+
95
+ # Load pre-built FAISS index
96
+ index = faiss.read_index(index_path)
97
+ arxiv_corpus._indexes = {'embed': index}
98
+ return arxiv_corpus
99
+
100
+ # If not cached, load and cache
101
+ print("Loading dataset and building cache...")
102
+ arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data', split='train')
103
+ arxiv_corpus.add_faiss_index(column='embed')
104
+
105
+ # Cache the dataset
106
+ with open(cache_path, 'wb') as f:
107
+ pickle.dump(arxiv_corpus, f)
108
+
109
+ # Cache the FAISS index
110
+ faiss.write_index(arxiv_corpus._indexes['embed'], index_path)
111
+
112
+ return arxiv_corpus
113
+
114
+ class AsyncRetrievalSystem:
115
+ def __init__(self):
116
+ self.dataset = arxiv_corpus
117
+ self.openai_key = os.environ['openai_key']
118
+ self.executor = ThreadPoolExecutor(max_workers=4)
119
+
120
+ async def async_embedding_call(self, texts, session):
121
+ """Async embedding API call"""
122
+ headers = {
123
+ "Authorization": f"Bearer {self.openai_key}",
124
+ "Content-Type": "application/json"
125
+ }
126
+
127
+ data = {
128
+ "input": texts if isinstance(texts, list) else [texts],
129
+ "model": "text-embedding-3-small"
130
+ }
131
+
132
+ async with session.post(
133
+ "https://api.openai.com/v1/embeddings",
134
+ headers=headers,
135
+ json=data
136
+ ) as response:
137
+ result = await response.json()
138
+ return [item['embedding'] for item in result['data']]
139
+
140
+ async def async_llm_call(self, messages, session, temperature=0):
141
+ """Async LLM API call"""
142
+ headers = {
143
+ "Authorization": f"Bearer {self.openai_key}",
144
+ "Content-Type": "application/json"
145
+ }
146
+
147
+ data = {
148
+ "model": "gpt-4o-mini",
149
+ "messages": messages,
150
+ "temperature": temperature
151
+ }
152
+
153
+ async with session.post(
154
+ "https://api.openai.com/v1/chat/completions",
155
+ headers=headers,
156
+ json=data
157
+ ) as response:
158
+ result = await response.json()
159
+ return result['choices'][0]['message']['content']
160
+
161
+ async def parallel_retrieve_and_analyze(self, query, top_k=10):
162
+ """Run multiple operations in parallel"""
163
+ async with aiohttp.ClientSession() as session:
164
+ # Start all async operations
165
+ tasks = []
166
+
167
+ # 1. Get query embedding
168
+ embedding_task = self.async_embedding_call(query, session)
169
+ tasks.append(embedding_task)
170
+
171
+ # 2. Generate HyDE document (if enabled)
172
+ hyde_messages = [
173
+ ("system", "You are an expert astronomer. Generate an abstract..."),
174
+ ("human", query)
175
+ ]
176
+ hyde_task = self.async_llm_call(hyde_messages, session, temperature=0.5)
177
+ tasks.append(hyde_task)
178
+
179
+ # 3. Question type classification
180
+ qtype_messages = [
181
+ ("system", "Classify this question type..."),
182
+ ("human", query)
183
+ ]
184
+ qtype_task = self.async_llm_call(qtype_messages, session)
185
+ tasks.append(qtype_task)
186
+
187
+ # Wait for all to complete
188
+ query_embedding, hyde_doc, question_type = await asyncio.gather(*tasks)
189
+
190
+ return {
191
+ 'embedding': query_embedding[0],
192
+ 'hyde_doc': hyde_doc,
193
+ 'question_type': question_type
194
+ }
195
+
196
+ def run_parallel_search(self, query, top_k=10):
197
+ """Wrapper to run async function"""
198
+ return asyncio.run(self.parallel_retrieve_and_analyze(query, top_k))
199
+
200
+ class OptimizedEmbedding:
201
+ def __init__(self, openai_key, batch_size=100):
202
+ self.client = OpenAI(api_key=openai_key)
203
+ self.batch_size = batch_size
204
+ self.embed_model = "text-embedding-3-small"
205
+
206
+ def batch_embeddings(self, texts):
207
+ """Process embeddings in batches for efficiency"""
208
+ all_embeddings = []
209
+
210
+ for i in range(0, len(texts), self.batch_size):
211
+ batch = texts[i:i + self.batch_size]
212
+ try:
213
+ response = self.client.embeddings.create(
214
+ input=batch,
215
+ model=self.embed_model
216
+ )
217
+ batch_embeddings = [item.embedding for item in response.data]
218
+ all_embeddings.extend(batch_embeddings)
219
+ except Exception as e:
220
+ print(f"Batch embedding failed: {e}")
221
+ # Fallback to individual processing
222
+ for text in batch:
223
+ emb = self.client.embeddings.create(
224
+ input=[text],
225
+ model=self.embed_model
226
+ ).data[0].embedding
227
+ all_embeddings.append(emb)
228
+
229
+ return all_embeddings
230
+
231
+ class MemoryOptimizedRAG:
232
+ def __init__(self):
233
+ self.vectorstore_cache = {}
234
+
235
+ def create_vectorstore_cached(self, documents, collection_name):
236
+ """Cache vectorstore to avoid recreation"""
237
+ cache_key = f"{collection_name}_{len(documents)}"
238
+
239
+ if cache_key in self.vectorstore_cache:
240
+ return self.vectorstore_cache[cache_key]
241
+
242
+ # Clear ChromaDB cache before creating new vectorstore
243
+ chromadb.api.client.SharedSystemClient.clear_system_cache()
244
+
245
+ text_splitter = RecursiveCharacterTextSplitter(
246
+ chunk_size=150,
247
+ chunk_overlap=50,
248
+ add_start_index=True
249
+ )
250
+ splits = text_splitter.split_documents(documents)
251
+
252
+ vectorstore = Chroma.from_documents(
253
+ documents=splits,
254
+ embedding=embeddings,
255
+ collection_name=collection_name
256
+ )
257
+
258
+ self.vectorstore_cache[cache_key] = vectorstore
259
+ return vectorstore
260
+
261
+ def cleanup_old_vectorstores(self, max_cache_size=3):
262
+ """Clean up old vectorstores to free memory"""
263
+ if len(self.vectorstore_cache) > max_cache_size:
264
+ # Remove oldest entries
265
+ oldest_keys = list(self.vectorstore_cache.keys())[:-max_cache_size]
266
+ for key in oldest_keys:
267
+ try:
268
+ self.vectorstore_cache[key].delete_collection()
269
+ except:
270
+ pass
271
+ del self.vectorstore_cache[key]
272
+
273
  def load_nlp():
274
  nlp = spacy.load("en_core_web_sm")
275
  nlp.add_pipe("textrank")
 
310
  # arxiv_corpus.load_faiss_index('embed', 'data/astrophindex.faiss')
311
 
312
  # keeping it up to date with the dataset
313
+ # arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data', split='train')
314
+ # arxiv_corpus.add_faiss_index(column='embed')
315
+ # print('loading arxiv corpus from disk')
316
+ loader = OptimizedDatasetLoader()
317
+ arxiv_corpus = loader.load_arxiv_corpus_cached()
318
  return arxiv_corpus
319
 
320
  class RetrievalSystem():
 
872
 
873
  yield formatted_df, rag_answer['answer'], consensus, qn_type, fig
874
 
875
+
876
+
877
+ async def run_pathfinder_optimized(query, top_k, extra_keywords, toggles,
878
+ prompt_type, rag_type, ec=None, progress=None):
879
+ """Optimized version of run_pathfinder with parallel processing"""
880
+
881
+ # Early validation
882
+ if check_mod(query):
883
+ yield None, "Query flagged by moderation", None, None, None
884
+ return
885
+
886
+ # Setup
887
+ input_keywords = [kw.strip() for kw in extra_keywords.split(',')] if extra_keywords else []
888
+ query_keywords = get_keywords(query)
889
+ ec.query_input_keywords = input_keywords + query_keywords
890
+ ec.toggles = toggles
891
+
892
+ # Configure retrieval method
893
+ ec.hyde = rag_type in ["Semantic + HyDE", "Semantic + HyDE + CoHERE"]
894
+ ec.rerank = rag_type in ["Semantic + CoHERE", "Semantic + HyDE + CoHERE"]
895
+
896
+ try:
897
+ if prompt_type == "Deep Research (BETA)":
898
+ # Deep research is inherently sequential, keep original implementation
899
+ formatted_df, rag_answer = deep_research(query, top_k=top_k, ec=ec)
900
+ yield formatted_df, rag_answer['answer'], None, None, None
901
+ else:
902
+ # Phase 1: Parallel initial operations
903
+ gr.Info("Starting parallel search operations...")
904
+
905
+ async with aiohttp.ClientSession() as session:
906
+ # Start retrieval
907
+ retrieval_task = asyncio.create_task(
908
+ async_retrieve(ec, query, top_k, session)
909
+ )
910
+
911
+ # Start question type analysis (independent operation)
912
+ qtype_task = asyncio.create_task(
913
+ async_question_type_analysis(query, session)
914
+ )
915
+
916
+ # Wait for retrieval to complete first
917
+ rs, small_df = await retrieval_task
918
+ formatted_df = ec.return_formatted_df(rs, small_df)
919
+ yield formatted_df, None, None, None, None
920
+
921
+ # Phase 2: RAG QA while question type analysis continues
922
+ gr.Info("Generating answer...")
923
+ rag_answer = await async_rag_qa(query, formatted_df, prompt_type, session)
924
+ yield formatted_df, rag_answer['answer'], None, None, None
925
+
926
+ # Phase 3: Parallel consensus and remaining operations
927
+ gr.Info("Finalizing analysis...")
928
+
929
+ consensus_task = asyncio.create_task(
930
+ async_consensus_evaluation(query, formatted_df, session)
931
+ )
932
+
933
+ plot_task = asyncio.create_task(
934
+ async_make_plot(formatted_df, top_k)
935
+ )
936
+
937
+ # Wait for question type and consensus
938
+ question_type_gen, consensus_answer = await asyncio.gather(
939
+ qtype_task, consensus_task
940
+ )
941
+
942
+ # Format outputs
943
+ consensus = f'## Consensus \n{consensus_answer.consensus}\n\n{consensus_answer.explanation}\n\n > Relevance: {consensus_answer.relevance_score:.1f}'
944
+ qn_type = format_question_type(question_type_gen)
945
+
946
+ yield formatted_df, rag_answer['answer'], consensus, qn_type, None
947
+
948
+ # Final plot
949
+ fig = await plot_task
950
+ yield formatted_df, rag_answer['answer'], consensus, qn_type, fig
951
+
952
+ except Exception as e:
953
+ print(f"Error in pathfinder: {e}")
954
+ yield None, f"Error: {str(e)}", None, None, None
955
+
956
+ async def async_retrieve(ec, query, top_k, session):
957
+ """Async wrapper for retrieval"""
958
+ loop = asyncio.get_event_loop()
959
+ return await loop.run_in_executor(None, ec.retrieve, query, top_k, True)
960
+
961
+ async def async_rag_qa(query, formatted_df, prompt_type, session):
962
+ """Async wrapper for RAG QA"""
963
+ loop = asyncio.get_event_loop()
964
+ return await loop.run_in_executor(None, run_rag_qa, query, formatted_df, prompt_type)
965
+
966
+ async def async_consensus_evaluation(query, formatted_df, session):
967
+ """Async consensus evaluation"""
968
+ abstracts = [formatted_df['abstract'][i+1] for i in range(len(formatted_df))]
969
+ loop = asyncio.get_event_loop()
970
+ return await loop.run_in_executor(None, evaluate_overall_consensus, query, abstracts)
971
+
972
+ async def async_question_type_analysis(query, session):
973
+ """Async question type analysis"""
974
+ loop = asyncio.get_event_loop()
975
+ return await loop.run_in_executor(None, guess_question_type, query)
976
+
977
+ async def async_make_plot(formatted_df, top_k):
978
+ """Async plot generation"""
979
+ loop = asyncio.get_event_loop()
980
+ return await loop.run_in_executor(None, make_embedding_plot, formatted_df, top_k, None)
981
+
982
+ def format_question_type(question_type_gen):
983
+ """Clean up question type output"""
984
+ if '<categorization>' in question_type_gen:
985
+ question_type_gen = question_type_gen.split('<categorization>')[1]
986
+ if '</categorization>' in question_type_gen:
987
+ question_type_gen = question_type_gen.split('</categorization>')[0]
988
+ return question_type_gen.replace('\n', ' \n')
989
+
990
  def create_interface():
991
  custom_css = """
992
  #custom-slider-* {
 
1025
 
1026
  inputs = [query, top_k, keywords, toggles, prompt_type, rag_type]
1027
  outputs = [ret_papers, search_results_state, qntype, conc, plot]
1028
+ btn.click(fn=run_pathfinder_optimized, inputs=inputs, outputs=outputs)
1029
 
1030
  return demo
1031