retriever / src /ExcelIndexer.py
Yyy0530's picture
重构 ExcelIndexer 类,移动到 src 目录并更新导入路径
a00efca
from typing import List, Tuple
import os
import faiss
import numpy as np
import pandas as pd
from tqdm import tqdm
#from sklearn.preprocessing import normalize
from text2vec import SentenceModel
class ExcelIndexer(object):
'''
excel表格检索器
'''
def __init__(self, vector_sz: int, n_subquantizers=0, n_bits=8, model: SentenceModel = None,**kwargs):
"""
初始化索引器,选择使用FAISS的类型
:param vector_sz: 嵌入向量的大小
:param n_subquantizers: 子量化器数量
:param n_bits: 每个子向量的位数
:param model: SentenceModel 模型
"""
if n_subquantizers > 0:
self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT)
else:
self.index = faiss.IndexFlatIP(vector_sz)
self.index_id_to_row_id = [] # 用于存储 FAISS 索引 ID 到 Excel 文件中的行号映射
self.data_frame = None # 存储 Excel 数据
self.model = model
print(f'Initialized FAISS index of type {type(self.index)}')
def load_excel(self, dataset_path: str, retrieve_column: str, batch_size: int = 2048, embeddings_file: str = None, **kwargs):
"""
加载 Excel 文件并自动创建 FAISS 索引
:param dataset_path: Excel 文件路径
:param retrieve_column: 存储待检索文本的列名
:param batch_size: 批处理大小
:param embeddings_file: 嵌入向量文件路径,如果存在则加载; 如果不存在则生成并保存
"""
self.retrieve_column = retrieve_column
self.embeddings_file = embeddings_file # 嵌入向量文件路径
# 加载 Excel 文件
print(f'📂 Loading Excel file: {dataset_path}...')
self.data_frame = pd.read_excel(dataset_path)
lenth = len(self.data_frame)
print(f'✅ Loaded {lenth} rows from {dataset_path}.')
# 检查是否有嵌入向量文件
if self.embeddings_file and os.path.exists(self.embeddings_file):
print(f'📂 Found embeddings file: {self.embeddings_file}. Loading...')
self.embeddings = np.load(self.embeddings_file)
print(f'✅ Embeddings loaded, shape: {self.embeddings.shape}. Indexing data...')
ids = range(lenth)
print(f'⚙️ Indexing {lenth} rows...')
self.index_data(ids, self.embeddings)
else:
print(f'⚙️ Generating embeddings and indexing data...')
self.embeddings = np.empty((lenth, self.index.d), dtype=np.float32) # 初始化空的嵌入向量矩阵
for times in range(lenth // batch_size + 1):
start = times * batch_size
end = min((times + 1) * batch_size, lenth)
# 动态显示进度
print(f'🔄 Processing batch {times + 1}/{lenth // batch_size + 1} processed: {start}/{lenth}', end='\r')
ids = range(start, end)
embeddings = np.array([self.model.encode(self.data_frame[self.retrieve_column][i]) for i in range(start, end)]).astype('float32')
# 保存嵌入向量到内存
self.embeddings[start:end] = embeddings
self.index_data(ids, embeddings)
# print(f'🔄 Batch {times + 1}/{lenth // batch_size + 1} processed: {end}/{lenth}', end='\r')
print('✅ All batches processed. Saving embeddings...')
if self.embeddings_file:
np.save(self.embeddings_file, self.embeddings)
print(f'✅ Embeddings saved to {self.embeddings_file}.')
print('🎉 Indexing complete!')
def index_data(self, ids: List[int], embeddings: np.array, **kwargs):
"""
将数据从 Excel 中加载并索引
:param ids: 来自 Excel 的行 ID(可以是某一列唯一标识符)
:param embeddings: 行的嵌入向量
"""
# 更新 ID 映射
self._update_id_mapping(ids)
# 转换为 float32 类型
embeddings = embeddings.astype('float32')
# 如果 FAISS 索引尚未训练,则进行训练
if not self.index.is_trained:
print('⚙️ Training FAISS index...')
self.index.train(embeddings)
print('✅ FAISS index trained.')
# 添加到索引
self.index.add(embeddings)
print(' ' * 80, end='\r', flush=True) # 用空格清除内容
print(f'✅ Indexed {len(self.index_id_to_row_id)} rows.')
# 其他方法保持不变
def _update_id_mapping(self, row_ids: List[int]):
"""更新行 ID 到索引 ID 的映射关系"""
self.index_id_to_row_id.extend(row_ids)
def search_return_text(self, query: str, top_docs: int, index_batch_size: int = 10000, **kwargs) -> Tuple[List[str], List[float]]:
db_ids, socres = self.search_dp(query, top_docs)
return [self.data_frame[self.retrieve_column][int(i)] for i in db_ids if len(self.data_frame[self.retrieve_column][int(i)]) > 20 ], socres #过滤掉描述长度小于10的
def search_dp(self, query: str, top_docs: int, **kwargs) -> Tuple[List[int], List[float]]:
"""
执行 dp 查询,返回 Excel 文件行 ID 和相似度得分
:param query_vectors: 查询的嵌入向量
:param top_docs: 返回的最近邻文档数量
:return: 返回每个查询向量对应的最近邻行 ID 和得分
"""
query_vectors = self.model.encode(query).astype('float32').reshape(1, -1) # (1, 768)
scores, indexes = self.index.search(query_vectors, top_docs) #index.search(query_vectors, top_docs)->Tuple[List[List[float32]], List[List[int]]]
# 将 FAISS 索引 ID 转换为 Excel 中的行 ID
scores = scores[0]
indexes = indexes[0]
db_ids = [self.index_id_to_row_id[i] for i in indexes]
# 将每个查询结果添加到最终结果中
return db_ids, scores