picpocket2 / vector_store.py
chawin.chen
init
7a6cb13
# vector_store.py
import logging
import os
import pickle
import faiss
import numpy as np
import torch
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# 拼接 FAISS 索引目录
FAISS_INDEX_DIR = os.environ.get('FAISS_INDEX_DIR', os.path.join(PROJECT_ROOT, 'faiss', 'data'))
os.makedirs(FAISS_INDEX_DIR, exist_ok=True)
# 最终路径
FAISS_INDEX_PATH = os.path.join(FAISS_INDEX_DIR, "index.faiss")
ID_MAP_PATH = os.path.join(FAISS_INDEX_DIR, "id_map.pkl")
# ViT-B/16 为 512,ViT-L/14 通常为 768 或 1024
VECTOR_DIM = int(os.environ.get("VECTOR_DIM", 512))
# 全局变量
index = None
id_map = None
def init_vector_store():
"""初始化向量存储"""
global index, id_map
try:
# 初始化或加载
if os.path.exists(FAISS_INDEX_PATH) and os.path.exists(ID_MAP_PATH):
index = faiss.read_index(FAISS_INDEX_PATH)
with open(ID_MAP_PATH, "rb") as f:
id_map = pickle.load(f)
logger.info(f"Vector store loaded successfully path={FAISS_INDEX_DIR}, contains {len(id_map)} vectors")
else:
index = faiss.IndexFlatIP(VECTOR_DIM) # 归一化后可以用内积代替余弦相似度
id_map = []
logger.info("Initializing new vector store")
return True
except Exception as e:
logger.error(f"Vector store initialization failed: {e}")
return False
def is_vector_store_available():
"""检查向量存储是否可用"""
return index is not None and id_map is not None
def check_image_exists(image_path: str) -> bool:
"""
检查图像是否已经在向量库中存在
Args:
image_path: 图像路径/标识
Returns:
bool: 如果存在返回True,否则返回False
"""
try:
if not is_vector_store_available():
return False
return image_path in id_map
except Exception as e:
logger.error(f"Failed to check if image exists: {str(e)}")
return False
def add_image_vector(image_path: str, vector: torch.Tensor):
"""添加图片向量到存储"""
if not is_vector_store_available():
raise RuntimeError("向量存储未初始化")
np_vector = vector.squeeze(0).numpy().astype('float32')
index.add(np_vector[np.newaxis, :])
id_map.append(image_path)
save_index()
logger.info(f"Image vector added: {image_path}")
def search_text_vector(vector: torch.Tensor, top_k=5):
"""搜索文本向量"""
if not is_vector_store_available():
raise RuntimeError("向量存储未初始化")
np_vector = vector.squeeze(0).numpy().astype('float32')
scores, indices = index.search(np_vector[np.newaxis, :], top_k)
if indices is None or len(indices[0]) == 0:
return []
results = [
(id_map[i], float(scores[0][j]))
for j, i in enumerate(indices[0])
if i < len(id_map) and i != -1
]
return results
def save_index():
"""保存索引文件"""
try:
faiss.write_index(index, FAISS_INDEX_PATH)
with open(ID_MAP_PATH, "wb") as f:
pickle.dump(id_map, f)
logger.info("Vector index saved")
except Exception as e:
logger.error(f"Failed to save vector index: {e}")
def get_vector_store_info():
"""获取向量存储信息"""
if not is_vector_store_available():
return {"status": "not_initialized", "count": 0}
return {
"status": "available",
"count": len(id_map),
"vector_dim": VECTOR_DIM,
"index_path": FAISS_INDEX_PATH
}