File size: 3,695 Bytes
cd5aabe d11ff01 cd5aabe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
# vector_store.py
import logging
import os
import pickle
import faiss
import numpy as np
import torch
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# 拼接 FAISS 索引目录
FAISS_INDEX_DIR = os.environ.get('FAISS_INDEX_DIR', os.path.join(PROJECT_ROOT, 'faiss', 'data'))
os.makedirs(FAISS_INDEX_DIR, exist_ok=True)
# 最终路径
FAISS_INDEX_PATH = os.path.join(FAISS_INDEX_DIR, "index.faiss")
ID_MAP_PATH = os.path.join(FAISS_INDEX_DIR, "id_map.pkl")
# ViT-B/16 为 512,ViT-L/14 通常为 768 或 1024
VECTOR_DIM = int(os.environ.get("VECTOR_DIM", 512))
# 全局变量
index = None
id_map = None
def init_vector_store():
"""初始化向量存储"""
global index, id_map
try:
# 初始化或加载
if os.path.exists(FAISS_INDEX_PATH) and os.path.exists(ID_MAP_PATH):
index = faiss.read_index(FAISS_INDEX_PATH)
with open(ID_MAP_PATH, "rb") as f:
id_map = pickle.load(f)
logger.info(f"Vector store loaded successfully path={FAISS_INDEX_DIR}, contains {len(id_map)} vectors")
else:
index = faiss.IndexFlatIP(VECTOR_DIM) # 归一化后可以用内积代替余弦相似度
id_map = []
logger.info("Initializing new vector store")
return True
except Exception as e:
logger.error(f"Vector store initialization failed: {e}")
return False
def is_vector_store_available():
"""检查向量存储是否可用"""
return index is not None and id_map is not None
def check_image_exists(image_path: str) -> bool:
"""
检查图像是否已经在向量库中存在
Args:
image_path: 图像路径/标识
Returns:
bool: 如果存在返回True,否则返回False
"""
try:
if not is_vector_store_available():
return False
return image_path in id_map
except Exception as e:
logger.error(f"Failed to check if image exists: {str(e)}")
return False
def add_image_vector(image_path: str, vector: torch.Tensor):
"""添加图片向量到存储"""
if not is_vector_store_available():
raise RuntimeError("向量存储未初始化")
np_vector = vector.squeeze(0).numpy().astype('float32')
index.add(np_vector[np.newaxis, :])
id_map.append(image_path)
save_index()
logger.info(f"Image vector added: {image_path}")
def search_text_vector(vector: torch.Tensor, top_k=5):
"""搜索文本向量"""
if not is_vector_store_available():
raise RuntimeError("向量存储未初始化")
np_vector = vector.squeeze(0).numpy().astype('float32')
scores, indices = index.search(np_vector[np.newaxis, :], top_k)
if indices is None or len(indices[0]) == 0:
return []
results = [
(id_map[i], float(scores[0][j]))
for j, i in enumerate(indices[0])
if i < len(id_map) and i != -1
]
return results
def save_index():
"""保存索引文件"""
try:
faiss.write_index(index, FAISS_INDEX_PATH)
with open(ID_MAP_PATH, "wb") as f:
pickle.dump(id_map, f)
logger.info("Vector index saved")
except Exception as e:
logger.error(f"Failed to save vector index: {e}")
def get_vector_store_info():
"""获取向量存储信息"""
if not is_vector_store_available():
return {"status": "not_initialized", "count": 0}
return {
"status": "available",
"count": len(id_map),
"vector_dim": VECTOR_DIM,
"index_path": FAISS_INDEX_PATH
}
|