MrSimple07 commited on
Commit
806f3f9
·
1 Parent(s): ad8e8ec

Much lower reranking threshold (-0.5 instead of 0.1) + detailed score logging

Browse files
Files changed (3) hide show
  1. documents_prep.py +0 -38
  2. index_retriever.py +21 -19
  3. utils.py +62 -21
documents_prep.py CHANGED
@@ -412,44 +412,6 @@ def extract_sections_from_json(json_path):
412
  return documents
413
 
414
 
415
- def load_table_documents(repo_id, hf_token, table_dir):
416
- """Load and chunk tables"""
417
- log_message("Loading tables...")
418
-
419
- files = list_repo_files(repo_id=repo_id, repo_type="dataset", token=hf_token)
420
- table_files = [f for f in files if f.startswith(table_dir) and f.endswith('.json')]
421
-
422
- all_chunks = []
423
- for file_path in table_files:
424
- try:
425
- local_path = hf_hub_download(
426
- repo_id=repo_id,
427
- filename=file_path,
428
- repo_type="dataset",
429
- token=hf_token
430
- )
431
-
432
- with open(local_path, 'r', encoding='utf-8') as f:
433
- data = json.load(f)
434
-
435
- # Extract file-level document_id
436
- file_doc_id = data.get('document_id', data.get('document', 'unknown'))
437
-
438
- for sheet in data.get('sheets', []):
439
- # Use sheet-level document_id if available, otherwise use file-level
440
- sheet_doc_id = sheet.get('document_id', sheet.get('document', file_doc_id))
441
-
442
- # CRITICAL: Pass document_id to chunk function
443
- chunks = chunk_table_by_content(sheet, sheet_doc_id)
444
- all_chunks.extend(chunks)
445
-
446
- except Exception as e:
447
- log_message(f"Error loading {file_path}: {e}")
448
-
449
- log_message(f"✓ Loaded {len(all_chunks)} table chunks")
450
- return all_chunks
451
-
452
-
453
  def load_image_documents(repo_id, hf_token, image_dir):
454
  """Load image descriptions"""
455
  log_message("Loading images...")
 
412
  return documents
413
 
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  def load_image_documents(repo_id, hf_token, image_dir):
416
  """Load image descriptions"""
417
  log_message("Loading images...")
index_retriever.py CHANGED
@@ -24,56 +24,58 @@ def keyword_filter_nodes(query, nodes, min_keyword_matches=1):
24
  return filtered
25
 
26
  def create_query_engine(vector_index):
27
- """Create hybrid retrieval engine with deduplication"""
28
  log_message("Creating query engine...")
29
 
30
  vector_retriever = VectorIndexRetriever(
31
  index=vector_index,
32
- similarity_top_k=80
33
  )
34
  bm25_retriever = BM25Retriever.from_defaults(
35
  docstore=vector_index.docstore,
36
- similarity_top_k=80,
37
  )
38
  hybrid_retriever = QueryFusionRetriever(
39
  [vector_retriever, bm25_retriever],
40
- similarity_top_k=100,
41
  num_queries=1
42
  )
43
 
44
  class DeduplicatedQueryEngine(RetrieverQueryEngine):
45
  def retrieve(self, query):
46
  nodes = hybrid_retriever.retrieve(query)
 
47
 
48
- # CRITICAL: Deduplicate by text content hash
49
  seen_hashes = set()
50
  unique_nodes = []
 
51
 
52
  for node in nodes:
53
- text_hash = hash(node.text[:200])
 
54
 
55
  if text_hash not in seen_hashes:
56
  seen_hashes.add(text_hash)
57
  unique_nodes.append(node)
 
 
 
 
58
 
59
- log_message(f"Retrieved: {len(nodes)} Unique: {len(unique_nodes)}")
60
- return unique_nodes[:50] # Return top 50 unique
61
-
62
- # FIX: Override query method to use our retrieve
63
- def query(self, query_bundle):
64
- nodes = self.retrieve(query_bundle.query_str)
65
- response = self._response_synthesizer.synthesize(
66
- query=query_bundle,
67
- nodes=nodes
68
- )
69
- return response
70
 
71
  response_synthesizer = get_response_synthesizer()
72
 
73
  query_engine = DeduplicatedQueryEngine(
74
- retriever=hybrid_retriever, # Still pass it but we override retrieve()
75
  response_synthesizer=response_synthesizer
76
  )
77
 
