Commit
·
ab32d1c
1
Parent(s):
2fa0d66
reduced latency
Browse files
app.py
CHANGED
|
@@ -5,6 +5,7 @@ import hashlib
|
|
| 5 |
import gradio as gr
|
| 6 |
import time
|
| 7 |
from functools import partial
|
|
|
|
| 8 |
from collections import defaultdict
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import List, Dict, Any
|
|
@@ -77,6 +78,8 @@ def initialize_resources():
|
|
| 77 |
|
| 78 |
vectorstore, all_chunks, all_texts, metadatas = initialize_resources()
|
| 79 |
|
|
|
|
|
|
|
| 80 |
# LLMs
|
| 81 |
repharser_llm = ChatNVIDIA(model="mistralai/mistral-7b-instruct-v0.3") | StrOutputParser()
|
| 82 |
instruct_llm = ChatNVIDIA(model="mistralai/mixtral-8x22b-instruct-v0.1") | StrOutputParser()
|
|
@@ -162,10 +165,10 @@ answer_prompt_relevant = ChatPromptTemplate.from_template(
|
|
| 162 |
"Answer:"
|
| 163 |
)
|
| 164 |
|
| 165 |
-
|
| 166 |
answer_prompt_fallback = ChatPromptTemplate.from_template(
|
| 167 |
"You are Krishna’s personal AI assistant. The user asked a question unrelated to Krishna’s background.\n"
|
| 168 |
"Respond with a touch of humor, then guide the conversation back to Krishna’s actual skills, experiences, or projects.\n\n"
|
|
|
|
| 169 |
"Krishna's Background:\n{profile}\n\n"
|
| 170 |
"User Question:\n{query}\n\n"
|
| 171 |
"Your Answer:"
|
|
@@ -178,13 +181,13 @@ def parse_rewrites(raw_response: str) -> list[str]:
|
|
| 178 |
def hybrid_retrieve(inputs, exclude_terms=None):
|
| 179 |
# if exclude_terms is None:
|
| 180 |
# exclude_terms = ["cgpa", "university", "b.tech", "m.s.", "certification", "coursera", "edx", "goal", "aspiration", "linkedin", "publication", "ieee", "doi", "degree"]
|
| 181 |
-
|
| 182 |
all_queries = inputs["all_queries"]
|
| 183 |
-
bm25_retriever = BM25Retriever.from_texts(texts=all_texts, metadatas=metadatas)
|
| 184 |
bm25_retriever.k = inputs["k_per_query"]
|
| 185 |
vectorstore = inputs["vectorstore"]
|
| 186 |
alpha = inputs["alpha"]
|
| 187 |
top_k = inputs.get("top_k", 15)
|
|
|
|
| 188 |
|
| 189 |
scored_chunks = defaultdict(lambda: {
|
| 190 |
"vector_scores": [],
|
|
@@ -192,23 +195,45 @@ def hybrid_retrieve(inputs, exclude_terms=None):
|
|
| 192 |
"content": None,
|
| 193 |
"metadata": None,
|
| 194 |
})
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
| 198 |
for doc, score in vec_hits:
|
| 199 |
key = hashlib.md5(doc.page_content.encode("utf-8")).hexdigest()
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
bm_hits = bm25_retriever.invoke(subquery)
|
|
|
|
| 205 |
for rank, doc in enumerate(bm_hits):
|
| 206 |
key = hashlib.md5(doc.page_content.encode("utf-8")).hexdigest()
|
| 207 |
-
bm_score = 1.0 - (rank /
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
all_vec_means = [np.mean(v["vector_scores"]) for v in scored_chunks.values() if v["vector_scores"]]
|
| 213 |
max_vec = max(all_vec_means) if all_vec_means else 1
|
| 214 |
min_vec = min(all_vec_means) if all_vec_means else 0
|
|
@@ -221,8 +246,6 @@ def hybrid_retrieve(inputs, exclude_terms=None):
|
|
| 221 |
final_score = alpha * norm_vec + (1 - alpha) * bm25_score
|
| 222 |
|
| 223 |
content = chunk["content"].lower()
|
| 224 |
-
# if any(term in content for term in exclude_terms):
|
| 225 |
-
# continue
|
| 226 |
if final_score < 0.05 or len(content.strip()) < 100:
|
| 227 |
continue
|
| 228 |
|
|
@@ -334,7 +357,7 @@ def chat_interface(message, history):
|
|
| 334 |
"k_per_query": 3,
|
| 335 |
"alpha": 0.7,
|
| 336 |
"vectorstore": vectorstore,
|
| 337 |
-
"
|
| 338 |
}
|
| 339 |
response = ""
|
| 340 |
for chunk in full_pipeline.stream(inputs):
|
|
@@ -358,7 +381,7 @@ demo = gr.ChatInterface(
|
|
| 358 |
)
|
| 359 |
|
| 360 |
if __name__ == "__main__":
|
| 361 |
-
demo.launch(debug=True)
|
| 362 |
|
| 363 |
# with gr.Blocks(css="""
|
| 364 |
# html, body, .gradio-container {
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
import time
|
| 7 |
from functools import partial
|
| 8 |
+
import concurrent.futures
|
| 9 |
from collections import defaultdict
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import List, Dict, Any
|
|
|
|
| 78 |
|
| 79 |
vectorstore, all_chunks, all_texts, metadatas = initialize_resources()
|
| 80 |
|
| 81 |
+
bm25_retriever = BM25Retriever.from_texts(texts=all_texts, metadatas=metadatas)
|
| 82 |
+
|
| 83 |
# LLMs
|
| 84 |
repharser_llm = ChatNVIDIA(model="mistralai/mistral-7b-instruct-v0.3") | StrOutputParser()
|
| 85 |
instruct_llm = ChatNVIDIA(model="mistralai/mixtral-8x22b-instruct-v0.1") | StrOutputParser()
|
|
|
|
| 165 |
"Answer:"
|
| 166 |
)
|
| 167 |
|
|
|
|
| 168 |
answer_prompt_fallback = ChatPromptTemplate.from_template(
|
| 169 |
"You are Krishna’s personal AI assistant. The user asked a question unrelated to Krishna’s background.\n"
|
| 170 |
"Respond with a touch of humor, then guide the conversation back to Krishna’s actual skills, experiences, or projects.\n\n"
|
| 171 |
+
"Make it clear that everything you mention afterward comes from Krishna's actual profile.\n\n"
|
| 172 |
"Krishna's Background:\n{profile}\n\n"
|
| 173 |
"User Question:\n{query}\n\n"
|
| 174 |
"Your Answer:"
|
|
|
|
| 181 |
def hybrid_retrieve(inputs, exclude_terms=None):
|
| 182 |
# if exclude_terms is None:
|
| 183 |
# exclude_terms = ["cgpa", "university", "b.tech", "m.s.", "certification", "coursera", "edx", "goal", "aspiration", "linkedin", "publication", "ieee", "doi", "degree"]
|
| 184 |
+
bm25_retriever = inputs["bm25_retriever"]
|
| 185 |
all_queries = inputs["all_queries"]
|
|
|
|
| 186 |
bm25_retriever.k = inputs["k_per_query"]
|
| 187 |
vectorstore = inputs["vectorstore"]
|
| 188 |
alpha = inputs["alpha"]
|
| 189 |
top_k = inputs.get("top_k", 15)
|
| 190 |
+
k_per_query = inputs["k_per_query"]
|
| 191 |
|
| 192 |
scored_chunks = defaultdict(lambda: {
|
| 193 |
"vector_scores": [],
|
|
|
|
| 195 |
"content": None,
|
| 196 |
"metadata": None,
|
| 197 |
})
|
| 198 |
+
|
| 199 |
+
# Function to process each subquery
|
| 200 |
+
def process_subquery(subquery, k_per_query=3):
|
| 201 |
+
# Vector retrieval
|
| 202 |
+
vec_hits = vectorstore.similarity_search_with_score(subquery, k=k_per_query)
|
| 203 |
+
vec_results = []
|
| 204 |
for doc, score in vec_hits:
|
| 205 |
key = hashlib.md5(doc.page_content.encode("utf-8")).hexdigest()
|
| 206 |
+
vec_results.append((key, doc, score))
|
| 207 |
+
|
| 208 |
+
# BM25 retrieval
|
|
|
|
| 209 |
bm_hits = bm25_retriever.invoke(subquery)
|
| 210 |
+
bm_results = []
|
| 211 |
for rank, doc in enumerate(bm_hits):
|
| 212 |
key = hashlib.md5(doc.page_content.encode("utf-8")).hexdigest()
|
| 213 |
+
bm_score = 1.0 - (rank / k_per_query)
|
| 214 |
+
bm_results.append((key, doc, bm_score))
|
| 215 |
+
|
| 216 |
+
return vec_results, bm_results
|
| 217 |
|
| 218 |
+
# Process subqueries in parallel
|
| 219 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 220 |
+
futures = [executor.submit(process_subquery, q) for q in all_queries]
|
| 221 |
+
for future in concurrent.futures.as_completed(futures):
|
| 222 |
+
vec_results, bm_results = future.result()
|
| 223 |
+
|
| 224 |
+
# Process vector results
|
| 225 |
+
for key, doc, score in vec_results:
|
| 226 |
+
scored_chunks[key]["vector_scores"].append(score)
|
| 227 |
+
scored_chunks[key]["content"] = doc.page_content
|
| 228 |
+
scored_chunks[key]["metadata"] = doc.metadata
|
| 229 |
+
|
| 230 |
+
# Process BM25 results
|
| 231 |
+
for key, doc, bm_score in bm_results:
|
| 232 |
+
scored_chunks[key]["bm25_score"] += bm_score
|
| 233 |
+
scored_chunks[key]["content"] = doc.page_content
|
| 234 |
+
scored_chunks[key]["metadata"] = doc.metadata
|
| 235 |
+
|
| 236 |
+
# Rest of the scoring and filtering logic remains the same
|
| 237 |
all_vec_means = [np.mean(v["vector_scores"]) for v in scored_chunks.values() if v["vector_scores"]]
|
| 238 |
max_vec = max(all_vec_means) if all_vec_means else 1
|
| 239 |
min_vec = min(all_vec_means) if all_vec_means else 0
|
|
|
|
| 246 |
final_score = alpha * norm_vec + (1 - alpha) * bm25_score
|
| 247 |
|
| 248 |
content = chunk["content"].lower()
|
|
|
|
|
|
|
| 249 |
if final_score < 0.05 or len(content.strip()) < 100:
|
| 250 |
continue
|
| 251 |
|
|
|
|
| 357 |
"k_per_query": 3,
|
| 358 |
"alpha": 0.7,
|
| 359 |
"vectorstore": vectorstore,
|
| 360 |
+
"bm25_retriever": bm25_retriever,
|
| 361 |
}
|
| 362 |
response = ""
|
| 363 |
for chunk in full_pipeline.stream(inputs):
|
|
|
|
| 381 |
)
|
| 382 |
|
| 383 |
if __name__ == "__main__":
|
| 384 |
+
demo.launch(max_threads=4, prevent_thread_lock=True, debug=True)
|
| 385 |
|
| 386 |
# with gr.Blocks(css="""
|
| 387 |
# html, body, .gradio-container {
|