File size: 5,638 Bytes
c69a4d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2025/4/27 19:52
# @Author : hukangzhe
# @File : retriever.py
# @Description : 负责向量化、存储、检索的模块
import os
import faiss
import numpy as np
import pickle
import logging
from rank_bm25 import BM25Okapi
from typing import List, Dict, Tuple
from .schema import Document, Chunk
class HybridVectorStore:
def __init__(self, config: dict, embedder):
self.config = config["vector_store"]
self.embedder = embedder
self.faiss_index = None
self.bm25_index = None
self.parent_docs: Dict[int, Document] = {}
self.child_chunks: List[Chunk] = []
def build(self, parent_docs: Dict[int, Document], child_chunks: List[Chunk]):
self.parent_docs = parent_docs
self.child_chunks = child_chunks
# Build Faiss index
child_text = [child.text for child in child_chunks]
embeddings = self.embedder.embed(child_text)
dim = embeddings.shape[1]
self.faiss_index = faiss.IndexFlatL2(dim)
self.faiss_index.add(embeddings)
logging.info(f"FAISS index built with {len(child_chunks)} vectors.")
# Build BM25 index
tokenize_chunks = [doc.text.split(" ") for doc in child_chunks]
self.bm25_index = BM25Okapi(tokenize_chunks)
logging.info(f"BM25 index built for {len(child_chunks)} documents.")
self.save()
def search(self, query: str, top_k: int , alpha: float) -> List[Tuple[int, float]]:
# Vector Search
query_embedding = self.embedder.embed([query])
distances, indices = self.faiss_index.search(query_embedding, k=top_k)
vector_scores = {idx : 1.0/(1.0 + dist) for idx, dist in zip(indices[0], distances[0])}
# BM25 Search
tokenize_query = query.split(" ")
bm25_scores = self.bm25_index.get_scores(tokenize_query)
bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
bm25_scores = {idx: bm25_scores[idx] for idx in bm25_top_indices}
# Hybrid Search
all_indices = set(vector_scores.keys()) | set(bm25_scores.keys()) # 求并集
hybrid_scors = {}
# Normalization
max_v_score = max(vector_scores.values()) if vector_scores else 1.0
max_b_score = max(bm25_scores.values()) if bm25_scores else 1.0
for idx in all_indices:
v_score = (vector_scores.get(idx, 0))/max_v_score
b_score = (bm25_scores.get(idx, 0))/max_b_score
hybrid_scors[idx] = alpha * v_score + (1 - alpha) * b_score
sorted_indices = sorted(hybrid_scors.items(), key=lambda item: item[1], reverse=True)[:top_k]
return sorted_indices
def get_chunks(self, indices: List[int]) -> List[Chunk]:
return [self.child_chunks[i] for i in indices]
def get_parent_docs(self, chunks: List[Chunk]) -> List[Document]:
parent_ids = sorted(list(set(chunk.parent_id for chunk in chunks)))
return [self.parent_docs[pid] for pid in parent_ids]
def save(self):
index_path = self.config['index_path']
metadata_path = self.config['metadata_path']
os.makedirs(os.path.dirname(index_path), exist_ok=True)
os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
logging.info(f"Saving FAISS index to: {index_path}")
try:
faiss.write_index(self.faiss_index, index_path)
except Exception as e:
logging.error(f"Failed to save FAISS index: {e}")
raise
logging.info(f"Saving metadata data to: {metadata_path}")
try:
with open(metadata_path, 'wb') as f:
metadata = {
'parent_docs': self.parent_docs,
'child_chunks': self.child_chunks,
'bm25_index': self.bm25_index
}
pickle.dump(metadata, f)
except Exception as e:
logging.error(f"Failed to save metadata: {e}")
raise
logging.info("Vector store saved successfully.")
def load(self) -> bool:
"""
从磁盘加载整个向量存储状态,成功时返回 True,失败时返回 False。
"""
index_path = self.config['index_path']
metadata_path = self.config['metadata_path']
if not os.path.exists(index_path) or not os.path.exists(metadata_path):
logging.warning("Index files not found. Cannot load vector store.")
return False
logging.info(f"Loading vector store from disk...")
try:
# Load FAISS index
logging.info(f"Loading FAISS index from: {index_path}")
self.faiss_index = faiss.read_index(index_path)
# Load metadata
logging.info(f"Loading metadata from: {metadata_path}")
with open(metadata_path, 'rb') as f:
metadata = pickle.load(f)
self.parent_docs = metadata['parent_docs']
self.child_chunks = metadata['child_chunks']
self.bm25_index = metadata['bm25_index']
logging.info("Vector store loaded successfully.")
return True
except Exception as e:
logging.error(f"Failed to load vector store from disk: {e}")
self.faiss_index = None
self.bm25_index = None
self.parent_docs = {}
self.child_chunks = []
return False
|