Mini-RAG / core /vector_store.py
TuNan52's picture
Upload 88 files
c69a4d6 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2025/4/27 19:52
# @Author : hukangzhe
# @File : retriever.py
# @Description : 负责向量化、存储、检索的模块
import os
import faiss
import numpy as np
import pickle
import logging
from rank_bm25 import BM25Okapi
from typing import List, Dict, Tuple
from .schema import Document, Chunk
class HybridVectorStore:
def __init__(self, config: dict, embedder):
self.config = config["vector_store"]
self.embedder = embedder
self.faiss_index = None
self.bm25_index = None
self.parent_docs: Dict[int, Document] = {}
self.child_chunks: List[Chunk] = []
def build(self, parent_docs: Dict[int, Document], child_chunks: List[Chunk]):
self.parent_docs = parent_docs
self.child_chunks = child_chunks
# Build Faiss index
child_text = [child.text for child in child_chunks]
embeddings = self.embedder.embed(child_text)
dim = embeddings.shape[1]
self.faiss_index = faiss.IndexFlatL2(dim)
self.faiss_index.add(embeddings)
logging.info(f"FAISS index built with {len(child_chunks)} vectors.")
# Build BM25 index
tokenize_chunks = [doc.text.split(" ") for doc in child_chunks]
self.bm25_index = BM25Okapi(tokenize_chunks)
logging.info(f"BM25 index built for {len(child_chunks)} documents.")
self.save()
def search(self, query: str, top_k: int , alpha: float) -> List[Tuple[int, float]]:
# Vector Search
query_embedding = self.embedder.embed([query])
distances, indices = self.faiss_index.search(query_embedding, k=top_k)
vector_scores = {idx : 1.0/(1.0 + dist) for idx, dist in zip(indices[0], distances[0])}
# BM25 Search
tokenize_query = query.split(" ")
bm25_scores = self.bm25_index.get_scores(tokenize_query)
bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
bm25_scores = {idx: bm25_scores[idx] for idx in bm25_top_indices}
# Hybrid Search
all_indices = set(vector_scores.keys()) | set(bm25_scores.keys()) # 求并集
hybrid_scors = {}
# Normalization
max_v_score = max(vector_scores.values()) if vector_scores else 1.0
max_b_score = max(bm25_scores.values()) if bm25_scores else 1.0
for idx in all_indices:
v_score = (vector_scores.get(idx, 0))/max_v_score
b_score = (bm25_scores.get(idx, 0))/max_b_score
hybrid_scors[idx] = alpha * v_score + (1 - alpha) * b_score
sorted_indices = sorted(hybrid_scors.items(), key=lambda item: item[1], reverse=True)[:top_k]
return sorted_indices
def get_chunks(self, indices: List[int]) -> List[Chunk]:
return [self.child_chunks[i] for i in indices]
def get_parent_docs(self, chunks: List[Chunk]) -> List[Document]:
parent_ids = sorted(list(set(chunk.parent_id for chunk in chunks)))
return [self.parent_docs[pid] for pid in parent_ids]
def save(self):
index_path = self.config['index_path']
metadata_path = self.config['metadata_path']
os.makedirs(os.path.dirname(index_path), exist_ok=True)
os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
logging.info(f"Saving FAISS index to: {index_path}")
try:
faiss.write_index(self.faiss_index, index_path)
except Exception as e:
logging.error(f"Failed to save FAISS index: {e}")
raise
logging.info(f"Saving metadata data to: {metadata_path}")
try:
with open(metadata_path, 'wb') as f:
metadata = {
'parent_docs': self.parent_docs,
'child_chunks': self.child_chunks,
'bm25_index': self.bm25_index
}
pickle.dump(metadata, f)
except Exception as e:
logging.error(f"Failed to save metadata: {e}")
raise
logging.info("Vector store saved successfully.")
def load(self) -> bool:
"""
从磁盘加载整个向量存储状态,成功时返回 True,失败时返回 False。
"""
index_path = self.config['index_path']
metadata_path = self.config['metadata_path']
if not os.path.exists(index_path) or not os.path.exists(metadata_path):
logging.warning("Index files not found. Cannot load vector store.")
return False
logging.info(f"Loading vector store from disk...")
try:
# Load FAISS index
logging.info(f"Loading FAISS index from: {index_path}")
self.faiss_index = faiss.read_index(index_path)
# Load metadata
logging.info(f"Loading metadata from: {metadata_path}")
with open(metadata_path, 'rb') as f:
metadata = pickle.load(f)
self.parent_docs = metadata['parent_docs']
self.child_chunks = metadata['child_chunks']
self.bm25_index = metadata['bm25_index']
logging.info("Vector store loaded successfully.")
return True
except Exception as e:
logging.error(f"Failed to load vector store from disk: {e}")
self.faiss_index = None
self.bm25_index = None
self.parent_docs = {}
self.child_chunks = []
return False