Spaces:
Sleeping
Sleeping
Commit
·
eefdfd0
1
Parent(s):
90e6b4c
hybrid_retrieve_with_keywords implemented
Browse files- index_retriever.py +113 -137
- utils.py +8 -11
index_retriever.py
CHANGED
|
@@ -49,6 +49,119 @@ def create_query_engine(vector_index):
|
|
| 49 |
log_message(f"Ошибка создания query engine: {str(e)}")
|
| 50 |
raise
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def rerank_nodes(query, nodes, reranker, top_k=25, min_score_threshold=0.5, diversity_penalty=0.3):
|
| 53 |
if not nodes or not reranker:
|
| 54 |
return nodes[:top_k]
|
|
@@ -114,140 +227,3 @@ def rerank_nodes(query, nodes, reranker, top_k=25, min_score_threshold=0.5, dive
|
|
| 114 |
log_message(f"Ошибка переранжировки: {str(e)}")
|
| 115 |
return nodes[:top_k]
|
| 116 |
|
| 117 |
-
|
| 118 |
-
from rank_bm25 import BM25Okapi
|
| 119 |
-
import numpy as np
|
| 120 |
-
|
| 121 |
-
class HybridRetriever:
|
| 122 |
-
def __init__(self, vector_retriever, documents):
|
| 123 |
-
self.vector_retriever = vector_retriever
|
| 124 |
-
self.documents = documents
|
| 125 |
-
|
| 126 |
-
# Build BM25 index
|
| 127 |
-
tokenized_docs = [doc.text.lower().split() for doc in documents]
|
| 128 |
-
self.bm25 = BM25Okapi(tokenized_docs)
|
| 129 |
-
|
| 130 |
-
# Build metadata index for exact matching
|
| 131 |
-
self.metadata_index = self._build_metadata_index(documents)
|
| 132 |
-
|
| 133 |
-
def _build_metadata_index(self, documents):
|
| 134 |
-
"""Index by materials, GOSTs, classes for exact matching"""
|
| 135 |
-
index = {
|
| 136 |
-
'materials': {},
|
| 137 |
-
'gosts': {},
|
| 138 |
-
'classes': {},
|
| 139 |
-
'key_terms': {}
|
| 140 |
-
}
|
| 141 |
-
|
| 142 |
-
for i, doc in enumerate(documents):
|
| 143 |
-
metadata = doc.metadata
|
| 144 |
-
|
| 145 |
-
# Index materials
|
| 146 |
-
for material in metadata.get('materials', []):
|
| 147 |
-
if material not in index['materials']:
|
| 148 |
-
index['materials'][material] = []
|
| 149 |
-
index['materials'][material].append(i)
|
| 150 |
-
|
| 151 |
-
# Index GOSTs
|
| 152 |
-
for gost in metadata.get('gosts', []):
|
| 153 |
-
if gost not in index['gosts']:
|
| 154 |
-
index['gosts'][gost] = []
|
| 155 |
-
index['gosts'][gost].append(i)
|
| 156 |
-
|
| 157 |
-
# Index classes
|
| 158 |
-
for cls in metadata.get('classes', []):
|
| 159 |
-
if cls not in index['classes']:
|
| 160 |
-
index['classes'][cls] = []
|
| 161 |
-
index['classes'][cls].append(i)
|
| 162 |
-
|
| 163 |
-
# Index key terms
|
| 164 |
-
for term in metadata.get('key_terms', []):
|
| 165 |
-
term_lower = term.lower()
|
| 166 |
-
if term_lower not in index['key_terms']:
|
| 167 |
-
index['key_terms'][term_lower] = []
|
| 168 |
-
index['key_terms'][term_lower].append(i)
|
| 169 |
-
|
| 170 |
-
return index
|
| 171 |
-
|
| 172 |
-
def retrieve(self, query, top_k=20, vector_weight=0.5, bm25_weight=0.3, metadata_weight=0.2):
|
| 173 |
-
"""Hybrid retrieval combining vector, BM25, and metadata matching"""
|
| 174 |
-
|
| 175 |
-
# 1. Vector search
|
| 176 |
-
vector_results = self.vector_retriever.retrieve(query)
|
| 177 |
-
vector_scores = {node.node_id: node.score for node in vector_results}
|
| 178 |
-
|
| 179 |
-
# 2. BM25 search
|
| 180 |
-
tokenized_query = query.lower().split()
|
| 181 |
-
bm25_scores = self.bm25.get_scores(tokenized_query)
|
| 182 |
-
|
| 183 |
-
# 3. Metadata exact matching
|
| 184 |
-
metadata_scores = self._get_metadata_scores(query)
|
| 185 |
-
|
| 186 |
-
# 4. Combine scores
|
| 187 |
-
all_node_ids = set(list(vector_scores.keys()) +
|
| 188 |
-
list(range(len(self.documents))))
|
| 189 |
-
|
| 190 |
-
combined_scores = {}
|
| 191 |
-
for node_id in all_node_ids:
|
| 192 |
-
vec_score = vector_scores.get(node_id, 0.0)
|
| 193 |
-
bm25_score = bm25_scores[node_id] if isinstance(node_id, int) and node_id < len(bm25_scores) else 0.0
|
| 194 |
-
meta_score = metadata_scores.get(node_id, 0.0)
|
| 195 |
-
|
| 196 |
-
# Normalize and combine
|
| 197 |
-
combined_scores[node_id] = (
|
| 198 |
-
vector_weight * vec_score +
|
| 199 |
-
bm25_weight * (bm25_score / (max(bm25_scores) + 1e-6)) +
|
| 200 |
-
metadata_weight * meta_score
|
| 201 |
-
)
|
| 202 |
-
|
| 203 |
-
# 5. Get top-k
|
| 204 |
-
sorted_nodes = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
|
| 205 |
-
|
| 206 |
-
# Return as node objects
|
| 207 |
-
results = []
|
| 208 |
-
for node_id, score in sorted_nodes:
|
| 209 |
-
if isinstance(node_id, int) and node_id < len(self.documents):
|
| 210 |
-
doc = self.documents[node_id]
|
| 211 |
-
# Create node-like object
|
| 212 |
-
from types import SimpleNamespace
|
| 213 |
-
node = SimpleNamespace(
|
| 214 |
-
text=doc.text,
|
| 215 |
-
metadata=doc.metadata,
|
| 216 |
-
score=score,
|
| 217 |
-
node_id=node_id
|
| 218 |
-
)
|
| 219 |
-
results.append(node)
|
| 220 |
-
|
| 221 |
-
return results
|
| 222 |
-
|
| 223 |
-
def _get_metadata_scores(self, query):
|
| 224 |
-
"""Score documents by exact metadata matches"""
|
| 225 |
-
scores = {}
|
| 226 |
-
query_lower = query.lower()
|
| 227 |
-
|
| 228 |
-
# Check for material codes
|
| 229 |
-
import re
|
| 230 |
-
material_pattern = r'\b\d{2}[ХНТМКВБА]+\d{1,2}[ХНТМКВБА]*\d*\b'
|
| 231 |
-
materials_in_query = re.findall(material_pattern, query, re.IGNORECASE)
|
| 232 |
-
|
| 233 |
-
for material in materials_in_query:
|
| 234 |
-
if material in self.metadata_index['materials']:
|
| 235 |
-
for doc_id in self.metadata_index['materials'][material]:
|
| 236 |
-
scores[doc_id] = scores.get(doc_id, 0) + 1.0
|
| 237 |
-
|
| 238 |
-
# Check for GOSTs
|
| 239 |
-
gost_pattern = r'ГОСТ\s+[РЕН\s]*\d+[\.\-\d]*'
|
| 240 |
-
gosts_in_query = re.findall(gost_pattern, query, re.IGNORECASE)
|
| 241 |
-
|
| 242 |
-
for gost in gosts_in_query:
|
| 243 |
-
if gost in self.metadata_index['gosts']:
|
| 244 |
-
for doc_id in self.metadata_index['gosts'][gost]:
|
| 245 |
-
scores[doc_id] = scores.get(doc_id, 0) + 0.8
|
| 246 |
-
|
| 247 |
-
# Check for key terms
|
| 248 |
-
for term, doc_ids in self.metadata_index['key_terms'].items():
|
| 249 |
-
if term in query_lower:
|
| 250 |
-
for doc_id in doc_ids:
|
| 251 |
-
scores[doc_id] = scores.get(doc_id, 0) + 0.5
|
| 252 |
-
|
| 253 |
-
return scores
|
|
|
|
| 49 |
log_message(f"Ошибка создания query engine: {str(e)}")
|
| 50 |
raise
|
| 51 |
|
| 52 |
+
import re
|
| 53 |
+
from typing import List, Dict, Set
|
| 54 |
+
from my_logging import log_message
|
| 55 |
+
|
| 56 |
+
def extract_keywords_from_query(query: str) -> Dict[str, List[str]]:
|
| 57 |
+
"""Extract technical keywords from query"""
|
| 58 |
+
keywords = {
|
| 59 |
+
'materials': [],
|
| 60 |
+
'gosts': [],
|
| 61 |
+
'classes': [],
|
| 62 |
+
'technical_terms': []
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
# Material codes: 08Х18Н10Т, 12Х18Н10Т, etc.
|
| 66 |
+
material_pattern = r'\b\d{2}[ХНТМКВБА]+\d{1,2}[ХНТМКВБА]*\d*\b'
|
| 67 |
+
keywords['materials'] = re.findall(material_pattern, query, re.IGNORECASE)
|
| 68 |
+
|
| 69 |
+
# GOST standards
|
| 70 |
+
gost_pattern = r'ГОСТ\s+[РЕН\s]*\d+[\.\-\d]*'
|
| 71 |
+
keywords['gosts'] = re.findall(gost_pattern, query, re.IGNORECASE)
|
| 72 |
+
|
| 73 |
+
# Classification codes: 3СIIIa, 1А, 2BII, etc.
|
| 74 |
+
class_pattern = r'\b\d[АБВГСD]+[IV]+[a-z]?\b'
|
| 75 |
+
keywords['classes'] = re.findall(class_pattern, query, re.IGNORECASE)
|
| 76 |
+
|
| 77 |
+
# Technical terms
|
| 78 |
+
terms = ['полуфабрикат', 'план качества', 'контроль', 'арматура',
|
| 79 |
+
'ультразвуковой', 'сварка', 'испытание']
|
| 80 |
+
for term in terms:
|
| 81 |
+
if term.lower() in query.lower():
|
| 82 |
+
keywords['technical_terms'].append(term)
|
| 83 |
+
|
| 84 |
+
return keywords
|
| 85 |
+
|
| 86 |
+
def keyword_search_nodes(nodes: List, keywords: Dict[str, List[str]]) -> List:
|
| 87 |
+
"""Filter nodes by exact keyword matches"""
|
| 88 |
+
if not any(keywords.values()):
|
| 89 |
+
return nodes
|
| 90 |
+
|
| 91 |
+
matched_nodes = []
|
| 92 |
+
|
| 93 |
+
for node in nodes:
|
| 94 |
+
text_lower = node.text.lower()
|
| 95 |
+
metadata = node.metadata if hasattr(node, 'metadata') else {}
|
| 96 |
+
|
| 97 |
+
# Check materials
|
| 98 |
+
for material in keywords['materials']:
|
| 99 |
+
if material.lower() in text_lower:
|
| 100 |
+
matched_nodes.append(node)
|
| 101 |
+
break
|
| 102 |
+
else:
|
| 103 |
+
# Check GOSTs
|
| 104 |
+
for gost in keywords['gosts']:
|
| 105 |
+
if gost.lower() in text_lower:
|
| 106 |
+
matched_nodes.append(node)
|
| 107 |
+
break
|
| 108 |
+
else:
|
| 109 |
+
# Check classes
|
| 110 |
+
for cls in keywords['classes']:
|
| 111 |
+
if cls.lower() in text_lower:
|
| 112 |
+
matched_nodes.append(node)
|
| 113 |
+
break
|
| 114 |
+
else:
|
| 115 |
+
# Check technical terms (at least 2 matches)
|
| 116 |
+
term_matches = sum(1 for term in keywords['technical_terms']
|
| 117 |
+
if term.lower() in text_lower)
|
| 118 |
+
if term_matches >= 2:
|
| 119 |
+
matched_nodes.append(node)
|
| 120 |
+
|
| 121 |
+
return matched_nodes
|
| 122 |
+
|
| 123 |
+
def hybrid_retrieve_with_keywords(question: str, query_engine, top_k: int = 40) -> List:
|
| 124 |
+
"""Retrieve using both vector search and keyword matching"""
|
| 125 |
+
|
| 126 |
+
# Extract keywords from query
|
| 127 |
+
keywords = extract_keywords_from_query(question)
|
| 128 |
+
log_message(f"Извлечены ключевые слова: {keywords}")
|
| 129 |
+
|
| 130 |
+
# Get vector search results
|
| 131 |
+
vector_nodes = query_engine.retriever.retrieve(question)
|
| 132 |
+
log_message(f"Векторный поиск: {len(vector_nodes)} узлов")
|
| 133 |
+
|
| 134 |
+
# Apply keyword filtering
|
| 135 |
+
if any(keywords.values()):
|
| 136 |
+
keyword_nodes = keyword_search_nodes(vector_nodes, keywords)
|
| 137 |
+
log_message(f"После фильтрации по ключевым словам: {len(keyword_nodes)} узлов")
|
| 138 |
+
|
| 139 |
+
# If keyword search found results, prioritize them
|
| 140 |
+
if keyword_nodes:
|
| 141 |
+
# Deduplicate and combine
|
| 142 |
+
seen_ids = set()
|
| 143 |
+
combined_nodes = []
|
| 144 |
+
|
| 145 |
+
# First add keyword matches
|
| 146 |
+
for node in keyword_nodes[:top_k]:
|
| 147 |
+
node_id = id(node)
|
| 148 |
+
if node_id not in seen_ids:
|
| 149 |
+
combined_nodes.append(node)
|
| 150 |
+
seen_ids.add(node_id)
|
| 151 |
+
|
| 152 |
+
# Then fill with vector results
|
| 153 |
+
for node in vector_nodes:
|
| 154 |
+
if len(combined_nodes) >= top_k:
|
| 155 |
+
break
|
| 156 |
+
node_id = id(node)
|
| 157 |
+
if node_id not in seen_ids:
|
| 158 |
+
combined_nodes.append(node)
|
| 159 |
+
seen_ids.add(node_id)
|
| 160 |
+
|
| 161 |
+
return combined_nodes[:top_k]
|
| 162 |
+
|
| 163 |
+
return vector_nodes[:top_k]
|
| 164 |
+
|
| 165 |
def rerank_nodes(query, nodes, reranker, top_k=25, min_score_threshold=0.5, diversity_penalty=0.3):
|
| 166 |
if not nodes or not reranker:
|
| 167 |
return nodes[:top_k]
|
|
|
|
| 227 |
log_message(f"Ошибка переранжировки: {str(e)}")
|
| 228 |
return nodes[:top_k]
|
| 229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
CHANGED
|
@@ -232,6 +232,7 @@ def generate_sources_html(nodes, chunks_df=None):
|
|
| 232 |
html += "</div>"
|
| 233 |
return html
|
| 234 |
def answer_question(question, query_engine, reranker, current_model, chunks_df=None, hybrid_retriever=None):
|
|
|
|
| 235 |
if query_engine is None:
|
| 236 |
return "<div style='background-color: #e53e3e; color: white; padding: 20px; border-radius: 10px;'>Система не инициализирована</div>", "", ""
|
| 237 |
|
|
@@ -240,22 +241,18 @@ def answer_question(question, query_engine, reranker, current_model, chunks_df=N
|
|
| 240 |
|
| 241 |
llm = get_llm_model(current_model)
|
| 242 |
|
| 243 |
-
# Use
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
log_message(f"Hybrid retrieval: получено {len(retrieved_nodes)} узлов")
|
| 247 |
-
else:
|
| 248 |
-
retrieved_nodes = query_engine.retriever.retrieve(question)
|
| 249 |
-
log_message(f"Vector retrieval: получено {len(retrieved_nodes)} узлов")
|
| 250 |
|
| 251 |
-
# Rerank
|
| 252 |
reranked_nodes = rerank_nodes(
|
| 253 |
question,
|
| 254 |
retrieved_nodes,
|
| 255 |
reranker,
|
| 256 |
-
top_k=25,
|
| 257 |
-
min_score_threshold=0.3,
|
| 258 |
-
diversity_penalty=0.2
|
| 259 |
)
|
| 260 |
|
| 261 |
formatted_context = format_context_for_llm(reranked_nodes)
|
|
|
|
| 232 |
html += "</div>"
|
| 233 |
return html
|
| 234 |
def answer_question(question, query_engine, reranker, current_model, chunks_df=None, hybrid_retriever=None):
|
| 235 |
+
from index_retriever import hybrid_retrieve_with_keywords
|
| 236 |
if query_engine is None:
|
| 237 |
return "<div style='background-color: #e53e3e; color: white; padding: 20px; border-radius: 10px;'>Система не инициализирована</div>", "", ""
|
| 238 |
|
|
|
|
| 241 |
|
| 242 |
llm = get_llm_model(current_model)
|
| 243 |
|
| 244 |
+
# Use keyword-enhanced retrieval
|
| 245 |
+
retrieved_nodes = hybrid_retrieve_with_keywords(question, query_engine, top_k=40)
|
| 246 |
+
log_message(f"Hybrid keyword retrieval: получено {len(retrieved_nodes)} узлов")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
+
# Rerank
|
| 249 |
reranked_nodes = rerank_nodes(
|
| 250 |
question,
|
| 251 |
retrieved_nodes,
|
| 252 |
reranker,
|
| 253 |
+
top_k=25,
|
| 254 |
+
min_score_threshold=0.3,
|
| 255 |
+
diversity_penalty=0.2
|
| 256 |
)
|
| 257 |
|
| 258 |
formatted_context = format_context_for_llm(reranked_nodes)
|