Spaces:
Runtime error
Runtime error
make global rag index
Browse files
app.py
CHANGED
|
@@ -68,6 +68,8 @@ question: Prior to playing for Michigan State, Keith Nichol played football for
|
|
| 68 |
answer: Norman
|
| 69 |
"""
|
| 70 |
|
|
|
|
|
|
|
| 71 |
class FinchCache(DynamicCache):
|
| 72 |
def __init__(self) -> None:
|
| 73 |
super().__init__()
|
|
@@ -218,9 +220,9 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
|
|
| 218 |
else:
|
| 219 |
rag_text = combined_text
|
| 220 |
print("Creating RAG index")
|
| 221 |
-
|
| 222 |
print("Done")
|
| 223 |
-
state = {
|
| 224 |
|
| 225 |
return (
|
| 226 |
combined_text,
|
|
@@ -438,13 +440,13 @@ def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_i
|
|
| 438 |
return cache
|
| 439 |
|
| 440 |
|
| 441 |
-
def run_naive_rag_query(
|
| 442 |
"""
|
| 443 |
For naive RAG, retrieves top-k chunks (k based on target token size)
|
| 444 |
and generates an answer using those chunks.
|
| 445 |
"""
|
| 446 |
k = max(1, rag_token_size // 256)
|
| 447 |
-
retriever =
|
| 448 |
retrieved_docs = retriever.invoke(query)
|
| 449 |
for doc in retrieved_docs:
|
| 450 |
print("=================")
|
|
@@ -477,9 +479,11 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
|
|
| 477 |
print("Target token size for compression: ", target_token_size)
|
| 478 |
step_size = 2
|
| 479 |
start_time_prefill = time.perf_counter()
|
|
|
|
| 480 |
past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size,
|
| 481 |
context_ids, context_attention_mask,
|
| 482 |
question_ids, question_attention_mask))
|
|
|
|
| 483 |
compressed_length = past_key_values.get_seq_length()
|
| 484 |
print("Context size after compression: ", compressed_length)
|
| 485 |
print("Compression rate: ", context_ids.size(1) / compressed_length)
|
|
@@ -490,19 +494,17 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
|
|
| 490 |
compressed_length = past_key_values.get_seq_length()
|
| 491 |
|
| 492 |
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
if rag_index is None:
|
| 496 |
if combined_text.startswith(prefix):
|
| 497 |
rag_text = combined_text[len(prefix):]
|
| 498 |
else:
|
| 499 |
rag_text = combined_text
|
| 500 |
-
|
| 501 |
|
| 502 |
state.update({
|
| 503 |
"compressed_cache": past_key_values,
|
| 504 |
"compressed_length": compressed_length,
|
| 505 |
-
"rag_index": rag_index,
|
| 506 |
"target_token_size": target_token_size,
|
| 507 |
"global_local": percentage,
|
| 508 |
"combined_text": combined_text,
|
|
@@ -523,7 +525,6 @@ def chat_response_stream(message: str, history: list, state: dict):
|
|
| 523 |
user_message = message
|
| 524 |
past_key_values = state["compressed_cache"]
|
| 525 |
compressed_length = past_key_values.get_seq_length()
|
| 526 |
-
rag_index = state["rag_index"]
|
| 527 |
retrieval_slider_value = state["retrieval_slider"]
|
| 528 |
percentage = state["global_local"]
|
| 529 |
|
|
@@ -540,7 +541,7 @@ def chat_response_stream(message: str, history: list, state: dict):
|
|
| 540 |
rag_few_shot = ""
|
| 541 |
print("user message: ", user_message)
|
| 542 |
if rag_retrieval_size != 0:
|
| 543 |
-
rag_context = run_naive_rag_query(
|
| 544 |
new_input = rag_context + "\nquestion: " + user_message + suffix + "answer:"
|
| 545 |
else:
|
| 546 |
new_input = "\nquestion: " + user_message + suffix + "answer:"
|
|
|
|
| 68 |
answer: Norman
|
| 69 |
"""
|
| 70 |
|
| 71 |
+
global_rag_index = None
|
| 72 |
+
|
| 73 |
class FinchCache(DynamicCache):
|
| 74 |
def __init__(self) -> None:
|
| 75 |
super().__init__()
|
|
|
|
| 220 |
else:
|
| 221 |
rag_text = combined_text
|
| 222 |
print("Creating RAG index")
|
| 223 |
+
global_rag_index = create_rag_index(rag_text)
|
| 224 |
print("Done")
|
| 225 |
+
state = {}
|
| 226 |
|
| 227 |
return (
|
| 228 |
combined_text,
|
|
|
|
| 440 |
return cache
|
| 441 |
|
| 442 |
|
| 443 |
+
def run_naive_rag_query(query, rag_token_size, prefix, task, few_shot_examples):
|
| 444 |
"""
|
| 445 |
For naive RAG, retrieves top-k chunks (k based on target token size)
|
| 446 |
and generates an answer using those chunks.
|
| 447 |
"""
|
| 448 |
k = max(1, rag_token_size // 256)
|
| 449 |
+
retriever = global_rag_index.as_retriever(search_type="similarity", search_kwargs={"k": k})
|
| 450 |
retrieved_docs = retriever.invoke(query)
|
| 451 |
for doc in retrieved_docs:
|
| 452 |
print("=================")
|
|
|
|
| 479 |
print("Target token size for compression: ", target_token_size)
|
| 480 |
step_size = 2
|
| 481 |
start_time_prefill = time.perf_counter()
|
| 482 |
+
print("Compressing KV Cache")
|
| 483 |
past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size,
|
| 484 |
context_ids, context_attention_mask,
|
| 485 |
question_ids, question_attention_mask))
|
| 486 |
+
print("Done")
|
| 487 |
compressed_length = past_key_values.get_seq_length()
|
| 488 |
print("Context size after compression: ", compressed_length)
|
| 489 |
print("Compression rate: ", context_ids.size(1) / compressed_length)
|
|
|
|
| 494 |
compressed_length = past_key_values.get_seq_length()
|
| 495 |
|
| 496 |
|
| 497 |
+
|
| 498 |
+
if global_rag_index is None:
|
|
|
|
| 499 |
if combined_text.startswith(prefix):
|
| 500 |
rag_text = combined_text[len(prefix):]
|
| 501 |
else:
|
| 502 |
rag_text = combined_text
|
| 503 |
+
global_rag_index = create_rag_index(rag_text, device)
|
| 504 |
|
| 505 |
state.update({
|
| 506 |
"compressed_cache": past_key_values,
|
| 507 |
"compressed_length": compressed_length,
|
|
|
|
| 508 |
"target_token_size": target_token_size,
|
| 509 |
"global_local": percentage,
|
| 510 |
"combined_text": combined_text,
|
|
|
|
| 525 |
user_message = message
|
| 526 |
past_key_values = state["compressed_cache"]
|
| 527 |
compressed_length = past_key_values.get_seq_length()
|
|
|
|
| 528 |
retrieval_slider_value = state["retrieval_slider"]
|
| 529 |
percentage = state["global_local"]
|
| 530 |
|
|
|
|
| 541 |
rag_few_shot = ""
|
| 542 |
print("user message: ", user_message)
|
| 543 |
if rag_retrieval_size != 0:
|
| 544 |
+
rag_context = run_naive_rag_query(user_message, rag_retrieval_size, rag_prefix, rag_task, rag_few_shot)
|
| 545 |
new_input = rag_context + "\nquestion: " + user_message + suffix + "answer:"
|
| 546 |
else:
|
| 547 |
new_input = "\nquestion: " + user_message + suffix + "answer:"
|