MrSimple07 commited on
Commit
2edec29
·
1 Parent(s): 359257d

max rows = 10 + new answer_question + reranking

Browse files
Files changed (2) hide show
  1. documents_prep.py +1 -1
  2. utils.py +22 -130
documents_prep.py CHANGED
@@ -53,7 +53,7 @@ def normalize_doc_id(doc_id):
53
  return doc_id
54
 
55
 
56
- def chunk_table_by_rows(table_data, doc_id, max_rows=5):
57
  headers = table_data.get('headers', [])
58
  rows = table_data.get('data', [])
59
  table_num = table_data.get('table_number', 'unknown')
 
53
  return doc_id
54
 
55
 
56
+ def chunk_table_by_rows(table_data, doc_id, max_rows=10):
57
  headers = table_data.get('headers', [])
58
  rows = table_data.get('data', [])
59
  table_num = table_data.get('table_number', 'unknown')
utils.py CHANGED
@@ -41,71 +41,19 @@ def answer_question(question, query_engine, reranker):
41
  try:
42
  log_message(f"\n{'='*70}")
43
  log_message(f"QUERY: {question}")
44
-
45
-
46
- # Detect listing queries - need MORE chunks
47
- is_listing_query = any(phrase in question.lower()
48
- for phrase in ['какие таблиц', 'список', 'перечисл', 'все таблиц'])
49
-
50
  retrieved = query_engine.retriever.retrieve(question)
51
  log_message(f"\nRETRIEVED: {len(retrieved)} nodes")
52
-
53
- # Log retrieved docs
54
- doc_stats = {}
55
- for n in retrieved:
56
- doc_id = n.metadata.get('document_id', 'unknown')
57
- doc_group = n.metadata.get('document_group', doc_id)
58
-
59
- if doc_group not in doc_stats:
60
- doc_stats[doc_group] = {'tables': set(), 'text': 0}
61
-
62
- if n.metadata.get('type') == 'table':
63
- table_id = n.metadata.get('table_identifier', n.metadata.get('table_number', '?'))
64
- doc_stats[doc_group]['tables'].add(table_id)
65
- else:
66
- doc_stats[doc_group]['text'] += 1
67
-
68
- for doc_id in sorted(doc_stats.keys()):
69
- stats = doc_stats[doc_id]
70
- log_message(f" {doc_id}: {len(stats['tables'])} tables, {stats['text']} text")
71
- if stats['tables']:
72
- log_message(f" Tables: {sorted(stats['tables'])}")
73
-
74
- # Adjust reranking based on query type
75
- if is_listing_query:
76
- reranked = rerank_nodes(question, retrieved, reranker, top_k=50, min_score=0.2)
77
- else:
78
- reranked = rerank_nodes(question, retrieved, reranker, top_k=25, min_score=0.3)
79
-
80
  log_message(f"\nRERANKED: {len(reranked)} nodes")
81
-
82
- # Log reranked
83
- doc_stats_reranked = {}
84
- for n in reranked:
85
- doc_group = n.metadata.get('document_group', n.metadata.get('document_id', 'unknown'))
86
-
87
- if doc_group not in doc_stats_reranked:
88
- doc_stats_reranked[doc_group] = {'tables': set(), 'text': 0}
89
-
90
- if n.metadata.get('type') == 'table':
91
- table_id = n.metadata.get('table_identifier', n.metadata.get('table_number', '?'))
92
- doc_stats_reranked[doc_group]['tables'].add(table_id)
93
- else:
94
- doc_stats_reranked[doc_group]['text'] += 1
95
-
96
- for doc_id in sorted(doc_stats_reranked.keys()):
97
- stats = doc_stats_reranked[doc_id]
98
- log_message(f" {doc_id}: {len(stats['tables'])} tables, {stats['text']} text")
99
- if stats['tables']:
100
- log_message(f" Tables: {sorted(stats['tables'])}")
101
-
102
- # Build context
103
  context_parts = []
104
  for n in reranked:
105
  meta = n.metadata
106
  doc_id = meta.get('document_id', 'unknown')
107
  doc_type = meta.get('type', 'text')
108
-
109
  if doc_type == 'table':
110
  table_id = meta.get('table_identifier', meta.get('table_number', 'unknown'))
111
  title = meta.get('table_title', '')
@@ -114,47 +62,21 @@ def answer_question(question, query_engine, reranker):
114
  source_label += f" {title}"
115
  else:
116
  source_label = f"[{doc_id}]"
117
-
118
  context_parts.append(f"{source_label}\n{n.text[:500]}") # Limit context per chunk
119
-
120
- context = "\n\n" + ("="*50 + "\n\n").join(context_parts)
121
-
122
- # Adjust prompt for listing queries
123
- if is_listing_query:
124
- prompt = f"""Контекст содержит информацию о таблицах из документов.
125
-
126
- КОНТЕКСТ:
127
- {context}
128
-
129
- ВОПРОС: {question}
130
 
131
- ИНСТРУКЦИИ:
132
- 1. Перечисли ВСЕ таблицы, найденные в контексте для запрошенного документа
133
- 2. Укажи номер таблицы и название (если есть)
134
- 3. Если таблиц нет - скажи прямо
135
-
136
- ОТВЕТ (список таблиц):"""
137
- else:
138
- prompt = f"""Ты эксперт по технической документации.
139
-
140
- КОНТЕКСТ:
141
- {context}
142
-
143
- ВОПРОС: {question}
144
-
145
- ИНСТРУКЦИИ:
146
- 1. Отвечай ТОЛЬКО на основе контекста
147
- 2. Укажи источник (документ, таблицу)
148
- 3. Если нужно показать содержимое таблицы - покажи ВСЕ данные
149
- 4. Если информации нет - скажи прямо
150
 
151
- ОТВЕТ:"""
152
-
 
 
153
  response = query_engine.query(prompt)
 
