# vector_store.py import logging import os import pickle import faiss import numpy as np import torch # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 获取项目根目录 PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # 拼接 FAISS 索引目录 FAISS_INDEX_DIR = os.environ.get('FAISS_INDEX_DIR', os.path.join(PROJECT_ROOT, 'faiss', 'data')) os.makedirs(FAISS_INDEX_DIR, exist_ok=True) # 最终路径 FAISS_INDEX_PATH = os.path.join(FAISS_INDEX_DIR, "index.faiss") ID_MAP_PATH = os.path.join(FAISS_INDEX_DIR, "id_map.pkl") # ViT-B/16 为 512,ViT-L/14 通常为 768 或 1024 VECTOR_DIM = int(os.environ.get("VECTOR_DIM", 512)) # 全局变量 index = None id_map = None def init_vector_store(): """初始化向量存储""" global index, id_map try: # 初始化或加载 if os.path.exists(FAISS_INDEX_PATH) and os.path.exists(ID_MAP_PATH): index = faiss.read_index(FAISS_INDEX_PATH) with open(ID_MAP_PATH, "rb") as f: id_map = pickle.load(f) logger.info(f"Vector store loaded successfully path={FAISS_INDEX_DIR}, contains {len(id_map)} vectors") else: index = faiss.IndexFlatIP(VECTOR_DIM) # 归一化后可以用内积代替余弦相似度 id_map = [] logger.info("Initializing new vector store") return True except Exception as e: logger.error(f"Vector store initialization failed: {e}") return False def is_vector_store_available(): """检查向量存储是否可用""" return index is not None and id_map is not None def check_image_exists(image_path: str) -> bool: """ 检查图像是否已经在向量库中存在 Args: image_path: 图像路径/标识 Returns: bool: 如果存在返回True,否则返回False """ try: if not is_vector_store_available(): return False return image_path in id_map except Exception as e: logger.error(f"Failed to check if image exists: {str(e)}") return False def add_image_vector(image_path: str, vector: torch.Tensor): """添加图片向量到存储""" if not is_vector_store_available(): raise RuntimeError("向量存储未初始化") np_vector = vector.squeeze(0).numpy().astype('float32') index.add(np_vector[np.newaxis, :]) id_map.append(image_path) save_index() logger.info(f"Image vector added: {image_path}") def search_text_vector(vector: torch.Tensor, top_k=5): """搜索文本向量""" if not is_vector_store_available(): raise RuntimeError("向量存储未初始化") np_vector = vector.squeeze(0).numpy().astype('float32') scores, indices = index.search(np_vector[np.newaxis, :], top_k) if indices is None or len(indices[0]) == 0: return [] results = [ (id_map[i], float(scores[0][j])) for j, i in enumerate(indices[0]) if i < len(id_map) and i != -1 ] return results def save_index(): """保存索引文件""" try: faiss.write_index(index, FAISS_INDEX_PATH) with open(ID_MAP_PATH, "wb") as f: pickle.dump(id_map, f) logger.info("Vector index saved") except Exception as e: logger.error(f"Failed to save vector index: {e}") def get_vector_store_info(): """获取向量存储信息""" if not is_vector_store_available(): return {"status": "not_initialized", "count": 0} return { "status": "available", "count": len(id_map), "vector_dim": VECTOR_DIM, "index_path": FAISS_INDEX_PATH }