78
- log_message("✓ Query engine created (with deduplication)")
79
  return query_engine
 
24
  return filtered
25
 
26
  def create_query_engine(vector_index):
27
+ """Create hybrid retrieval engine with better deduplication"""
28
  log_message("Creating query engine...")
29
 
30
  vector_retriever = VectorIndexRetriever(
31
  index=vector_index,
32
+ similarity_top_k=50 # Reduced to get more diverse results
33
  )
34
  bm25_retriever = BM25Retriever.from_defaults(
35
  docstore=vector_index.docstore,
36
+ similarity_top_k=50,
37
  )
38
  hybrid_retriever = QueryFusionRetriever(
39
  [vector_retriever, bm25_retriever],
40
+ similarity_top_k=60, # Reduced
41
  num_queries=1
42
  )
43
 
44
  class DeduplicatedQueryEngine(RetrieverQueryEngine):
45
  def retrieve(self, query):
46
  nodes = hybrid_retriever.retrieve(query)
47
+ log_message(f"Hybrid retrieval returned: {len(nodes)} nodes")
48
 
49
+ # Better deduplication using longer text snippet
50
  seen_hashes = set()
51
  unique_nodes = []
52
+ doc_type_counts = {'text': 0, 'table': 0, 'image': 0}
53
 
54
  for node in nodes:
55
+ # Use first 500 chars for dedup hash
56
+ text_hash = hash(node.text[:500])
57
 
58
  if text_hash not in seen_hashes:
59
  seen_hashes.add(text_hash)
60
  unique_nodes.append(node)
61
+
62
+ # Count by type
63
+ node_type = node.metadata.get('type', 'text')
64
+ doc_type_counts[node_type] = doc_type_counts.get(node_type, 0) + 1
65
 
66
+ log_message(f"After dedup: {len(unique_nodes)} unique nodes")
67
+ log_message(f"Types: text={doc_type_counts.get('text', 0)}, "
68
+ f"table={doc_type_counts.get('table', 0)}, "
69
+ f"image={doc_type_counts.get('image', 0)}")
70
+
71
+ return unique_nodes[:50]
 
 
 
 
 
72
 
73
  response_synthesizer = get_response_synthesizer()
74
 
75
  query_engine = DeduplicatedQueryEngine(
76
+ retriever=hybrid_retriever,
77
  response_synthesizer=response_synthesizer
78
  )
79
 
80
+ log_message("✓ Query engine created")
81
  return query_engine
utils.py CHANGED
@@ -47,29 +47,55 @@ def answer_question(question, query_engine, reranker):
47
  retrieved = query_engine.retrieve(question)
48
  log_message(f"RETRIEVED: {len(retrieved)} unique nodes")
49
 
50
- reranked = rerank_nodes(question, retrieved, reranker, top_k=25, min_score=0.1)
51
  log_message(f"RERANKED: {len(reranked)} nodes")
52
 
53
- context_parts = []
 
54
  for n in reranked:
55
- meta = n.metadata
56
- doc_id = meta.get('document_id', 'unknown')
57
- doc_type = meta.get('type', 'text')
58
- if doc_type == 'table':
59
- table_id = meta.get('table_identifier', meta.get('table_number', 'unknown'))
60
- title = meta.get('table_title', '')
61
- source_label = f"[{doc_id} - Таблица {table_id}]"
62
- if title:
63
- source_label += f" {title}"
64
  else:
65
- source_label = f"[{doc_id}]"
66
- context_parts.append(f"{source_label}\n{n.text}")
 
67
 
