Spaces:
Runtime error
Runtime error
trying some optimizations
Browse files- app_gradio.py +342 -4
app_gradio.py
CHANGED
|
@@ -45,10 +45,231 @@ from string import punctuation
|
|
| 45 |
import pytextrank
|
| 46 |
from prompts import *
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
openai_key = os.environ['openai_key']
|
| 49 |
cohere_key = os.environ['cohere_key']
|
| 50 |
os.environ["OPENAI_API_KEY"] = os.environ['openai_key']
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def load_nlp():
|
| 53 |
nlp = spacy.load("en_core_web_sm")
|
| 54 |
nlp.add_pipe("textrank")
|
|
@@ -89,9 +310,11 @@ def load_arxiv_corpus():
|
|
| 89 |
# arxiv_corpus.load_faiss_index('embed', 'data/astrophindex.faiss')
|
| 90 |
|
| 91 |
# keeping it up to date with the dataset
|
| 92 |
-
arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data', split='train')
|
| 93 |
-
arxiv_corpus.add_faiss_index(column='embed')
|
| 94 |
-
print('loading arxiv corpus from disk')
|
|
|
|
|
|
|
| 95 |
return arxiv_corpus
|
| 96 |
|
| 97 |
class RetrievalSystem():
|
|
@@ -649,6 +872,121 @@ def run_pathfinder(query, top_k, extra_keywords, toggles, prompt_type, rag_type,
|
|
| 649 |
|
| 650 |
yield formatted_df, rag_answer['answer'], consensus, qn_type, fig
|
| 651 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 652 |
def create_interface():
|
| 653 |
custom_css = """
|
| 654 |
#custom-slider-* {
|
|
@@ -687,7 +1025,7 @@ def create_interface():
|
|
| 687 |
|
| 688 |
inputs = [query, top_k, keywords, toggles, prompt_type, rag_type]
|
| 689 |
outputs = [ret_papers, search_results_state, qntype, conc, plot]
|
| 690 |
-
btn.click(fn=
|
| 691 |
|
| 692 |
return demo
|
| 693 |
|
|
|
|
| 45 |
import pytextrank
|
| 46 |
from prompts import *
|
| 47 |
|
| 48 |
+
import os
|
| 49 |
+
from datasets import load_dataset
|
| 50 |
+
import pickle
|
| 51 |
+
import faiss
|
| 52 |
+
import numpy as np
|
| 53 |
+
from functools import lru_cache
|
| 54 |
+
import asyncio
|
| 55 |
+
import aiohttp
|
| 56 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 57 |
+
import time
|
| 58 |
+
|
| 59 |
+
# Add to your main function
|
| 60 |
+
import gc
|
| 61 |
+
|
| 62 |
+
def cleanup_memory():
|
| 63 |
+
"""Force garbage collection and clear caches"""
|
| 64 |
+
gc.collect()
|
| 65 |
+
chromadb.api.client.SharedSystemClient.clear_system_cache()
|
| 66 |
+
|
| 67 |
openai_key = os.environ['openai_key']
|
| 68 |
cohere_key = os.environ['cohere_key']
|
| 69 |
os.environ["OPENAI_API_KEY"] = os.environ['openai_key']
|
| 70 |
|
| 71 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid tokenizer warnings
|
| 72 |
+
os.environ["HF_DATASETS_CACHE"] = "./cache" # Control cache location
|
| 73 |
+
|
| 74 |
+
# Use Hugging Face's built-in caching
|
| 75 |
+
from datasets import enable_caching
|
| 76 |
+
enable_caching()
|
| 77 |
+
|
| 78 |
+
class OptimizedDatasetLoader:
|
| 79 |
+
def __init__(self, cache_dir="./cache"):
|
| 80 |
+
self.cache_dir = cache_dir
|
| 81 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 82 |
+
|
| 83 |
+
@lru_cache(maxsize=1)
|
| 84 |
+
def load_arxiv_corpus_cached(self):
|
| 85 |
+
"""Load dataset with aggressive caching"""
|
| 86 |
+
cache_path = os.path.join(self.cache_dir, "arxiv_corpus.pkl")
|
| 87 |
+
index_path = os.path.join(self.cache_dir, "faiss_index.bin")
|
| 88 |
+
|
| 89 |
+
# Try to load from cache first
|
| 90 |
+
if os.path.exists(cache_path) and os.path.exists(index_path):
|
| 91 |
+
print("Loading from cache...")
|
| 92 |
+
with open(cache_path, 'rb') as f:
|
| 93 |
+
arxiv_corpus = pickle.load(f)
|
| 94 |
+
|
| 95 |
+
# Load pre-built FAISS index
|
| 96 |
+
index = faiss.read_index(index_path)
|
| 97 |
+
arxiv_corpus._indexes = {'embed': index}
|
| 98 |
+
return arxiv_corpus
|
| 99 |
+
|
| 100 |
+
# If not cached, load and cache
|
| 101 |
+
print("Loading dataset and building cache...")
|
| 102 |
+
arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data', split='train')
|
| 103 |
+
arxiv_corpus.add_faiss_index(column='embed')
|
| 104 |
+
|
| 105 |
+
# Cache the dataset
|
| 106 |
+
with open(cache_path, 'wb') as f:
|
| 107 |
+
pickle.dump(arxiv_corpus, f)
|
| 108 |
+
|
| 109 |
+
# Cache the FAISS index
|
| 110 |
+
faiss.write_index(arxiv_corpus._indexes['embed'], index_path)
|
| 111 |
+
|
| 112 |
+
return arxiv_corpus
|
| 113 |
+
|
| 114 |
+
class AsyncRetrievalSystem:
|
| 115 |
+
def __init__(self):
|
| 116 |
+
self.dataset = arxiv_corpus
|
| 117 |
+
self.openai_key = os.environ['openai_key']
|
| 118 |
+
self.executor = ThreadPoolExecutor(max_workers=4)
|
| 119 |
+
|
| 120 |
+
async def async_embedding_call(self, texts, session):
|
| 121 |
+
"""Async embedding API call"""
|
| 122 |
+
headers = {
|
| 123 |
+
"Authorization": f"Bearer {self.openai_key}",
|
| 124 |
+
"Content-Type": "application/json"
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
data = {
|
| 128 |
+
"input": texts if isinstance(texts, list) else [texts],
|
| 129 |
+
"model": "text-embedding-3-small"
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
async with session.post(
|
| 133 |
+
"https://api.openai.com/v1/embeddings",
|
| 134 |
+
headers=headers,
|
| 135 |
+
json=data
|
| 136 |
+
) as response:
|
| 137 |
+
result = await response.json()
|
| 138 |
+
return [item['embedding'] for item in result['data']]
|
| 139 |
+
|
| 140 |
+
async def async_llm_call(self, messages, session, temperature=0):
|
| 141 |
+
"""Async LLM API call"""
|
| 142 |
+
headers = {
|
| 143 |
+
"Authorization": f"Bearer {self.openai_key}",
|
| 144 |
+
"Content-Type": "application/json"
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
data = {
|
| 148 |
+
"model": "gpt-4o-mini",
|
| 149 |
+
"messages": messages,
|
| 150 |
+
"temperature": temperature
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
async with session.post(
|
| 154 |
+
"https://api.openai.com/v1/chat/completions",
|
| 155 |
+
headers=headers,
|
| 156 |
+
json=data
|
| 157 |
+
) as response:
|
| 158 |
+
result = await response.json()
|
| 159 |
+
return result['choices'][0]['message']['content']
|
| 160 |
+
|
| 161 |
+
async def parallel_retrieve_and_analyze(self, query, top_k=10):
|
| 162 |
+
"""Run multiple operations in parallel"""
|
| 163 |
+
async with aiohttp.ClientSession() as session:
|
| 164 |
+
# Start all async operations
|
| 165 |
+
tasks = []
|
| 166 |
+
|
| 167 |
+
# 1. Get query embedding
|
| 168 |
+
embedding_task = self.async_embedding_call(query, session)
|
| 169 |
+
tasks.append(embedding_task)
|
| 170 |
+
|
| 171 |
+
# 2. Generate HyDE document (if enabled)
|
| 172 |
+
hyde_messages = [
|
| 173 |
+
("system", "You are an expert astronomer. Generate an abstract..."),
|
| 174 |
+
("human", query)
|
| 175 |
+
]
|
| 176 |
+
hyde_task = self.async_llm_call(hyde_messages, session, temperature=0.5)
|
| 177 |
+
tasks.append(hyde_task)
|
| 178 |
+
|
| 179 |
+
# 3. Question type classification
|
| 180 |
+
qtype_messages = [
|
| 181 |
+
("system", "Classify this question type..."),
|
| 182 |
+
("human", query)
|
| 183 |
+
]
|
| 184 |
+
qtype_task = self.async_llm_call(qtype_messages, session)
|
| 185 |
+
tasks.append(qtype_task)
|
| 186 |
+
|
| 187 |
+
# Wait for all to complete
|
| 188 |
+
query_embedding, hyde_doc, question_type = await asyncio.gather(*tasks)
|
| 189 |
+
|
| 190 |
+
return {
|
| 191 |
+
'embedding': query_embedding[0],
|
| 192 |
+
'hyde_doc': hyde_doc,
|
| 193 |
+
'question_type': question_type
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
def run_parallel_search(self, query, top_k=10):
|
| 197 |
+
"""Wrapper to run async function"""
|
| 198 |
+
return asyncio.run(self.parallel_retrieve_and_analyze(query, top_k))
|
| 199 |
+
|
| 200 |
+
class OptimizedEmbedding:
|
| 201 |
+
def __init__(self, openai_key, batch_size=100):
|
| 202 |
+
self.client = OpenAI(api_key=openai_key)
|
| 203 |
+
self.batch_size = batch_size
|
| 204 |
+
self.embed_model = "text-embedding-3-small"
|
| 205 |
+
|
| 206 |
+
def batch_embeddings(self, texts):
|
| 207 |
+
"""Process embeddings in batches for efficiency"""
|
| 208 |
+
all_embeddings = []
|
| 209 |
+
|
| 210 |
+
for i in range(0, len(texts), self.batch_size):
|
| 211 |
+
batch = texts[i:i + self.batch_size]
|
| 212 |
+
try:
|
| 213 |
+
response = self.client.embeddings.create(
|
| 214 |
+
input=batch,
|
| 215 |
+
model=self.embed_model
|
| 216 |
+
)
|
| 217 |
+
batch_embeddings = [item.embedding for item in response.data]
|
| 218 |
+
all_embeddings.extend(batch_embeddings)
|
| 219 |
+
except Exception as e:
|
| 220 |
+
print(f"Batch embedding failed: {e}")
|
| 221 |
+
# Fallback to individual processing
|
| 222 |
+
for text in batch:
|
| 223 |
+
emb = self.client.embeddings.create(
|
| 224 |
+
input=[text],
|
| 225 |
+
model=self.embed_model
|
| 226 |
+
).data[0].embedding
|
| 227 |
+
all_embeddings.append(emb)
|
| 228 |
+
|
| 229 |
+
return all_embeddings
|
| 230 |
+
|
| 231 |
+
class MemoryOptimizedRAG:
|
| 232 |
+
def __init__(self):
|
| 233 |
+
self.vectorstore_cache = {}
|
| 234 |
+
|
| 235 |
+
def create_vectorstore_cached(self, documents, collection_name):
|
| 236 |
+
"""Cache vectorstore to avoid recreation"""
|
| 237 |
+
cache_key = f"{collection_name}_{len(documents)}"
|
| 238 |
+
|
| 239 |
+
if cache_key in self.vectorstore_cache:
|
| 240 |
+
return self.vectorstore_cache[cache_key]
|
| 241 |
+
|
| 242 |
+
# Clear ChromaDB cache before creating new vectorstore
|
| 243 |
+
chromadb.api.client.SharedSystemClient.clear_system_cache()
|
| 244 |
+
|
| 245 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 246 |
+
chunk_size=150,
|
| 247 |
+
chunk_overlap=50,
|
| 248 |
+
add_start_index=True
|
| 249 |
+
)
|
| 250 |
+
splits = text_splitter.split_documents(documents)
|
| 251 |
+
|
| 252 |
+
vectorstore = Chroma.from_documents(
|
| 253 |
+
documents=splits,
|
| 254 |
+
embedding=embeddings,
|
| 255 |
+
collection_name=collection_name
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
self.vectorstore_cache[cache_key] = vectorstore
|
| 259 |
+
return vectorstore
|
| 260 |
+
|
| 261 |
+
def cleanup_old_vectorstores(self, max_cache_size=3):
|
| 262 |
+
"""Clean up old vectorstores to free memory"""
|
| 263 |
+
if len(self.vectorstore_cache) > max_cache_size:
|
| 264 |
+
# Remove oldest entries
|
| 265 |
+
oldest_keys = list(self.vectorstore_cache.keys())[:-max_cache_size]
|
| 266 |
+
for key in oldest_keys:
|
| 267 |
+
try:
|
| 268 |
+
self.vectorstore_cache[key].delete_collection()
|
| 269 |
+
except:
|
| 270 |
+
pass
|
| 271 |
+
del self.vectorstore_cache[key]
|
| 272 |
+
|
| 273 |
def load_nlp():
|
| 274 |
nlp = spacy.load("en_core_web_sm")
|
| 275 |
nlp.add_pipe("textrank")
|
|
|
|
| 310 |
# arxiv_corpus.load_faiss_index('embed', 'data/astrophindex.faiss')
|
| 311 |
|
| 312 |
# keeping it up to date with the dataset
|
| 313 |
+
# arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data', split='train')
|
| 314 |
+
# arxiv_corpus.add_faiss_index(column='embed')
|
| 315 |
+
# print('loading arxiv corpus from disk')
|
| 316 |
+
loader = OptimizedDatasetLoader()
|
| 317 |
+
arxiv_corpus = loader.load_arxiv_corpus_cached()
|
| 318 |
return arxiv_corpus
|
| 319 |
|
| 320 |
class RetrievalSystem():
|
|
|
|
| 872 |
|
| 873 |
yield formatted_df, rag_answer['answer'], consensus, qn_type, fig
|
| 874 |
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
async def run_pathfinder_optimized(query, top_k, extra_keywords, toggles,
|
| 878 |
+
prompt_type, rag_type, ec=None, progress=None):
|
| 879 |
+
"""Optimized version of run_pathfinder with parallel processing"""
|
| 880 |
+
|
| 881 |
+
# Early validation
|
| 882 |
+
if check_mod(query):
|
| 883 |
+
yield None, "Query flagged by moderation", None, None, None
|
| 884 |
+
return
|
| 885 |
+
|
| 886 |
+
# Setup
|
| 887 |
+
input_keywords = [kw.strip() for kw in extra_keywords.split(',')] if extra_keywords else []
|
| 888 |
+
query_keywords = get_keywords(query)
|
| 889 |
+
ec.query_input_keywords = input_keywords + query_keywords
|
| 890 |
+
ec.toggles = toggles
|
| 891 |
+
|
| 892 |
+
# Configure retrieval method
|
| 893 |
+
ec.hyde = rag_type in ["Semantic + HyDE", "Semantic + HyDE + CoHERE"]
|
| 894 |
+
ec.rerank = rag_type in ["Semantic + CoHERE", "Semantic + HyDE + CoHERE"]
|
| 895 |
+
|
| 896 |
+
try:
|
| 897 |
+
if prompt_type == "Deep Research (BETA)":
|
| 898 |
+
# Deep research is inherently sequential, keep original implementation
|
| 899 |
+
formatted_df, rag_answer = deep_research(query, top_k=top_k, ec=ec)
|
| 900 |
+
yield formatted_df, rag_answer['answer'], None, None, None
|
| 901 |
+
else:
|
| 902 |
+
# Phase 1: Parallel initial operations
|
| 903 |
+
gr.Info("Starting parallel search operations...")
|
| 904 |
+
|
| 905 |
+
async with aiohttp.ClientSession() as session:
|
| 906 |
+
# Start retrieval
|
| 907 |
+
retrieval_task = asyncio.create_task(
|
| 908 |
+
async_retrieve(ec, query, top_k, session)
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
# Start question type analysis (independent operation)
|
| 912 |
+
qtype_task = asyncio.create_task(
|
| 913 |
+
async_question_type_analysis(query, session)
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
# Wait for retrieval to complete first
|
| 917 |
+
rs, small_df = await retrieval_task
|
| 918 |
+
formatted_df = ec.return_formatted_df(rs, small_df)
|
| 919 |
+
yield formatted_df, None, None, None, None
|
| 920 |
+
|
| 921 |
+
# Phase 2: RAG QA while question type analysis continues
|
| 922 |
+
gr.Info("Generating answer...")
|
| 923 |
+
rag_answer = await async_rag_qa(query, formatted_df, prompt_type, session)
|
| 924 |
+
yield formatted_df, rag_answer['answer'], None, None, None
|
| 925 |
+
|
| 926 |
+
# Phase 3: Parallel consensus and remaining operations
|
| 927 |
+
gr.Info("Finalizing analysis...")
|
| 928 |
+
|
| 929 |
+
consensus_task = asyncio.create_task(
|
| 930 |
+
async_consensus_evaluation(query, formatted_df, session)
|
| 931 |
+
)
|
| 932 |
+
|
| 933 |
+
plot_task = asyncio.create_task(
|
| 934 |
+
async_make_plot(formatted_df, top_k)
|
| 935 |
+
)
|
| 936 |
+
|
| 937 |
+
# Wait for question type and consensus
|
| 938 |
+
question_type_gen, consensus_answer = await asyncio.gather(
|
| 939 |
+
qtype_task, consensus_task
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
# Format outputs
|
| 943 |
+
consensus = f'## Consensus \n{consensus_answer.consensus}\n\n{consensus_answer.explanation}\n\n > Relevance: {consensus_answer.relevance_score:.1f}'
|
| 944 |
+
qn_type = format_question_type(question_type_gen)
|
| 945 |
+
|
| 946 |
+
yield formatted_df, rag_answer['answer'], consensus, qn_type, None
|
| 947 |
+
|
| 948 |
+
# Final plot
|
| 949 |
+
fig = await plot_task
|
| 950 |
+
yield formatted_df, rag_answer['answer'], consensus, qn_type, fig
|
| 951 |
+
|
| 952 |
+
except Exception as e:
|
| 953 |
+
print(f"Error in pathfinder: {e}")
|
| 954 |
+
yield None, f"Error: {str(e)}", None, None, None
|
| 955 |
+
|
| 956 |
+
async def async_retrieve(ec, query, top_k, session):
|
| 957 |
+
"""Async wrapper for retrieval"""
|
| 958 |
+
loop = asyncio.get_event_loop()
|
| 959 |
+
return await loop.run_in_executor(None, ec.retrieve, query, top_k, True)
|
| 960 |
+
|
| 961 |
+
async def async_rag_qa(query, formatted_df, prompt_type, session):
|
| 962 |
+
"""Async wrapper for RAG QA"""
|
| 963 |
+
loop = asyncio.get_event_loop()
|
| 964 |
+
return await loop.run_in_executor(None, run_rag_qa, query, formatted_df, prompt_type)
|
| 965 |
+
|
| 966 |
+
async def async_consensus_evaluation(query, formatted_df, session):
|
| 967 |
+
"""Async consensus evaluation"""
|
| 968 |
+
abstracts = [formatted_df['abstract'][i+1] for i in range(len(formatted_df))]
|
| 969 |
+
loop = asyncio.get_event_loop()
|
| 970 |
+
return await loop.run_in_executor(None, evaluate_overall_consensus, query, abstracts)
|
| 971 |
+
|
| 972 |
+
async def async_question_type_analysis(query, session):
|
| 973 |
+
"""Async question type analysis"""
|
| 974 |
+
loop = asyncio.get_event_loop()
|
| 975 |
+
return await loop.run_in_executor(None, guess_question_type, query)
|
| 976 |
+
|
| 977 |
+
async def async_make_plot(formatted_df, top_k):
|
| 978 |
+
"""Async plot generation"""
|
| 979 |
+
loop = asyncio.get_event_loop()
|
| 980 |
+
return await loop.run_in_executor(None, make_embedding_plot, formatted_df, top_k, None)
|
| 981 |
+
|
| 982 |
+
def format_question_type(question_type_gen):
|
| 983 |
+
"""Clean up question type output"""
|
| 984 |
+
if '<categorization>' in question_type_gen:
|
| 985 |
+
question_type_gen = question_type_gen.split('<categorization>')[1]
|
| 986 |
+
if '</categorization>' in question_type_gen:
|
| 987 |
+
question_type_gen = question_type_gen.split('</categorization>')[0]
|
| 988 |
+
return question_type_gen.replace('\n', ' \n')
|
| 989 |
+
|
| 990 |
def create_interface():
|
| 991 |
custom_css = """
|
| 992 |
#custom-slider-* {
|
|
|
|
| 1025 |
|
| 1026 |
inputs = [query, top_k, keywords, toggles, prompt_type, rag_type]
|
| 1027 |
outputs = [ret_papers, search_results_state, qntype, conc, plot]
|
| 1028 |
+
btn.click(fn=run_pathfinder_optimized, inputs=inputs, outputs=outputs)
|
| 1029 |
|
| 1030 |
return demo
|
| 1031 |
|