MrSimple07 commited on
Commit
2d1ebe6
·
1 Parent(s): 7c138ed

new embeeding model + new create_quer_engine with keyword matching

Browse files
Files changed (2) hide show
  1. index_retriever.py +74 -23
  2. utils.py +6 -3
index_retriever.py CHANGED
@@ -27,38 +27,89 @@ def create_vector_index(documents):
27
  index = VectorStoreIndex.from_documents(documents)
28
  log_message("✓ Index created")
29
  return index
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def create_query_engine(vector_index):
32
- """Create hybrid retrieval engine"""
33
  log_message("Creating query engine...")
34
 
35
- # Vector retriever
36
- vector_retriever = VectorIndexRetriever(
37
- index=vector_index,
38
- similarity_top_k=50
39
- )
40
-
41
- # BM25 retriever
42
- bm25_retriever = BM25Retriever.from_defaults(
43
- docstore=vector_index.docstore,
44
- similarity_top_k=50
45
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Hybrid fusion
48
- hybrid_retriever = QueryFusionRetriever(
49
- [vector_retriever, bm25_retriever],
50
- similarity_top_k=60,
51
- num_queries=1
52
- )
53
 
54
- # Response synthesizer
55
  response_synthesizer = get_response_synthesizer()
56
-
57
- # Query engine
58
  query_engine = RetrieverQueryEngine(
59
- retriever=hybrid_retriever,
60
  response_synthesizer=response_synthesizer
61
  )
62
 
63
- log_message("✓ Query engine created")
64
  return query_engine
 
27
  index = VectorStoreIndex.from_documents(documents)
28
  log_message("✓ Index created")
29
  return index
30
+ from llama_index.core.vector_stores import MetadataFilters, ExactMatchFilter
31
+ import re
32
+
33
+ def extract_document_id(query):
34
+ """Extract GOST document ID from query"""
35
+ patterns = [
36
+ r'ГОСТ\s*Р?\s*([\d\.]+(?:-\d{4})?)',
37
+ r'НП-[\d\-]+',
38
+ r'ПН\s+АЭ\s+Г-[\d\-]+'
39
+ ]
40
+
41
+ for pattern in patterns:
42
+ match = re.search(pattern, query, re.IGNORECASE)
43
+ if match:
44
+ doc_id = match.group(0)
45
+ # Normalize
46
+ doc_id = re.sub(r'ГОСТ\s*Р', 'ГОСТ Р', doc_id, flags=re.IGNORECASE)
47
+ if 'ГОСТ' in doc_id and '-' not in doc_id:
48
+ doc_id += '-2020'
49
+ return doc_id
50
+ return None
51
+
52
 
53
  def create_query_engine(vector_index):
54
+ """Create hybrid retrieval engine with document filtering"""
55
  log_message("Creating query engine...")
56
 
57
+ def retrieve_with_filter(query_str):
58
+ """Custom retrieval with optional document filtering"""
59
+ doc_id = extract_document_id(query_str)
60
+
61
+ if doc_id:
62
+ log_message(f"Detected document filter: {doc_id}")
63
+
64
+ # Try filtered retrieval first
65
+ filters = MetadataFilters(
66
+ filters=[ExactMatchFilter(key="document_id", value=doc_id)]
67
+ )
68
+
69
+ filtered_retriever = VectorIndexRetriever(
70
+ index=vector_index,
71
+ similarity_top_k=30,
72
+ filters=filters
73
+ )
74
+
75
+ filtered_results = filtered_retriever.retrieve(query_str)
76
+ log_message(f"Filtered retrieval: {len(filtered_results)} results from {doc_id}")
77
+
78
+ if len(filtered_results) >= 10:
79
+ # Good enough, use filtered results
80
+ return filtered_results
81
+ else:
82
+ log_message("Not enough filtered results, falling back to hybrid")
83
+
84
+ # Fallback to hybrid retrieval
85
+ vector_retriever = VectorIndexRetriever(
86
+ index=vector_index,
87
+ similarity_top_k=50
88
+ )
89
+
90
+ bm25_retriever = BM25Retriever.from_defaults(
91
+ docstore=vector_index.docstore,
92
+ similarity_top_k=50
93
+ )
94
+
95
+ hybrid_retriever = QueryFusionRetriever(
96
+ [vector_retriever, bm25_retriever],
97
+ similarity_top_k=60,
98
+ num_queries=1
99
+ )
100
+
101
+ return hybrid_retriever.retrieve(query_str)
102
 
103
+ # Create custom query engine
104
+ class CustomRetriever:
105
+ def retrieve(self, query_str):
106
+ return retrieve_with_filter(query_str)
 
 
107
 
 
108
  response_synthesizer = get_response_synthesizer()
 
 
109
  query_engine = RetrieverQueryEngine(
110
+ retriever=CustomRetriever(),
111
  response_synthesizer=response_synthesizer
112
  )
113
 
114
+ log_message("✓ Query engine created with document filtering")
115
  return query_engine
utils.py CHANGED
@@ -7,9 +7,12 @@ def get_llm_model(api_key, model_name="gemini-2.0-flash"):
7
  """Get LLM model"""
8
  return GoogleGenAI(model=model_name, api_key=api_key)
9
 
10
- def get_embedding_model(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"):
11
- """Get embedding model"""
12
- return HuggingFaceEmbedding(model_name=model_name)
 
 
 
13
 
14
  def get_reranker_model(model_name='cross-encoder/ms-marco-MiniLM-L-12-v2'):
15
  """Get reranker model"""
 
7
  """Get LLM model"""
8
  return GoogleGenAI(model=model_name, api_key=api_key)
9
 
10
+ def get_embedding_model(model_name="intfloat/multilingual-e5-large"):
11
+ """Use better multilingual embedding model"""
12
+ return HuggingFaceEmbedding(
13
+ model_name=model_name,
14
+ trust_remote_code=True
15
+ )
16
 
17
  def get_reranker_model(model_name='cross-encoder/ms-marco-MiniLM-L-12-v2'):
18
  """Get reranker model"""