MrSimple01 commited on
Commit
037203a
·
verified ·
1 Parent(s): 1e62c7c

Update index_retriever.py

Browse files
Files changed (1) hide show
  1. index_retriever.py +104 -104
index_retriever.py CHANGED
@@ -1,105 +1,105 @@
1
- from llama_index.core import VectorStoreIndex, Settings
2
- from llama_index.core.query_engine import RetrieverQueryEngine
3
- from llama_index.core.retrievers import VectorIndexRetriever
4
- from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode
5
- from llama_index.core.prompts import PromptTemplate
6
- from llama_index.retrievers.bm25 import BM25Retriever
7
- from llama_index.core.retrievers import QueryFusionRetriever
8
- from my_logging import log_message
9
- from config import CUSTOM_PROMPT, PROMPT_SIMPLE_POISK
10
-
11
- def create_vector_index(documents):
12
- log_message("Строю векторный индекс")
13
-
14
- connection_type_sources = {}
15
- table_count = 0
16
-
17
- for doc in documents:
18
- if doc.metadata.get('type') == 'table':
19
- table_count += 1
20
- conn_type = doc.metadata.get('connection_type', '')
21
- if conn_type:
22
- table_id = f"{doc.metadata.get('document_id', 'unknown')} Table {doc.metadata.get('table_number', 'N/A')}"
23
- if conn_type not in connection_type_sources:
24
- connection_type_sources[conn_type] = []
25
- connection_type_sources[conn_type].append(table_id)
26
-
27
- log_message("="*60)
28
- log_message(f"INDEXING {table_count} TABLE CHUNKS")
29
- log_message("CONNECTION TYPES IN INDEX WITH SOURCES:")
30
- for conn_type in sorted(connection_type_sources.keys()):
31
- sources = list(set(connection_type_sources[conn_type])) # Unique sources
32
- log_message(f" {conn_type}: {len(connection_type_sources[conn_type])} chunks from {len(sources)} tables")
33
- for src in sources:
34
- log_message(f" - {src}")
35
- log_message("="*60)
36
-
37
- return VectorStoreIndex.from_documents(documents)
38
-
39
-
40
- def rerank_nodes(query, nodes, reranker, top_k=25, min_score_threshold=0.5):
41
- if not nodes or not reranker:
42
- return nodes[:top_k]
43
-
44
- try:
45
- log_message(f"Переранжирую {len(nodes)} узлов")
46
-
47
- pairs = [[query, node.text] for node in nodes]
48
- scores = reranker.predict(pairs)
49
- scored_nodes = list(zip(nodes, scores))
50
-
51
- scored_nodes.sort(key=lambda x: x[1], reverse=True)
52
-
53
- # Apply threshold
54
- filtered = [(node, score) for node, score in scored_nodes if score >= min_score_threshold]
55
-
56
- if not filtered:
57
- # Lower threshold if nothing passes
58
- filtered = scored_nodes[:top_k]
59
-
60
- log_message(f"Выбрано {min(len(filtered), top_k)} узлов")
61
-
62
- return [node for node, score in filtered[:top_k]]
63
-
64
- except Exception as e:
65
- log_message(f"Ошибка переранжировки: {str(e)}")
66
- return nodes[:top_k]
67
-
68
- def create_query_engine(vector_index):
69
- try:
70
- from config import CUSTOM_PROMPT
71
-
72
- bm25_retriever = BM25Retriever.from_defaults(
73
- docstore=vector_index.docstore,
74
- similarity_top_k=50
75
- )
76
-
77
- vector_retriever = VectorIndexRetriever(
78
- index=vector_index,
79
- similarity_top_k=50,
80
- similarity_cutoff=0.65
81
- )
82
-
83
- hybrid_retriever = QueryFusionRetriever(
84
- [vector_retriever, bm25_retriever],
85
- similarity_top_k=100,
86
- num_queries=1
87
- )
88
-
89
- custom_prompt_template = PromptTemplate(CUSTOM_PROMPT)
90
- response_synthesizer = get_response_synthesizer(
91
- response_mode=ResponseMode.TREE_SUMMARIZE,
92
- text_qa_template=custom_prompt_template
93
- )
94
-
95
- query_engine = RetrieverQueryEngine(
96
- retriever=hybrid_retriever,
97
- response_synthesizer=response_synthesizer
98
- )
99
-
100
- log_message("Query engine успешно создан")
101
- return query_engine
102
-
103
- except Exception as e:
104
- log_message(f"Ошибка создания query engine: {str(e)}")
105
  raise
 
