from typing import List, Tuple import os import faiss import numpy as np import pandas as pd from tqdm import tqdm #from sklearn.preprocessing import normalize from text2vec import SentenceModel class ExcelIndexer(object): ''' excel表格检索器 ''' def __init__(self, vector_sz: int, n_subquantizers=0, n_bits=8, model: SentenceModel = None,**kwargs): """ 初始化索引器,选择使用FAISS的类型 :param vector_sz: 嵌入向量的大小 :param n_subquantizers: 子量化器数量 :param n_bits: 每个子向量的位数 :param model: SentenceModel 模型 """ if n_subquantizers > 0: self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT) else: self.index = faiss.IndexFlatIP(vector_sz) self.index_id_to_row_id = [] # 用于存储 FAISS 索引 ID 到 Excel 文件中的行号映射 self.data_frame = None # 存储 Excel 数据 self.model = model print(f'Initialized FAISS index of type {type(self.index)}') def load_excel(self, dataset_path: str, retrieve_column: str, batch_size: int = 2048, embeddings_file: str = None, **kwargs): """ 加载 Excel 文件并自动创建 FAISS 索引 :param dataset_path: Excel 文件路径 :param retrieve_column: 存储待检索文本的列名 :param batch_size: 批处理大小 :param embeddings_file: 嵌入向量文件路径,如果存在则加载; 如果不存在则生成并保存 """ self.retrieve_column = retrieve_column self.embeddings_file = embeddings_file # 嵌入向量文件路径 # 加载 Excel 文件 print(f'📂 Loading Excel file: {dataset_path}...') self.data_frame = pd.read_excel(dataset_path) lenth = len(self.data_frame) print(f'✅ Loaded {lenth} rows from {dataset_path}.') # 检查是否有嵌入向量文件 if self.embeddings_file and os.path.exists(self.embeddings_file): print(f'📂 Found embeddings file: {self.embeddings_file}. Loading...') self.embeddings = np.load(self.embeddings_file) print(f'✅ Embeddings loaded, shape: {self.embeddings.shape}. Indexing data...') ids = range(lenth) print(f'⚙️ Indexing {lenth} rows...') self.index_data(ids, self.embeddings) else: print(f'⚙️ Generating embeddings and indexing data...') self.embeddings = np.empty((lenth, self.index.d), dtype=np.float32) # 初始化空的嵌入向量矩阵 for times in range(lenth // batch_size + 1): start = times * batch_size end = min((times + 1) * batch_size, lenth) # 动态显示进度 print(f'🔄 Processing batch {times + 1}/{lenth // batch_size + 1} processed: {start}/{lenth}', end='\r') ids = range(start, end) embeddings = np.array([self.model.encode(self.data_frame[self.retrieve_column][i]) for i in range(start, end)]).astype('float32') # 保存嵌入向量到内存 self.embeddings[start:end] = embeddings self.index_data(ids, embeddings) # print(f'🔄 Batch {times + 1}/{lenth // batch_size + 1} processed: {end}/{lenth}', end='\r') print('✅ All batches processed. Saving embeddings...') if self.embeddings_file: np.save(self.embeddings_file, self.embeddings) print(f'✅ Embeddings saved to {self.embeddings_file}.') print('🎉 Indexing complete!') def index_data(self, ids: List[int], embeddings: np.array, **kwargs): """ 将数据从 Excel 中加载并索引 :param ids: 来自 Excel 的行 ID(可以是某一列唯一标识符) :param embeddings: 行的嵌入向量 """ # 更新 ID 映射 self._update_id_mapping(ids) # 转换为 float32 类型 embeddings = embeddings.astype('float32') # 如果 FAISS 索引尚未训练,则进行训练 if not self.index.is_trained: print('⚙️ Training FAISS index...') self.index.train(embeddings) print('✅ FAISS index trained.') # 添加到索引 self.index.add(embeddings) print(' ' * 80, end='\r', flush=True) # 用空格清除内容 print(f'✅ Indexed {len(self.index_id_to_row_id)} rows.') # 其他方法保持不变 def _update_id_mapping(self, row_ids: List[int]): """更新行 ID 到索引 ID 的映射关系""" self.index_id_to_row_id.extend(row_ids) def search_return_text(self, query: str, top_docs: int, index_batch_size: int = 10000, **kwargs) -> Tuple[List[str], List[float]]: db_ids, socres = self.search_dp(query, top_docs) return [self.data_frame[self.retrieve_column][int(i)] for i in db_ids if len(self.data_frame[self.retrieve_column][int(i)]) > 20 ], socres #过滤掉描述长度小于10的 def search_dp(self, query: str, top_docs: int, **kwargs) -> Tuple[List[int], List[float]]: """ 执行 dp 查询,返回 Excel 文件行 ID 和相似度得分 :param query_vectors: 查询的嵌入向量 :param top_docs: 返回的最近邻文档数量 :return: 返回每个查询向量对应的最近邻行 ID 和得分 """ query_vectors = self.model.encode(query).astype('float32').reshape(1, -1) # (1, 768) scores, indexes = self.index.search(query_vectors, top_docs) #index.search(query_vectors, top_docs)->Tuple[List[List[float32]], List[List[int]]] # 将 FAISS 索引 ID 转换为 Excel 中的行 ID scores = scores[0] indexes = indexes[0] db_ids = [self.index_id_to_row_id[i] for i in indexes] # 将每个查询结果添加到最终结果中 return db_ids, scores