File size: 3,695 Bytes
cd5aabe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d11ff01
cd5aabe
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# vector_store.py
import logging
import os
import pickle

import faiss
import numpy as np
import torch

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# 拼接 FAISS 索引目录
FAISS_INDEX_DIR = os.environ.get('FAISS_INDEX_DIR', os.path.join(PROJECT_ROOT, 'faiss', 'data'))
os.makedirs(FAISS_INDEX_DIR, exist_ok=True)

# 最终路径
FAISS_INDEX_PATH = os.path.join(FAISS_INDEX_DIR, "index.faiss")
ID_MAP_PATH = os.path.join(FAISS_INDEX_DIR, "id_map.pkl")

# ViT-B/16 为 512,ViT-L/14 通常为 768 或 1024
VECTOR_DIM = int(os.environ.get("VECTOR_DIM", 512))

# 全局变量
index = None
id_map = None

def init_vector_store():
    """初始化向量存储"""
    global index, id_map
    try:
        # 初始化或加载
        if os.path.exists(FAISS_INDEX_PATH) and os.path.exists(ID_MAP_PATH):
            index = faiss.read_index(FAISS_INDEX_PATH)
            with open(ID_MAP_PATH, "rb") as f:
                id_map = pickle.load(f)
            logger.info(f"Vector store loaded successfully path={FAISS_INDEX_DIR}, contains {len(id_map)} vectors")
        else:
            index = faiss.IndexFlatIP(VECTOR_DIM)  # 归一化后可以用内积代替余弦相似度
            id_map = []
            logger.info("Initializing new vector store")
        return True
    except Exception as e:
        logger.error(f"Vector store initialization failed: {e}")
        return False

def is_vector_store_available():
    """检查向量存储是否可用"""
    return index is not None and id_map is not None

def check_image_exists(image_path: str) -> bool:
    """
    检查图像是否已经在向量库中存在
    Args:
        image_path: 图像路径/标识
    Returns:
        bool: 如果存在返回True,否则返回False
    """
    try:
        if not is_vector_store_available():
            return False
        return image_path in id_map
    except Exception as e:
        logger.error(f"Failed to check if image exists: {str(e)}")
        return False

def add_image_vector(image_path: str, vector: torch.Tensor):
    """添加图片向量到存储"""
    if not is_vector_store_available():
        raise RuntimeError("向量存储未初始化")

    np_vector = vector.squeeze(0).numpy().astype('float32')
    index.add(np_vector[np.newaxis, :])
    id_map.append(image_path)
    save_index()
    logger.info(f"Image vector added: {image_path}")

def search_text_vector(vector: torch.Tensor, top_k=5):
    """搜索文本向量"""
    if not is_vector_store_available():
        raise RuntimeError("向量存储未初始化")

    np_vector = vector.squeeze(0).numpy().astype('float32')
    scores, indices = index.search(np_vector[np.newaxis, :], top_k)

    if indices is None or len(indices[0]) == 0:
        return []

    results = [
        (id_map[i], float(scores[0][j]))
        for j, i in enumerate(indices[0])
        if i < len(id_map) and i != -1
    ]
    return results

def save_index():
    """保存索引文件"""
    try:
        faiss.write_index(index, FAISS_INDEX_PATH)
        with open(ID_MAP_PATH, "wb") as f:
            pickle.dump(id_map, f)
        logger.info("Vector index saved")
    except Exception as e:
        logger.error(f"Failed to save vector index: {e}")

def get_vector_store_info():
    """获取向量存储信息"""
    if not is_vector_store_available():
        return {"status": "not_initialized", "count": 0}

    return {
        "status": "available",
        "count": len(id_map),
        "vector_dim": VECTOR_DIM,
        "index_path": FAISS_INDEX_PATH
    }