OCR_RAG-AX650N / vector_store.py
H022329's picture
Upload folder using huggingface_hub
0378e25 verified
Raw
History Blame Contribute Delete
10.6 kB
"""
============================================================
向量数据库存储模块
============================================================
嵌入模型: Qwen3-Embedding 系列
向量数据库: Chroma / FAISS
功能:
1. 文档批量向量化入库
2. 相似度检索 / MMR / 元数据过滤
3. 持久化与增量更新
"""
from pathlib import Path
from typing import List, Optional, Dict, Any, Callable
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores import Chroma, FAISS
from loguru import logger
import config
from embeddings import get_embedding_model
# ============================================================
# 向量数据库工厂
# ============================================================
class VectorStoreFactory:
@staticmethod
def create_chroma(
persist_directory: Optional[str | Path] = None,
collection_name: str = config.CHROMA_COLLECTION_NAME,
embedding_function: Optional[Embeddings] = None,
) -> Chroma:
persist_dir = str(persist_directory or config.VECTOR_DB_DIR / "chroma")
embedding = embedding_function or get_embedding_model()
logger.info(f"创建 Chroma 向量数据库: {persist_dir} (集合: {collection_name})")
return Chroma(
collection_name=collection_name,
embedding_function=embedding,
persist_directory=persist_dir,
collection_metadata={
"hnsw:space": "cosine", # Qwen3-Embedding 使用余弦相似度
"hnsw:construction_ef": 200,
"hnsw:M": 48,
},
)
@staticmethod
def create_faiss(
embedding_function: Optional[Embeddings] = None,
) -> FAISS:
embedding = embedding_function or get_embedding_model()
logger.info("创建 FAISS 向量数据库 (flat L2 index)")
# FAISS.from_documents 会创建合适的索引
return FAISS(
embedding_function=embedding,
index=None,
docstore=None,
index_to_docstore_id={},
)
@staticmethod
def create(store_type: Optional[str] = None, **kwargs) -> VectorStore:
store_type = store_type or config.VECTOR_STORE_TYPE
if store_type == "chroma":
return VectorStoreFactory.create_chroma(**kwargs)
elif store_type == "faiss":
return VectorStoreFactory.create_faiss(**kwargs)
else:
raise ValueError(f"不支持的向量数据库: {store_type}. 可选: chroma, faiss")
# ============================================================
# 向量数据库管理器
# ============================================================
class VectorStoreManager:
def __init__(
self,
vector_store: Optional[VectorStore] = None,
store_type: Optional[str] = None,
embedding_function: Optional[Embeddings] = None,
persist_directory: Optional[str | Path] = None,
):
self.store_type = store_type or config.VECTOR_STORE_TYPE
self.embedding_function = embedding_function or get_embedding_model()
self.persist_directory = str(persist_directory or config.VECTOR_DB_DIR)
self._store = vector_store or self._init_store()
def _init_store(self) -> VectorStore:
if self.store_type == "chroma":
return self._init_chroma()
elif self.store_type == "faiss":
return self._init_faiss()
else:
raise ValueError(f"不支持的向量数据库: {self.store_type}")
def _init_chroma(self) -> Chroma:
persist_dir = Path(self.persist_directory) / "chroma"
if persist_dir.exists() and any(persist_dir.iterdir()):
logger.info(f"加载已有 Chroma 数据库: {persist_dir}")
return Chroma(
persist_directory=str(persist_dir),
embedding_function=self.embedding_function,
collection_name=config.CHROMA_COLLECTION_NAME,
)
return VectorStoreFactory.create_chroma(
persist_directory=str(persist_dir),
embedding_function=self.embedding_function,
)
def _init_faiss(self) -> FAISS:
index_path = Path(self.persist_directory) / "faiss_index"
if index_path.exists():
logger.info(f"加载已有 FAISS 数据库: {index_path}")
return FAISS.load_local(
str(index_path),
self.embedding_function,
allow_dangerous_deserialization=True,
)
return VectorStoreFactory.create_faiss(
embedding_function=self.embedding_function,
)
@property
def store(self) -> VectorStore:
return self._store
# ---- 入库 ----
def add_documents(
self,
documents: List[Document],
batch_size: int = 50,
progress_callback: Optional[Callable[[int, int], None]] = None,
) -> int:
if not documents:
logger.warning("文档列表为空, 跳过入库")
return 0
total = len(documents)
logger.info(f"开始向量化入库: {total} 个文档块 (批大小={batch_size})")
for i in range(0, total, batch_size):
batch = documents[i : i + batch_size]
self._store.add_documents(batch)
if progress_callback:
progress_callback(min(i + batch_size, total), total)
self._persist()
logger.info(f"向量化入库完成: {total} 个文档块")
return total
def add_texts(
self,
texts: List[str],
metadatas: Optional[List[dict]] = None,
batch_size: int = 50,
) -> List[str]:
if not texts:
return []
all_ids = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
batch_metas = metadatas[i : i + batch_size] if metadatas else None
ids = self._store.add_texts(batch_texts, batch_metas)
all_ids.extend(ids)
self._persist()
return all_ids
# ---- 检索 ----
def similarity_search(
self,
query: str,
k: int = config.RETRIEVAL_TOP_K,
filter: Optional[Dict[str, Any]] = None,
**kwargs,
) -> List[Document]:
if filter and isinstance(self._store, Chroma):
kwargs["filter"] = filter
return self._store.similarity_search(query, k=k, **kwargs)
def similarity_search_with_score(
self,
query: str,
k: int = config.RETRIEVAL_TOP_K,
filter: Optional[Dict[str, Any]] = None,
score_threshold: float = 0.3,
**kwargs,
) -> List[tuple]:
if filter and isinstance(self._store, Chroma):
kwargs["filter"] = filter
raw = self._store.similarity_search_with_relevance_scores(
query, k=k, **kwargs
)
# Qwen3-Embedding 余弦相似度通常 > 0.5 为相关
return [(doc, score) for doc, score in raw if score >= score_threshold]
def max_marginal_relevance_search(
self,
query: str,
k: int = config.RETRIEVAL_TOP_K,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
) -> List[Document]:
if filter and isinstance(self._store, Chroma):
return self._store.max_marginal_relevance_search(
query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter,
)
return self._store.max_marginal_relevance_search(
query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult,
)
# ---- 过滤查询 ----
def search_by_document(
self, query: str, document_name: str, k: int = config.RETRIEVAL_TOP_K
) -> List[Document]:
return self.similarity_search(query, k=k, filter={"document_name": document_name})
def search_by_page_range(
self, query: str, start_page: int, end_page: int,
k: int = config.RETRIEVAL_TOP_K,
) -> List[Document]:
return self.similarity_search(
query, k=k, filter={"page": {"$gte": start_page, "$lte": end_page}}
)
# ---- 管理 ----
def _persist(self):
if self.store_type == "faiss":
index_path = Path(self.persist_directory) / "faiss_index"
index_path.mkdir(parents=True, exist_ok=True)
self._store.save_local(str(index_path))
def clear(self):
if self.store_type == "chroma":
self._store.delete_collection()
self._store = VectorStoreFactory.create_chroma(
persist_directory=Path(self.persist_directory) / "chroma",
embedding_function=self.embedding_function,
)
elif self.store_type == "faiss":
self._store = VectorStoreFactory.create_faiss(
embedding_function=self.embedding_function,
)
logger.info("向量数据库已清空")
def get_document_count(self) -> int:
try:
if self.store_type == "chroma":
return self._store._collection.count()
elif self.store_type == "faiss":
return self._store.index.ntotal if self._store.index else 0
except Exception:
return 0
def get_stats(self) -> Dict[str, Any]:
return {
"store_type": self.store_type,
"persist_directory": self.persist_directory,
"document_count": self.get_document_count(),
"embedding_model": config.EMBEDDING_MODEL_NAME,
}
# ============================================================
# 便捷函数
# ============================================================
def build_vector_store(
documents: List[Document],
store_type: Optional[str] = None,
embedding_model: Optional[Embeddings] = None,
clear_existing: bool = False,
) -> VectorStoreManager:
manager = VectorStoreManager(
store_type=store_type,
embedding_function=embedding_model,
)
if clear_existing:
manager.clear()
manager.add_documents(documents)
return manager
def load_vector_store(
store_type: Optional[str] = None,
embedding_model: Optional[Embeddings] = None,
) -> VectorStoreManager:
return VectorStoreManager(
store_type=store_type,
embedding_function=embedding_model,
)