File size: 5,638 Bytes
c69a4d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2025/4/27 19:52
# @Author  : hukangzhe
# @File    : retriever.py
# @Description : 负责向量化、存储、检索的模块
import os
import faiss
import numpy as np
import pickle
import logging
from rank_bm25 import BM25Okapi
from typing import List, Dict, Tuple
from .schema import Document, Chunk


class HybridVectorStore:
    def __init__(self, config: dict, embedder):
        self.config = config["vector_store"]
        self.embedder = embedder
        self.faiss_index = None
        self.bm25_index = None
        self.parent_docs: Dict[int, Document] = {}
        self.child_chunks: List[Chunk] = []

    def build(self, parent_docs: Dict[int, Document], child_chunks: List[Chunk]):
        self.parent_docs = parent_docs
        self.child_chunks = child_chunks

        # Build Faiss index
        child_text = [child.text for child in child_chunks]
        embeddings = self.embedder.embed(child_text)
        dim = embeddings.shape[1]
        self.faiss_index = faiss.IndexFlatL2(dim)
        self.faiss_index.add(embeddings)
        logging.info(f"FAISS index built with {len(child_chunks)} vectors.")

        # Build BM25 index
        tokenize_chunks = [doc.text.split(" ") for doc in child_chunks]
        self.bm25_index = BM25Okapi(tokenize_chunks)
        logging.info(f"BM25 index built for {len(child_chunks)} documents.")

        self.save()

    def search(self, query: str, top_k: int , alpha: float) -> List[Tuple[int, float]]:
        # Vector Search
        query_embedding = self.embedder.embed([query])
        distances, indices = self.faiss_index.search(query_embedding, k=top_k)
        vector_scores = {idx : 1.0/(1.0 + dist) for idx, dist in zip(indices[0], distances[0])}

        # BM25 Search
        tokenize_query = query.split(" ")
        bm25_scores = self.bm25_index.get_scores(tokenize_query)
        bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
        bm25_scores = {idx: bm25_scores[idx] for idx in bm25_top_indices}

        # Hybrid Search
        all_indices = set(vector_scores.keys()) | set(bm25_scores.keys()) # 求并集
        hybrid_scors = {}

        # Normalization
        max_v_score = max(vector_scores.values()) if vector_scores else 1.0
        max_b_score = max(bm25_scores.values()) if bm25_scores else 1.0
        for idx in all_indices:
            v_score = (vector_scores.get(idx, 0))/max_v_score
            b_score = (bm25_scores.get(idx, 0))/max_b_score
            hybrid_scors[idx] = alpha * v_score + (1 - alpha) * b_score

        sorted_indices = sorted(hybrid_scors.items(), key=lambda item: item[1], reverse=True)[:top_k]
        return sorted_indices

    def get_chunks(self, indices: List[int]) -> List[Chunk]:
        return [self.child_chunks[i] for i in indices]

    def get_parent_docs(self, chunks: List[Chunk]) -> List[Document]:
        parent_ids = sorted(list(set(chunk.parent_id for chunk in chunks)))
        return [self.parent_docs[pid] for pid in parent_ids]

    def save(self):
        index_path = self.config['index_path']
        metadata_path = self.config['metadata_path']

        os.makedirs(os.path.dirname(index_path), exist_ok=True)
        os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
        logging.info(f"Saving FAISS index to: {index_path}")
        try:
            faiss.write_index(self.faiss_index, index_path)
        except Exception as e:
            logging.error(f"Failed to save FAISS index: {e}")
            raise

        logging.info(f"Saving metadata data to: {metadata_path}")
        try:
            with open(metadata_path, 'wb') as f:
                metadata = {
                    'parent_docs': self.parent_docs,
                    'child_chunks': self.child_chunks,
                    'bm25_index': self.bm25_index
                }
                pickle.dump(metadata, f)
        except Exception as e:
            logging.error(f"Failed to save metadata: {e}")
            raise

        logging.info("Vector store saved successfully.")

    def load(self) -> bool:
        """

        从磁盘加载整个向量存储状态,成功时返回 True,失败时返回 False。

        """
        index_path = self.config['index_path']
        metadata_path = self.config['metadata_path']

        if not os.path.exists(index_path) or not os.path.exists(metadata_path):
            logging.warning("Index files not found. Cannot load vector store.")
            return False

        logging.info(f"Loading vector store from disk...")
        try:
            # Load FAISS index
            logging.info(f"Loading FAISS index from: {index_path}")
            self.faiss_index = faiss.read_index(index_path)

            # Load metadata
            logging.info(f"Loading metadata from: {metadata_path}")
            with open(metadata_path, 'rb') as f:
                metadata = pickle.load(f)
                self.parent_docs = metadata['parent_docs']
                self.child_chunks = metadata['child_chunks']
                self.bm25_index = metadata['bm25_index']

            logging.info("Vector store loaded successfully.")
            return True

        except Exception as e:
            logging.error(f"Failed to load vector store from disk: {e}")
            self.faiss_index = None
            self.bm25_index = None
            self.parent_docs = {}
            self.child_chunks = []
            return False