1
+ from llama_index.core import VectorStoreIndex, Settings
2
+ from llama_index.core.query_engine import RetrieverQueryEngine
3
+ from llama_index.core.retrievers import VectorIndexRetriever
4
+ from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode
5
+ from llama_index.core.prompts import PromptTemplate
6
+ from llama_index.retrievers.bm25 import BM25Retriever
7
+ from llama_index.core.retrievers import QueryFusionRetriever
8
+ from my_logging import log_message
9
+ from config import CUSTOM_PROMPT, PROMPT_SIMPLE_POISK
10
+
11
+ def create_vector_index(documents):
12
+ log_message("Строю векторный индекс")
13
+
14
+ connection_type_sources = {}
15
+ table_count = 0
16
+
17
+ for doc in documents:
18
+ if doc.metadata.get('type') == 'table':
19
+ table_count += 1
20
+ conn_type = doc.metadata.get('connection_type', '')
21
+ if conn_type:
22
+ table_id = f"{doc.metadata.get('document_id', 'unknown')} Table {doc.metadata.get('table_number', 'N/A')}"
23
+ if conn_type not in connection_type_sources:
24
+ connection_type_sources[conn_type] = []
25
+ connection_type_sources[conn_type].append(table_id)
26
+
27
+ log_message("="*60)
28
+ log_message(f"INDEXING {table_count} TABLE CHUNKS")
29
+ log_message("CONNECTION TYPES IN INDEX WITH SOURCES:")
30
+ for conn_type in sorted(connection_type_sources.keys()):
31
+ sources = list(set(connection_type_sources[conn_type])) # Unique sources
32
+ log_message(f" {conn_type}: {len(connection_type_sources[conn_type])} chunks from {len(sources)} tables")
33
+ for src in sources:
34
+ log_message(f" - {src}")
35
+ log_message("="*60)
36
+
37
+ return VectorStoreIndex.from_documents(documents)
38
+
39
+
40
+ def rerank_nodes(query, nodes, reranker, top_k=25, min_score_threshold=0.5):
41
+ if not nodes or not reranker:
42
+ return nodes[:top_k]
43
+
44
+ try:
45
+ log_message(f"Переранжирую {len(nodes)} узлов")
46
+
47
+ pairs = [[query, node.text] for node in nodes]
48
+ scores = reranker.predict(pairs)
49
+ scored_nodes = list(zip(nodes, scores))
50
+
51
+ scored_nodes.sort(key=lambda x: x[1], reverse=True)
52
+
53
+ # Apply threshold
54
+ filtered = [(node, score) for node, score in scored_nodes if score >= min_score_threshold]
55
+
56
+ if not filtered:
57
+ # Lower threshold if nothing passes
58
+ filtered = scored_nodes[:top_k]
59
+
60
+ log_message(f"Выбрано {min(len(filtered), top_k)} узлов")
61
+
62
+ return [node for node, score in filtered[:top_k]]
63
+
64
+ except Exception as e:
65
+ log_message(f"Ошибка переранжировки: {str(e)}")
66
+ return nodes[:top_k]
67
+
68
+ def create_query_engine(vector_index):
69
+ try:
70
+ from config import CUSTOM_PROMPT
71
+
72
+ bm25_retriever = BM25Retriever.from_defaults(
73
+ docstore=vector_index.docstore,
74
+ similarity_top_k=50
75
+ )
76
+
77
+ vector_retriever = VectorIndexRetriever(
78
+ index=vector_index,
79
+ similarity_top_k=50,
80
+ similarity_cutoff=0.7
81
+ )
82
+
83
+ hybrid_retriever = QueryFusionRetriever(
84
+ [vector_retriever, bm25_retriever],
85
+ similarity_top_k=100,
86
+ num_queries=1
87
+ )
88
+
89
+ custom_prompt_template = PromptTemplate(CUSTOM_PROMPT)
90
+ response_synthesizer = get_response_synthesizer(
91
+ response_mode=ResponseMode.TREE_SUMMARIZE,
92
+ text_qa_template=custom_prompt_template
93
+ )
94
+
95
+ query_engine = RetrieverQueryEngine(
96
+ retriever=hybrid_retriever,
97
+ response_synthesizer=response_synthesizer
98
+ )
99
+
100
+ log_message("Query engine успешно создан")
101
+ return query_engine
102
+
103
+ except Exception as e:
104
+ log_message(f"Ошибка создания query engine: {str(e)}")
105
  raise