Spaces:
Paused
Paused
| # vector_store.py | |
| import logging | |
| import os | |
| import pickle | |
| import faiss | |
| import numpy as np | |
| import torch | |
| # 配置日志 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # 获取项目根目录 | |
| PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| # 拼接 FAISS 索引目录 | |
| FAISS_INDEX_DIR = os.environ.get('FAISS_INDEX_DIR', os.path.join(PROJECT_ROOT, 'faiss', 'data')) | |
| os.makedirs(FAISS_INDEX_DIR, exist_ok=True) | |
| # 最终路径 | |
| FAISS_INDEX_PATH = os.path.join(FAISS_INDEX_DIR, "index.faiss") | |
| ID_MAP_PATH = os.path.join(FAISS_INDEX_DIR, "id_map.pkl") | |
| # ViT-B/16 为 512,ViT-L/14 通常为 768 或 1024 | |
| VECTOR_DIM = int(os.environ.get("VECTOR_DIM", 512)) | |
| # 全局变量 | |
| index = None | |
| id_map = None | |
| def init_vector_store(): | |
| """初始化向量存储""" | |
| global index, id_map | |
| try: | |
| # 初始化或加载 | |
| if os.path.exists(FAISS_INDEX_PATH) and os.path.exists(ID_MAP_PATH): | |
| index = faiss.read_index(FAISS_INDEX_PATH) | |
| with open(ID_MAP_PATH, "rb") as f: | |
| id_map = pickle.load(f) | |
| logger.info(f"Vector store loaded successfully path={FAISS_INDEX_DIR}, contains {len(id_map)} vectors") | |
| else: | |
| index = faiss.IndexFlatIP(VECTOR_DIM) # 归一化后可以用内积代替余弦相似度 | |
| id_map = [] | |
| logger.info("Initializing new vector store") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Vector store initialization failed: {e}") | |
| return False | |
| def is_vector_store_available(): | |
| """检查向量存储是否可用""" | |
| return index is not None and id_map is not None | |
| def check_image_exists(image_path: str) -> bool: | |
| """ | |
| 检查图像是否已经在向量库中存在 | |
| Args: | |
| image_path: 图像路径/标识 | |
| Returns: | |
| bool: 如果存在返回True,否则返回False | |
| """ | |
| try: | |
| if not is_vector_store_available(): | |
| return False | |
| return image_path in id_map | |
| except Exception as e: | |
| logger.error(f"Failed to check if image exists: {str(e)}") | |
| return False | |
| def add_image_vector(image_path: str, vector: torch.Tensor): | |
| """添加图片向量到存储""" | |
| if not is_vector_store_available(): | |
| raise RuntimeError("向量存储未初始化") | |
| np_vector = vector.squeeze(0).numpy().astype('float32') | |
| index.add(np_vector[np.newaxis, :]) | |
| id_map.append(image_path) | |
| save_index() | |
| logger.info(f"Image vector added: {image_path}") | |
| def search_text_vector(vector: torch.Tensor, top_k=5): | |
| """搜索文本向量""" | |
| if not is_vector_store_available(): | |
| raise RuntimeError("向量存储未初始化") | |
| np_vector = vector.squeeze(0).numpy().astype('float32') | |
| scores, indices = index.search(np_vector[np.newaxis, :], top_k) | |
| if indices is None or len(indices[0]) == 0: | |
| return [] | |
| results = [ | |
| (id_map[i], float(scores[0][j])) | |
| for j, i in enumerate(indices[0]) | |
| if i < len(id_map) and i != -1 | |
| ] | |
| return results | |
| def save_index(): | |
| """保存索引文件""" | |
| try: | |
| faiss.write_index(index, FAISS_INDEX_PATH) | |
| with open(ID_MAP_PATH, "wb") as f: | |
| pickle.dump(id_map, f) | |
| logger.info("Vector index saved") | |
| except Exception as e: | |
| logger.error(f"Failed to save vector index: {e}") | |
| def get_vector_store_info(): | |
| """获取向量存储信息""" | |
| if not is_vector_store_available(): | |
| return {"status": "not_initialized", "count": 0} | |
| return { | |
| "status": "available", | |
| "count": len(id_map), | |
| "vector_dim": VECTOR_DIM, | |
| "index_path": FAISS_INDEX_PATH | |
| } | |