MrSimple07 commited on
Commit
d577496
·
1 Parent(s): c0c8ab9

top k reranker = 20, max rows = 10, max chars= 4000 + new deduplication

Browse files
Files changed (3) hide show
  1. documents_prep.py +1 -1
  2. index_retriever.py +8 -101
  3. utils.py +1 -1
documents_prep.py CHANGED
@@ -38,7 +38,7 @@ def chunk_text_documents(documents):
38
  return chunked
39
 
40
 
41
- def chunk_table_by_rows(table_data, doc_id, rows_per_chunk=5, max_chars=4000):
42
  """
43
  Chunk tables by rows with fallback to character limit.
44
  Keeps 3-4 rows together, but splits individual rows if they're too large.
 
38
  return chunked
39
 
40
 
41
+ def chunk_table_by_rows(table_data, doc_id, rows_per_chunk=10, max_chars=4000):
42
  """
43
  Chunk tables by rows with fallback to character limit.
44
  Keeps 3-4 rows together, but splits individual rows if they're too large.
index_retriever.py CHANGED
@@ -6,12 +6,6 @@ from llama_index.core.retrievers import QueryFusionRetriever
6
  from llama_index.core.response_synthesizers import get_response_synthesizer
7
  from my_logging import log_message
8
 
9
- import re
10
-
11
- import re
12
- from difflib import SequenceMatcher
13
-
14
-
15
  def create_vector_index(documents):
16
  """Create vector index from documents"""
17
  log_message(f"Building vector index from {len(documents)} documents...")
@@ -29,96 +23,21 @@ def keyword_filter_nodes(query, nodes, min_keyword_matches=1):
29
  filtered.append(node)
30
  return filtered
31
 
32
-
33
- def normalize_doc_id(doc_id: str) -> str:
34
- """Normalize document ID - KEEP dots for numeric parts"""
35
- doc_id = doc_id.upper().strip()
36
- doc_id = re.sub(r'\s+', '', doc_id) # Remove spaces only
37
- doc_id = doc_id.replace("ГОСТР", "ГОСТ")
38
- doc_id = doc_id.replace("GOSTR", "ГОСТ")
39
- return doc_id
40
-
41
- def base_number(doc_id: str) -> str:
42
- """Extract full numeric pattern including all parts (e.g., '59023.6' from 'ГОСТ 59023.6')"""
43
- # Match: 59023.6 or 59023.4 or 50.05.01 etc.
44
- m = re.search(r'(\d+(?:\.\d+)*)', doc_id)
45
- return m.group(1) if m else ""
46
-
47
- def filter_nodes_by_doc_id(nodes, doc_ids, threshold=0.85):
48
- """Filter nodes by document ID with strict numeric matching"""
49
- if not doc_ids:
50
- return nodes
51
-
52
- filtered = []
53
- doc_ids_norm = [normalize_doc_id(d) for d in doc_ids]
54
- doc_ids_base = [base_number(d) for d in doc_ids_norm]
55
-
56
- for node in nodes:
57
- node_doc_id = normalize_doc_id(node.metadata.get('document_id', ''))
58
- node_base = base_number(node_doc_id)
59
-
60
- for q_doc, q_base in zip(doc_ids_norm, doc_ids_base):
61
- # STRICT: base number must match exactly
62
- if q_base and node_base and q_base == node_base:
63
- filtered.append(node)
64
- break
65
-
66
- # STRICT: full normalized ID must match exactly or have very high similarity
67
- elif SequenceMatcher(None, node_doc_id, q_doc).ratio() >= threshold:
68
- filtered.append(node)
69
- break
70
-
71
- return filtered if filtered else nodes
72
-
73
-
74
- def extract_doc_id_from_query(query):
75
- """Extract document IDs from query text with better pattern matching"""
76
- patterns = [
77
- r'ГОСТ\s*Р?\s*\d+(?:\.\d+)*(?:-\d{4})?', # ГОСТ 59023.4, ГОСТ Р 50.05.01-2018
78
- r'НП-\d+(?:-\d+)?', # НП-104-18
79
- r'МУ[_\s]\d+(?:\.\d+)+(?:\.\d+)*(?:-\d{4})?', # МУ 1.2.3.07.0057-2018
80
- ]
81
-
82
- found_ids = []
83
- for pattern in patterns:
84
- matches = re.findall(pattern, query, re.IGNORECASE)
85
- found_ids.extend(matches)
86
-
87
- # Normalize spacing and preserve dots
88
- normalized = [re.sub(r'\s+', ' ', id.strip().upper()) for id in found_ids]
89
- return normalized
90
- def russian_tokenizer(text):
91
- """Better tokenizer for Russian document IDs and technical terms"""
92
- import re
93
-
94
- # Keep document ID patterns intact
95
- text = re.sub(r'(ГОСТ\s*Р?\s*[\d\.]+(?:-\d{4})?)', r' \1 ', text)
96
- text = re.sub(r'(НП-\d+(?:-\d+)?)', r' \1 ', text)
97
- text = re.sub(r'(МУ[_\s][\d\.]+)', r' \1 ', text)
98
-
99
- # Split on whitespace and punctuation, but keep numbers with decimals
100
- tokens = re.findall(r'\d+\.\d+|\w+', text.lower())
101
-
102
- return tokens
103
-
104
-
105
  def create_query_engine(vector_index):
106
- """Create hybrid retrieval engine with document ID filtering"""
107
  log_message("Creating query engine...")
108
 
109
  vector_retriever = VectorIndexRetriever(
110
  index=vector_index,
111
- similarity_top_k=100
112
  )