68
- context = "\n\n" + ("="*50 + "\n\n").join(context_parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  from config import CUSTOM_PROMPT
71
  prompt = CUSTOM_PROMPT.format(context_str=context, query_str=question)
72
- log_message(f"PROMPT LENGTH: {len(prompt)} chars")
73
 
74
  from llama_index.core import Settings
75
  response = Settings.llm.complete(prompt)
@@ -83,15 +109,30 @@ def answer_question(question, query_engine, reranker):
83
  log_message(traceback.format_exc())
84
  return f"Ошибка: {e}", ""
85
 
86
- def rerank_nodes(query, nodes, reranker, top_k=20, min_score=0.1):
87
- """Simple and effective reranking: sort by score and filter by threshold."""
88
  if not nodes or not reranker:
 
89
  return nodes[:top_k]
90
 
91
- pairs = [[query, n.text] for n in nodes]
92
  scores = reranker.predict(pairs)
93
  scored = sorted(zip(nodes, scores), key=lambda x: x[1], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  filtered = [n for n, s in scored if s >= min_score]
95
-
96
- # Return top_k filtered nodes, or fallback to top_k overall
97
- return filtered[:top_k] if filtered else [n for n, _ in scored[:top_k]]
 
 
47
  retrieved = query_engine.retrieve(question)
48
  log_message(f"RETRIEVED: {len(retrieved)} unique nodes")
49
 
50
+ reranked = rerank_nodes(question, retrieved, reranker, top_k=15, min_score=-0.5)
51
  log_message(f"RERANKED: {len(reranked)} nodes")
52
 
53
+ # Group by document and type
54
+ doc_groups = {}
55
  for n in reranked:
56
+ doc_id = n.metadata.get('document_id', 'unknown')
57
+ if doc_id not in doc_groups:
58
+ doc_groups[doc_id] = {'tables': [], 'text': [], 'images': []}
59
+
60
+ node_type = n.metadata.get('type', 'text')
61
+ if node_type == 'table':
62
+ doc_groups[doc_id]['tables'].append(n)
63
+ elif node_type == 'image':
64
+ doc_groups[doc_id]['images'].append(n)
65
  else:
66
+ doc_groups[doc_id]['text'].append(n)
67
+
68
+ log_message(f"Documents found: {list(doc_groups.keys())}")
69
 
70
+ # Format context by document
71
+ context_parts = []
72
+ for doc_id, groups in doc_groups.items():
73
+ doc_section = [f"=== ДОКУМЕНТ: {doc_id} ==="]
74
+
75
+ # Tables first (most important for your queries)
76
+ if groups['tables']:
77
+ doc_section.append("\n--- ТАБЛИЦЫ ---")
78
+ for n in groups['tables']:
79
+ meta = n.metadata
80
+ table_id = meta.get('table_identifier', meta.get('table_number', 'unknown'))
81
+ title = meta.get('table_title', '')
82
+ doc_section.append(f"\n[Таблица {table_id}] {title}")
83
+ doc_section.append(n.text[:1500]) # Limit length
84
+
85
+ # Then text
86
+ if groups['text']:
87
+ doc_section.append("\n--- ТЕКСТ ---")
88
+ for n in groups['text'][:3]: # Limit text chunks
89
+ doc_section.append(n.text[:800])
90
+
91
+ context_parts.append("\n".join(doc_section))
92
+
93
+ context = "\n\n" + ("="*70 + "\n\n").join(context_parts)
94
+
95
+ log_message(f"Context length: {len(context)} chars")
96
 
97
  from config import CUSTOM_PROMPT
98
  prompt = CUSTOM_PROMPT.format(context_str=context, query_str=question)
 
99
 
100
  from llama_index.core import Settings
101
  response = Settings.llm.complete(prompt)
 
109
  log_message(traceback.format_exc())
110
  return f"Ошибка: {e}", ""
111
 
112
+ def rerank_nodes(query, nodes, reranker, top_k=20, min_score=-0.5): # Much lower threshold
113
+ """Rerank with detailed score logging"""
114
  if not nodes or not reranker:
115
+ log_message("WARNING: No nodes or reranker available")
116
  return nodes[:top_k]
117
 
118
+ pairs = [[query, n.text[:500]] for n in nodes] # Limit text length for reranker
119
  scores = reranker.predict(pairs)
120
  scored = sorted(zip(nodes, scores), key=lambda x: x[1], reverse=True)
121
+
122
+ # Detailed logging
123
+ if scored:
124
+ top_5_scores = [s for _, s in scored[:5]]
125
+ bottom_5_scores = [s for _, s in scored[-5:]]
126
+ log_message(f"Score range: {min(scores):.3f} to {max(scores):.3f}")
127
+ log_message(f"Top 5 scores: {top_5_scores}")
128
+ log_message(f"Bottom 5 scores: {bottom_5_scores}")
129
+
130
+ # Count how many pass threshold
131
+ above_threshold = sum(1 for _, s in scored if s >= min_score)
132
+ log_message(f"Nodes above threshold ({min_score}): {above_threshold}/{len(scored)}")
133
+
134
  filtered = [n for n, s in scored if s >= min_score]
135
+ result = filtered[:top_k] if filtered else [n for n, _ in scored[:top_k]]
136
+
137
+ log_message(f"Returning {len(result)} nodes after reranking")
138
+ return result