Spaces:
Runtime error
Runtime error
Added timers, trying to figure out slowness
Browse files
RAG.py
CHANGED
|
@@ -15,7 +15,8 @@ from typing import Dict, Any, Optional, List, Tuple
|
|
| 15 |
import json
|
| 16 |
import logging
|
| 17 |
|
| 18 |
-
def retrieve(index_name: str, query: str, embeddings, k: int = 1000) -> Tuple[List[Document], List[float]]:
|
|
|
|
| 19 |
load_dotenv()
|
| 20 |
pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
| 21 |
pc = Pinecone(api_key=pinecone_api_key)
|
|
@@ -31,6 +32,7 @@ def retrieve(index_name: str, query: str, embeddings, k: int = 1000) -> Tuple[Li
|
|
| 31 |
for res, score in results:
|
| 32 |
documents.append(res)
|
| 33 |
scores.append(score)
|
|
|
|
| 34 |
return documents, scores
|
| 35 |
|
| 36 |
def safe_get_json(url: str) -> Optional[Dict]:
|
|
@@ -61,7 +63,8 @@ def extract_text_from_json(json_data: Dict) -> str:
|
|
| 61 |
return " ".join(text_parts) if text_parts else "No content available"
|
| 62 |
|
| 63 |
def rerank(documents: List[Document], query: str) -> List[Document]:
|
| 64 |
-
"""Rerank documents using BM25
|
|
|
|
| 65 |
if not documents:
|
| 66 |
return []
|
| 67 |
|
|
@@ -85,6 +88,7 @@ def rerank(documents: List[Document], query: str) -> List[Document]:
|
|
| 85 |
# Create BM25 retriever with the processed documents
|
| 86 |
reranker = BM25Retriever.from_documents(full_docs, k=min(10, len(full_docs)))
|
| 87 |
reranked_docs = reranker.invoke(query)
|
|
|
|
| 88 |
return reranked_docs
|
| 89 |
|
| 90 |
def parse_xml_and_query(query:str,xml_string:str) -> str:
|
|
@@ -116,8 +120,9 @@ def parse_xml_and_check(xml_string: str) -> str:
|
|
| 116 |
|
| 117 |
def RAG(llm: Any, query: str, index_name: str, embeddings: Any, top: int = 10, k: int = 100) -> Tuple[str, List[Document]]:
|
| 118 |
"""Main RAG function with improved error handling and validation."""
|
|
|
|
| 119 |
try:
|
| 120 |
-
# Retrieve initial documents
|
| 121 |
query_template = PromptTemplate.from_template(
|
| 122 |
"""
|
| 123 |
Your job is to think about a query and then generate a statement that only includes information from the query that would answer the query.
|
|
@@ -147,6 +152,7 @@ def RAG(llm: Any, query: str, index_name: str, embeddings: Any, top: int = 10, k
|
|
| 147 |
query_prompt = query_template.invoke({"query":query})
|
| 148 |
query_response = llm.invoke(query_prompt)
|
| 149 |
new_query = parse_xml_and_query(query=query,xml_string=query_response.content)
|
|
|
|
| 150 |
|
| 151 |
retrieved, _ = retrieve(index_name=index_name, query=new_query, embeddings=embeddings, k=k)
|
| 152 |
if not retrieved:
|
|
@@ -191,6 +197,7 @@ def RAG(llm: Any, query: str, index_name: str, embeddings: Any, top: int = 10, k
|
|
| 191 |
|
| 192 |
# Parse and return response
|
| 193 |
parsed = parse_xml_and_check(response.content)
|
|
|
|
| 194 |
return parsed, reranked
|
| 195 |
|
| 196 |
except Exception as e:
|
|
|
|
| 15 |
import json
|
| 16 |
import logging
|
| 17 |
|
| 18 |
+
def retrieve(index_name: str, query: str, embeddings, k: int = 1000) -> Tuple[List[Document], List[float]]:
|
| 19 |
+
start = time.time()
|
| 20 |
load_dotenv()
|
| 21 |
pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
| 22 |
pc = Pinecone(api_key=pinecone_api_key)
|
|
|
|
| 32 |
for res, score in results:
|
| 33 |
documents.append(res)
|
| 34 |
scores.append(score)
|
| 35 |
+
print(f"Finished Retrieval: {time.time() - start}")
|
| 36 |
return documents, scores
|
| 37 |
|
| 38 |
def safe_get_json(url: str) -> Optional[Dict]:
|
|
|
|
| 63 |
return " ".join(text_parts) if text_parts else "No content available"
|
| 64 |
|
| 65 |
def rerank(documents: List[Document], query: str) -> List[Document]:
|
| 66 |
+
"""Ingest more metadata. Rerank documents using BM25"""
|
| 67 |
+
start = time.time()
|
| 68 |
if not documents:
|
| 69 |
return []
|
| 70 |
|
|
|
|
| 88 |
# Create BM25 retriever with the processed documents
|
| 89 |
reranker = BM25Retriever.from_documents(full_docs, k=min(10, len(full_docs)))
|
| 90 |
reranked_docs = reranker.invoke(query)
|
| 91 |
+
print(f"Finished reranking: {time.time()-start}")
|
| 92 |
return reranked_docs
|
| 93 |
|
| 94 |
def parse_xml_and_query(query:str,xml_string:str) -> str:
|
|
|
|
| 120 |
|
| 121 |
def RAG(llm: Any, query: str, index_name: str, embeddings: Any, top: int = 10, k: int = 100) -> Tuple[str, List[Document]]:
|
| 122 |
"""Main RAG function with improved error handling and validation."""
|
| 123 |
+
start = time.time()
|
| 124 |
try:
|
| 125 |
+
# Retrieve initial documents using rephrased query
|
| 126 |
query_template = PromptTemplate.from_template(
|
| 127 |
"""
|
| 128 |
Your job is to think about a query and then generate a statement that only includes information from the query that would answer the query.
|
|
|
|
| 152 |
query_prompt = query_template.invoke({"query":query})
|
| 153 |
query_response = llm.invoke(query_prompt)
|
| 154 |
new_query = parse_xml_and_query(query=query,xml_string=query_response.content)
|
| 155 |
+
print(f"New_Query: {new_query}")
|
| 156 |
|
| 157 |
retrieved, _ = retrieve(index_name=index_name, query=new_query, embeddings=embeddings, k=k)
|
| 158 |
if not retrieved:
|
|
|
|
| 197 |
|
| 198 |
# Parse and return response
|
| 199 |
parsed = parse_xml_and_check(response.content)
|
| 200 |
+
print(f"RAG Finished: {time.time()-start}")
|
| 201 |
return parsed, reranked
|
| 202 |
|
| 203 |
except Exception as e:
|