MrSimple07 commited on
Commit
ad5ae30
·
1 Parent(s): f74c675

Extracts weld type from query

Browse files
Files changed (1) hide show
  1. index_retriever.py +106 -4
index_retriever.py CHANGED
@@ -65,9 +65,96 @@ def rerank_nodes(query, nodes, reranker, top_k=25, min_score_threshold=0.5):
65
  log_message(f"Ошибка переранжировки: {str(e)}")
66
  return nodes[:top_k]
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def create_query_engine(vector_index):
69
  try:
70
  from config import CUSTOM_PROMPT
 
 
 
 
 
71
 
72
  bm25_retriever = BM25Retriever.from_defaults(
73
  docstore=vector_index.docstore,
@@ -92,14 +179,29 @@ def create_query_engine(vector_index):
92
  text_qa_template=custom_prompt_template
93
  )
94
 
95
- query_engine = RetrieverQueryEngine(
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  retriever=hybrid_retriever,
97
- response_synthesizer=response_synthesizer
 
 
98
  )
99
 
100
- log_message("Query engine успешно создан")
101
  return query_engine
102
 
103
  except Exception as e:
104
- log_message(f"Ошибка создания query engine: {str(e)}")
105
  raise
 
65
  log_message(f"Ошибка переранжировки: {str(e)}")
66
  return nodes[:top_k]
67
 
68
+ def extract_weld_type_from_query(query):
69
+ """Extract welded joint type (С-XX, У-XX, etc.) from query"""
70
+ import re
71
+
72
+ # Pattern for Russian weld types: С-25, У-12, Т-5, etc.
73
+ patterns = [
74
+ r'[СУТ]-\d+', # Matches С-25, У-12, Т-5
75
+ r'(?:тип|тип[а-я]*)\s+([СУТ]-\d+)', # "тип С-25" or "тип: С-25"
76
+ ]
77
+
78
+ for pattern in patterns:
79
+ match = re.search(pattern, query, re.IGNORECASE)
80
+ if match:
81
+ if '-' in match.group(0):
82
+ return match.group(0).upper()
83
+ elif len(match.groups()) > 0:
84
+ return match.group(1).upper()
85
+
86
+ return None
87
+
88
+
89
+ def retrieve_nodes_with_weld_type_priority(query, vector_index, hybrid_retriever, reranker, top_k=20):
90
+ """
91
+ Enhanced retrieval that prioritizes welded joint type matches
92
+ """
93
+ from utils import deduplicate_nodes
94
+
95
+ log_message(f"Enhanced retrieval for query: {query}")
96
+
97
+ # Step 1: Try to extract weld type from query
98
+ weld_type = extract_weld_type_from_query(query)
99
+
100
+ if weld_type:
101
+ log_message(f"Detected weld type in query: {weld_type}")
102
+
103
+ # Step 2: Direct lookup in docstore for this weld type
104
+ direct_matches = []
105
+ all_nodes = list(vector_index.docstore.docs.values())
106
+
107
+ for node in all_nodes:
108
+ metadata = node.metadata if hasattr(node, 'metadata') else {}
109
+
110
+ # Check if this is a table node with matching weld type
111
+ if metadata.get('type') == 'table':
112
+ table_num = metadata.get('table_number', '')
113
+ table_title = metadata.get('table_title', '')
114
+
115
+ # Check multiple fields for the weld type
116
+ if (weld_type in str(table_num) or
117
+ weld_type in str(table_title) or
118
+ weld_type in str(metadata.get('section', ''))):
119
+ direct_matches.append(node)
120
+ log_message(f" Direct match found: {metadata.get('document_id')} - {table_title}")
121
+
122
+ if direct_matches:
123
+ # Remove duplicates
124
+ direct_matches = deduplicate_nodes(direct_matches)
125
+ log_message(f"Found {len(direct_matches)} direct matches for {weld_type}")
126
+
127
+ # Add some context from hybrid retriever
128
+ hybrid_results = hybrid_retriever.retrieve(query)
129
+
130
+ # Combine: prioritize direct matches, supplement with hybrid results
131
+ combined = direct_matches + hybrid_results
132
+ combined = deduplicate_nodes(combined)
133
+
134
+ # Rerank combined results
135
+ reranked = rerank_nodes(query, combined, reranker, top_k=top_k)
136
+ log_message(f"Combined retrieval: {len(direct_matches)} direct + hybrid, returning {len(reranked)} reranked")
137
+
138
+ return reranked
139
+
140
+ # Step 3: Fall back to normal hybrid retrieval if no weld type found
141
+ log_message("No weld type detected, using standard hybrid retrieval")
142
+ retrieved_nodes = hybrid_retriever.retrieve(query)
143
+ retrieved_nodes = deduplicate_nodes(retrieved_nodes)
144
+ reranked_nodes = rerank_nodes(query, retrieved_nodes, reranker, top_k=top_k)
145
+
146
+ return reranked_nodes
147
+
148
+
149
+ # Update create_query_engine to use the enhanced retrieval
150
  def create_query_engine(vector_index):
151
  try:
152
  from config import CUSTOM_PROMPT
153
+ from llama_index.core.prompts import PromptTemplate
154
+ from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode
155
+ from llama_index.core.query_engine import RetrieverQueryEngine
156
+ from llama_index.retrievers.bm25 import BM25Retriever
157
+ from llama_index.core.retrievers import QueryFusionRetriever, VectorIndexRetriever
158
 
159
  bm25_retriever = BM25Retriever.from_defaults(
160
  docstore=vector_index.docstore,
 
179
  text_qa_template=custom_prompt_template
180
  )
181
 
182
+ # Create custom query engine with enhanced retrieval
183
+ class EnhancedRetrieverQueryEngine(RetrieverQueryEngine):
184
+ def __init__(self, retriever, response_synthesizer, vector_index, reranker):
185
+ super().__init__(retriever=retriever, response_synthesizer=response_synthesizer)
186
+ self.vector_index = vector_index
187
+ self.reranker = reranker
188
+
189
+ def retrieve(self, query):
190
+ """Override retrieve to use enhanced weld-type-aware retrieval"""
191
+ return retrieve_nodes_with_weld_type_priority(
192
+ query, self.vector_index, self.retriever, self.reranker, top_k=20
193
+ )
194
+
195
+ query_engine = EnhancedRetrieverQueryEngine(
196
  retriever=hybrid_retriever,
197
+ response_synthesizer=response_synthesizer,
198
+ vector_index=vector_index,
199
+ reranker=None # Will be passed in later
200
  )
201
 
202
+ log_message("Enhanced query engine created with weld-type prioritization")
203
  return query_engine
204
 
205
  except Exception as e:
206
+ log_message(f"Error creating enhanced query engine: {str(e)}")
207
  raise