154
  sources = format_sources(reranked)
155
-
 
156
  return response.response, sources
157
-
158
  except Exception as e:
159
  log_message(f"Error: {e}")
160
  import traceback
@@ -163,44 +85,14 @@ def answer_question(question, query_engine, reranker):
163
 
164
 
165
  def rerank_nodes(query, nodes, reranker, top_k=25, min_score=0.3):
166
- """Rerank with document grouping awareness"""
167
- if not nodes:
168
- return []
169
-
170
  pairs = [[query, n.text] for n in nodes]
171
  scores = reranker.predict(pairs)
172
-
173
  scored = sorted(zip(nodes, scores), key=lambda x: x[1], reverse=True)
174
-
175
- log_message(f"Top 10 reranking scores: {[f'{s:.3f}' for _, s in scored[:10]]}")
176
-
177
- # More lenient filtering
178
- filtered = [(n, s) for n, s in scored if s >= min_score]
179
-
180
- if not filtered:
181
- cutoff = max(scores) * 0.4
182
- filtered = [(n, s) for n, s in scored if s >= cutoff][:top_k]
183
-
184
- # Group by document for diversity
185
- doc_groups = {}
186
- for node, score in filtered:
187
- doc_group = node.metadata.get('document_group', node.metadata.get('document_id', 'unknown'))
188
- if doc_group not in doc_groups:
189
- doc_groups[doc_group] = []
190
- doc_groups[doc_group].append((node, score))
191
-
192
- # Take top chunks from each document group
193
- selected = []
194
- group_limits = max(3, top_k // max(1, len(doc_groups)))
195
-
196
- for doc_group in doc_groups:
197
- selected.extend([n for n, s in doc_groups[doc_group][:group_limits]])
198
-
199
- # Fill remaining slots with highest scores
200
- if len(selected) < top_k:
201
- remaining = [n for n, s in filtered if n not in selected]
202
- selected.extend(remaining[:top_k - len(selected)])
203
-
204
- log_message(f"Reranked: {len(filtered)} → {len(selected)} (from {len(doc_groups)} doc groups)")
205
-
206
- return selected[:top_k]
 
41
  try:
42
  log_message(f"\n{'='*70}")
43
  log_message(f"QUERY: {question}")
44
+
45
+ # Retrieve and rerank nodes
 
 
 
 
46
  retrieved = query_engine.retriever.retrieve(question)
47
  log_message(f"\nRETRIEVED: {len(retrieved)} nodes")
48
+ reranked = rerank_nodes(question, retrieved, reranker, top_k=25, min_score=0.3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  log_message(f"\nRERANKED: {len(reranked)} nodes")
50
+
51
+ # Build context for prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  context_parts = []
53
  for n in reranked:
54
  meta = n.metadata
55
  doc_id = meta.get('document_id', 'unknown')
56
  doc_type = meta.get('type', 'text')
 
57
  if doc_type == 'table':
58
  table_id = meta.get('table_identifier', meta.get('table_number', 'unknown'))
59
  title = meta.get('table_title', '')
 
62
  source_label += f" {title}"
63
  else:
64
  source_label = f"[{doc_id}]"
 
65
  context_parts.append(f"{source_label}\n{n.text[:500]}") # Limit context per chunk
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ context = "\n\n" + ("="*50 + "\n\n").join(context_parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # Use only CUSTOM_PROMPT from config
70
+ from config import CUSTOM_PROMPT
71
+ prompt = CUSTOM_PROMPT.format(context_str=context, query_str=question)
72
+ log_message(f"\nPROMPT:\n{prompt[:300]}...\n") # Log first 1000 chars of prompt
73
  response = query_engine.query(prompt)
74
+
75
  sources = format_sources(reranked)
76
+ for i in reranked:
77
+ log_message(f"---\n{i.text[:500]}\n...")
78
  return response.response, sources
79
+
80
  except Exception as e:
81
  log_message(f"Error: {e}")
82
  import traceback
 
85
 
86
 
87
  def rerank_nodes(query, nodes, reranker, top_k=25, min_score=0.3):
88
+ """Simple and effective reranking: sort by score and filter by threshold."""
89
+ if not nodes or not reranker:
90
+ return nodes[:top_k]
91
+
92
  pairs = [[query, n.text] for n in nodes]
93
  scores = reranker.predict(pairs)
 
94
  scored = sorted(zip(nodes, scores), key=lambda x: x[1], reverse=True)
95
+ filtered = [n for n, s in scored if s >= min_score]
96
+
97
+ # Return top_k filtered nodes, or fallback to top_k overall
98
+ return filtered[:top_k] if filtered else [n for n, _ in scored[:top_k]]