Spaces:
Sleeping
Sleeping
Commit
·
806f3f9
1
Parent(s):
ad8e8ec
Much lower reranking threshold (-0.5 instead of 0.1) + detailed score logging
Browse files- documents_prep.py +0 -38
- index_retriever.py +21 -19
- utils.py +62 -21
documents_prep.py
CHANGED
|
@@ -412,44 +412,6 @@ def extract_sections_from_json(json_path):
|
|
| 412 |
return documents
|
| 413 |
|
| 414 |
|
| 415 |
-
def load_table_documents(repo_id, hf_token, table_dir):
|
| 416 |
-
"""Load and chunk tables"""
|
| 417 |
-
log_message("Loading tables...")
|
| 418 |
-
|
| 419 |
-
files = list_repo_files(repo_id=repo_id, repo_type="dataset", token=hf_token)
|
| 420 |
-
table_files = [f for f in files if f.startswith(table_dir) and f.endswith('.json')]
|
| 421 |
-
|
| 422 |
-
all_chunks = []
|
| 423 |
-
for file_path in table_files:
|
| 424 |
-
try:
|
| 425 |
-
local_path = hf_hub_download(
|
| 426 |
-
repo_id=repo_id,
|
| 427 |
-
filename=file_path,
|
| 428 |
-
repo_type="dataset",
|
| 429 |
-
token=hf_token
|
| 430 |
-
)
|
| 431 |
-
|
| 432 |
-
with open(local_path, 'r', encoding='utf-8') as f:
|
| 433 |
-
data = json.load(f)
|
| 434 |
-
|
| 435 |
-
# Extract file-level document_id
|
| 436 |
-
file_doc_id = data.get('document_id', data.get('document', 'unknown'))
|
| 437 |
-
|
| 438 |
-
for sheet in data.get('sheets', []):
|
| 439 |
-
# Use sheet-level document_id if available, otherwise use file-level
|
| 440 |
-
sheet_doc_id = sheet.get('document_id', sheet.get('document', file_doc_id))
|
| 441 |
-
|
| 442 |
-
# CRITICAL: Pass document_id to chunk function
|
| 443 |
-
chunks = chunk_table_by_content(sheet, sheet_doc_id)
|
| 444 |
-
all_chunks.extend(chunks)
|
| 445 |
-
|
| 446 |
-
except Exception as e:
|
| 447 |
-
log_message(f"Error loading {file_path}: {e}")
|
| 448 |
-
|
| 449 |
-
log_message(f"✓ Loaded {len(all_chunks)} table chunks")
|
| 450 |
-
return all_chunks
|
| 451 |
-
|
| 452 |
-
|
| 453 |
def load_image_documents(repo_id, hf_token, image_dir):
|
| 454 |
"""Load image descriptions"""
|
| 455 |
log_message("Loading images...")
|
|
|
|
| 412 |
return documents
|
| 413 |
|
| 414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
def load_image_documents(repo_id, hf_token, image_dir):
|
| 416 |
"""Load image descriptions"""
|
| 417 |
log_message("Loading images...")
|
index_retriever.py
CHANGED
|
@@ -24,56 +24,58 @@ def keyword_filter_nodes(query, nodes, min_keyword_matches=1):
|
|
| 24 |
return filtered
|
| 25 |
|
| 26 |
def create_query_engine(vector_index):
|
| 27 |
-
"""Create hybrid retrieval engine with deduplication"""
|
| 28 |
log_message("Creating query engine...")
|
| 29 |
|
| 30 |
vector_retriever = VectorIndexRetriever(
|
| 31 |
index=vector_index,
|
| 32 |
-
similarity_top_k=
|
| 33 |
)
|
| 34 |
bm25_retriever = BM25Retriever.from_defaults(
|
| 35 |
docstore=vector_index.docstore,
|
| 36 |
-
similarity_top_k=
|
| 37 |
)
|
| 38 |
hybrid_retriever = QueryFusionRetriever(
|
| 39 |
[vector_retriever, bm25_retriever],
|
| 40 |
-
similarity_top_k=
|
| 41 |
num_queries=1
|
| 42 |
)
|
| 43 |
|
| 44 |
class DeduplicatedQueryEngine(RetrieverQueryEngine):
|
| 45 |
def retrieve(self, query):
|
| 46 |
nodes = hybrid_retriever.retrieve(query)
|
|
|
|
| 47 |
|
| 48 |
-
#
|
| 49 |
seen_hashes = set()
|
| 50 |
unique_nodes = []
|
|
|
|
| 51 |
|
| 52 |
for node in nodes:
|
| 53 |
-
|
|
|
|
| 54 |
|
| 55 |
if text_hash not in seen_hashes:
|
| 56 |
seen_hashes.add(text_hash)
|
| 57 |
unique_nodes.append(node)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
log_message(f"
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
response = self._response_synthesizer.synthesize(
|
| 66 |
-
query=query_bundle,
|
| 67 |
-
nodes=nodes
|
| 68 |
-
)
|
| 69 |
-
return response
|
| 70 |
|
| 71 |
response_synthesizer = get_response_synthesizer()
|
| 72 |
|
| 73 |
query_engine = DeduplicatedQueryEngine(
|
| 74 |
-
retriever=hybrid_retriever,
|
| 75 |
response_synthesizer=response_synthesizer
|
| 76 |
)
|
| 77 |
|
| 78 |
-
log_message("✓ Query engine created
|
| 79 |
return query_engine
|
|
|
|
| 24 |
return filtered
|
| 25 |
|
| 26 |
def create_query_engine(vector_index):
|
| 27 |
+
"""Create hybrid retrieval engine with better deduplication"""
|
| 28 |
log_message("Creating query engine...")
|
| 29 |
|
| 30 |
vector_retriever = VectorIndexRetriever(
|
| 31 |
index=vector_index,
|
| 32 |
+
similarity_top_k=50 # Reduced to get more diverse results
|
| 33 |
)
|
| 34 |
bm25_retriever = BM25Retriever.from_defaults(
|
| 35 |
docstore=vector_index.docstore,
|
| 36 |
+
similarity_top_k=50,
|
| 37 |
)
|
| 38 |
hybrid_retriever = QueryFusionRetriever(
|
| 39 |
[vector_retriever, bm25_retriever],
|
| 40 |
+
similarity_top_k=60, # Reduced
|
| 41 |
num_queries=1
|
| 42 |
)
|
| 43 |
|
| 44 |
class DeduplicatedQueryEngine(RetrieverQueryEngine):
|
| 45 |
def retrieve(self, query):
|
| 46 |
nodes = hybrid_retriever.retrieve(query)
|
| 47 |
+
log_message(f"Hybrid retrieval returned: {len(nodes)} nodes")
|
| 48 |
|
| 49 |
+
# Better deduplication using longer text snippet
|
| 50 |
seen_hashes = set()
|
| 51 |
unique_nodes = []
|
| 52 |
+
doc_type_counts = {'text': 0, 'table': 0, 'image': 0}
|
| 53 |
|
| 54 |
for node in nodes:
|
| 55 |
+
# Use first 500 chars for dedup hash
|
| 56 |
+
text_hash = hash(node.text[:500])
|
| 57 |
|
| 58 |
if text_hash not in seen_hashes:
|
| 59 |
seen_hashes.add(text_hash)
|
| 60 |
unique_nodes.append(node)
|
| 61 |
+
|
| 62 |
+
# Count by type
|
| 63 |
+
node_type = node.metadata.get('type', 'text')
|
| 64 |
+
doc_type_counts[node_type] = doc_type_counts.get(node_type, 0) + 1
|
| 65 |
|
| 66 |
+
log_message(f"After dedup: {len(unique_nodes)} unique nodes")
|
| 67 |
+
log_message(f"Types: text={doc_type_counts.get('text', 0)}, "
|
| 68 |
+
f"table={doc_type_counts.get('table', 0)}, "
|
| 69 |
+
f"image={doc_type_counts.get('image', 0)}")
|
| 70 |
+
|
| 71 |
+
return unique_nodes[:50]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
response_synthesizer = get_response_synthesizer()
|
| 74 |
|
| 75 |
query_engine = DeduplicatedQueryEngine(
|
| 76 |
+
retriever=hybrid_retriever,
|
| 77 |
response_synthesizer=response_synthesizer
|
| 78 |
)
|
| 79 |
|
| 80 |
+
log_message("✓ Query engine created")
|
| 81 |
return query_engine
|
utils.py
CHANGED
|
@@ -47,29 +47,55 @@ def answer_question(question, query_engine, reranker):
|
|
| 47 |
retrieved = query_engine.retrieve(question)
|
| 48 |
log_message(f"RETRIEVED: {len(retrieved)} unique nodes")
|
| 49 |
|
| 50 |
-
reranked = rerank_nodes(question, retrieved, reranker, top_k=
|
| 51 |
log_message(f"RERANKED: {len(reranked)} nodes")
|
| 52 |
|
| 53 |
-
|
|
|
|
| 54 |
for n in reranked:
|
| 55 |
-
|
| 56 |
-
doc_id
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
else:
|
| 65 |
-
|
| 66 |
-
|
|
|
|
| 67 |
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
from config import CUSTOM_PROMPT
|
| 71 |
prompt = CUSTOM_PROMPT.format(context_str=context, query_str=question)
|
| 72 |
-
log_message(f"PROMPT LENGTH: {len(prompt)} chars")
|
| 73 |
|
| 74 |
from llama_index.core import Settings
|
| 75 |
response = Settings.llm.complete(prompt)
|
|
@@ -83,15 +109,30 @@ def answer_question(question, query_engine, reranker):
|
|
| 83 |
log_message(traceback.format_exc())
|
| 84 |
return f"Ошибка: {e}", ""
|
| 85 |
|
| 86 |
-
def rerank_nodes(query, nodes, reranker, top_k=20, min_score
|
| 87 |
-
"""
|
| 88 |
if not nodes or not reranker:
|
|
|
|
| 89 |
return nodes[:top_k]
|
| 90 |
|
| 91 |
-
pairs = [[query, n.text] for n in nodes]
|
| 92 |
scores = reranker.predict(pairs)
|
| 93 |
scored = sorted(zip(nodes, scores), key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
filtered = [n for n, s in scored if s >= min_score]
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
| 47 |
retrieved = query_engine.retrieve(question)
|
| 48 |
log_message(f"RETRIEVED: {len(retrieved)} unique nodes")
|
| 49 |
|
| 50 |
+
reranked = rerank_nodes(question, retrieved, reranker, top_k=15, min_score=-0.5)
|
| 51 |
log_message(f"RERANKED: {len(reranked)} nodes")
|
| 52 |
|
| 53 |
+
# Group by document and type
|
| 54 |
+
doc_groups = {}
|
| 55 |
for n in reranked:
|
| 56 |
+
doc_id = n.metadata.get('document_id', 'unknown')
|
| 57 |
+
if doc_id not in doc_groups:
|
| 58 |
+
doc_groups[doc_id] = {'tables': [], 'text': [], 'images': []}
|
| 59 |
+
|
| 60 |
+
node_type = n.metadata.get('type', 'text')
|
| 61 |
+
if node_type == 'table':
|
| 62 |
+
doc_groups[doc_id]['tables'].append(n)
|
| 63 |
+
elif node_type == 'image':
|
| 64 |
+
doc_groups[doc_id]['images'].append(n)
|
| 65 |
else:
|
| 66 |
+
doc_groups[doc_id]['text'].append(n)
|
| 67 |
+
|
| 68 |
+
log_message(f"Documents found: {list(doc_groups.keys())}")
|
| 69 |
|
| 70 |
+
# Format context by document
|
| 71 |
+
context_parts = []
|
| 72 |
+
for doc_id, groups in doc_groups.items():
|
| 73 |
+
doc_section = [f"=== ДОКУМЕНТ: {doc_id} ==="]
|
| 74 |
+
|
| 75 |
+
# Tables first (most important for your queries)
|
| 76 |
+
if groups['tables']:
|
| 77 |
+
doc_section.append("\n--- ТАБЛИЦЫ ---")
|
| 78 |
+
for n in groups['tables']:
|
| 79 |
+
meta = n.metadata
|
| 80 |
+
table_id = meta.get('table_identifier', meta.get('table_number', 'unknown'))
|
| 81 |
+
title = meta.get('table_title', '')
|
| 82 |
+
doc_section.append(f"\n[Таблица {table_id}] {title}")
|
| 83 |
+
doc_section.append(n.text[:1500]) # Limit length
|
| 84 |
+
|
| 85 |
+
# Then text
|
| 86 |
+
if groups['text']:
|
| 87 |
+
doc_section.append("\n--- ТЕКСТ ---")
|
| 88 |
+
for n in groups['text'][:3]: # Limit text chunks
|
| 89 |
+
doc_section.append(n.text[:800])
|
| 90 |
+
|
| 91 |
+
context_parts.append("\n".join(doc_section))
|
| 92 |
+
|
| 93 |
+
context = "\n\n" + ("="*70 + "\n\n").join(context_parts)
|
| 94 |
+
|
| 95 |
+
log_message(f"Context length: {len(context)} chars")
|
| 96 |
|
| 97 |
from config import CUSTOM_PROMPT
|
| 98 |
prompt = CUSTOM_PROMPT.format(context_str=context, query_str=question)
|
|
|
|
| 99 |
|
| 100 |
from llama_index.core import Settings
|
| 101 |
response = Settings.llm.complete(prompt)
|
|
|
|
| 109 |
log_message(traceback.format_exc())
|
| 110 |
return f"Ошибка: {e}", ""
|
| 111 |
|
| 112 |
+
def rerank_nodes(query, nodes, reranker, top_k=20, min_score=-0.5): # Much lower threshold
|
| 113 |
+
"""Rerank with detailed score logging"""
|
| 114 |
if not nodes or not reranker:
|
| 115 |
+
log_message("WARNING: No nodes or reranker available")
|
| 116 |
return nodes[:top_k]
|
| 117 |
|
| 118 |
+
pairs = [[query, n.text[:500]] for n in nodes] # Limit text length for reranker
|
| 119 |
scores = reranker.predict(pairs)
|
| 120 |
scored = sorted(zip(nodes, scores), key=lambda x: x[1], reverse=True)
|
| 121 |
+
|
| 122 |
+
# Detailed logging
|
| 123 |
+
if scored:
|
| 124 |
+
top_5_scores = [s for _, s in scored[:5]]
|
| 125 |
+
bottom_5_scores = [s for _, s in scored[-5:]]
|
| 126 |
+
log_message(f"Score range: {min(scores):.3f} to {max(scores):.3f}")
|
| 127 |
+
log_message(f"Top 5 scores: {top_5_scores}")
|
| 128 |
+
log_message(f"Bottom 5 scores: {bottom_5_scores}")
|
| 129 |
+
|
| 130 |
+
# Count how many pass threshold
|
| 131 |
+
above_threshold = sum(1 for _, s in scored if s >= min_score)
|
| 132 |
+
log_message(f"Nodes above threshold ({min_score}): {above_threshold}/{len(scored)}")
|
| 133 |
+
|
| 134 |
filtered = [n for n, s in scored if s >= min_score]
|
| 135 |
+
result = filtered[:top_k] if filtered else [n for n, _ in scored[:top_k]]
|
| 136 |
+
|
| 137 |
+
log_message(f"Returning {len(result)} nodes after reranking")
|
| 138 |
+
return result
|