113
  bm25_retriever = BM25Retriever.from_defaults(
114
  docstore=vector_index.docstore,
115
- similarity_top_k=100,
116
- tokenizer=russian_tokenizer # Add custom tokenizer
117
-
118
  )
119
  hybrid_retriever = QueryFusionRetriever(
120
  [vector_retriever, bm25_retriever],
121
- similarity_top_k=60,
122
  num_queries=1
123
  )
124
 
@@ -127,28 +46,20 @@ def create_query_engine(vector_index):
127
  nodes = hybrid_retriever.retrieve(query)
128
  log_message(f"Hybrid retrieval returned: {len(nodes)} nodes")
129
 
130
- # Extract document IDs from query
131
- doc_ids = extract_doc_id_from_query(query)
132
- if doc_ids:
133
- log_message(f"Detected document IDs in query: {doc_ids}")
134
- before = len(nodes)
135
- nodes = filter_nodes_by_doc_id(nodes, doc_ids)
136
- after = len(nodes)
137
- log_message(f"Filtered by doc ID: {after}/{before} nodes kept (fallback safe)")
138
-
139
-
140
- # Deduplication
141
  seen_hashes = set()
142
  unique_nodes = []
143
  doc_type_counts = {'text': 0, 'table': 0, 'image': 0}
144
 
145
  for node in nodes:
 
146
  text_hash = hash(node.text[:500])
147
 
148
  if text_hash not in seen_hashes:
149
  seen_hashes.add(text_hash)
150
  unique_nodes.append(node)
151
 
 
152
  node_type = node.metadata.get('type', 'text')
153
  doc_type_counts[node_type] = doc_type_counts.get(node_type, 0) + 1
154
 
@@ -157,10 +68,6 @@ def create_query_engine(vector_index):
157
  f"table={doc_type_counts.get('table', 0)}, "
158
  f"image={doc_type_counts.get('image', 0)}")
159
 
160
- # Log which documents we're returning
161
- returned_docs = set(n.metadata.get('document_id', 'unknown') for n in unique_nodes[:50])
162
- log_message(f"Returning nodes from: {sorted(returned_docs)}")
163
-
164
  return unique_nodes[:50]
165
 
166
  response_synthesizer = get_response_synthesizer()
@@ -170,5 +77,5 @@ def create_query_engine(vector_index):
170
  response_synthesizer=response_synthesizer
171
  )
172
 
173
- log_message("✓ Query engine created with doc ID filtering")
174
  return query_engine
 
6
  from llama_index.core.response_synthesizers import get_response_synthesizer
7
  from my_logging import log_message
8
 
 
 
 
 
 
 
9
  def create_vector_index(documents):
10
  """Create vector index from documents"""
11
  log_message(f"Building vector index from {len(documents)} documents...")
 
23
  filtered.append(node)
24
  return filtered
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def create_query_engine(vector_index):
27
+ """Create hybrid retrieval engine with better deduplication"""
28
  log_message("Creating query engine...")
29
 
30
  vector_retriever = VectorIndexRetriever(
31
  index=vector_index,
32
+ similarity_top_k=50 # Reduced to get more diverse results
33
  )
34
  bm25_retriever = BM25Retriever.from_defaults(
35
  docstore=vector_index.docstore,
36
+ similarity_top_k=50,
 
 
37
  )
38
  hybrid_retriever = QueryFusionRetriever(
39
  [vector_retriever, bm25_retriever],
40
+ similarity_top_k=60, # Reduced
41
  num_queries=1
42
  )
43
 
 
46
  nodes = hybrid_retriever.retrieve(query)
47
  log_message(f"Hybrid retrieval returned: {len(nodes)} nodes")
48
 
49
+ # Better deduplication using longer text snippet
 
 
 
 
 
 
 
 
 
 
50
  seen_hashes = set()
51
  unique_nodes = []
52
  doc_type_counts = {'text': 0, 'table': 0, 'image': 0}
53
 
54
  for node in nodes:
55
+ # Use first 500 chars for dedup hash
56
  text_hash = hash(node.text[:500])
57
 
58
  if text_hash not in seen_hashes:
59
  seen_hashes.add(text_hash)
60
  unique_nodes.append(node)
61
 
62
+ # Count by type
63
  node_type = node.metadata.get('type', 'text')
64
  doc_type_counts[node_type] = doc_type_counts.get(node_type, 0) + 1
65
 
 
68
  f"table={doc_type_counts.get('table', 0)}, "
69
  f"image={doc_type_counts.get('image', 0)}")
70
 
 
 
 
 
71
  return unique_nodes[:50]
72
 
73
  response_synthesizer = get_response_synthesizer()
 
77
  response_synthesizer=response_synthesizer
78
  )
79
 
80
+ log_message("✓ Query engine created")
81
  return query_engine
utils.py CHANGED
@@ -47,7 +47,7 @@ def answer_question(question, query_engine, reranker):
47
  retrieved = query_engine.retrieve(question)
48
  log_message(f"RETRIEVED: {len(retrieved)} unique nodes")
49
 
50
- reranked = rerank_nodes(question, retrieved, reranker, top_k=25, min_score=-0.5)
51
  log_message(f"RERANKED: {len(reranked)} nodes")
52
 
53
  # Group by document and type
 
47
  retrieved = query_engine.retrieve(question)
48
  log_message(f"RETRIEVED: {len(retrieved)} unique nodes")
49
 
50
+ reranked = rerank_nodes(question, retrieved, reranker, top_k=20, min_score=-0.5)
51
  log_message(f"RERANKED: {len(reranked)} nodes")
52
 
53
  # Group by document and type