|
|
from typing import List, Tuple |
|
|
import os |
|
|
import faiss |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
|
|
|
from tqdm import tqdm |
|
|
|
|
|
from text2vec import SentenceModel |
|
|
|
|
|
|
|
|
class ExcelIndexer(object): |
|
|
''' |
|
|
excel表格检索器 |
|
|
''' |
|
|
def __init__(self, vector_sz: int, n_subquantizers=0, n_bits=8, model: SentenceModel = None,**kwargs): |
|
|
""" |
|
|
初始化索引器,选择使用FAISS的类型 |
|
|
:param vector_sz: 嵌入向量的大小 |
|
|
:param n_subquantizers: 子量化器数量 |
|
|
:param n_bits: 每个子向量的位数 |
|
|
:param model: SentenceModel 模型 |
|
|
""" |
|
|
if n_subquantizers > 0: |
|
|
self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT) |
|
|
else: |
|
|
self.index = faiss.IndexFlatIP(vector_sz) |
|
|
self.index_id_to_row_id = [] |
|
|
|
|
|
self.data_frame = None |
|
|
self.model = model |
|
|
|
|
|
|
|
|
print(f'Initialized FAISS index of type {type(self.index)}') |
|
|
|
|
|
def load_excel(self, dataset_path: str, retrieve_column: str, batch_size: int = 2048, embeddings_file: str = None, **kwargs): |
|
|
""" |
|
|
加载 Excel 文件并自动创建 FAISS 索引 |
|
|
:param dataset_path: Excel 文件路径 |
|
|
:param retrieve_column: 存储待检索文本的列名 |
|
|
:param batch_size: 批处理大小 |
|
|
:param embeddings_file: 嵌入向量文件路径,如果存在则加载; 如果不存在则生成并保存 |
|
|
""" |
|
|
self.retrieve_column = retrieve_column |
|
|
self.embeddings_file = embeddings_file |
|
|
|
|
|
|
|
|
print(f'📂 Loading Excel file: {dataset_path}...') |
|
|
self.data_frame = pd.read_excel(dataset_path) |
|
|
lenth = len(self.data_frame) |
|
|
print(f'✅ Loaded {lenth} rows from {dataset_path}.') |
|
|
|
|
|
|
|
|
if self.embeddings_file and os.path.exists(self.embeddings_file): |
|
|
print(f'📂 Found embeddings file: {self.embeddings_file}. Loading...') |
|
|
self.embeddings = np.load(self.embeddings_file) |
|
|
print(f'✅ Embeddings loaded, shape: {self.embeddings.shape}. Indexing data...') |
|
|
ids = range(lenth) |
|
|
print(f'⚙️ Indexing {lenth} rows...') |
|
|
self.index_data(ids, self.embeddings) |
|
|
else: |
|
|
print(f'⚙️ Generating embeddings and indexing data...') |
|
|
self.embeddings = np.empty((lenth, self.index.d), dtype=np.float32) |
|
|
|
|
|
for times in range(lenth // batch_size + 1): |
|
|
|
|
|
start = times * batch_size |
|
|
end = min((times + 1) * batch_size, lenth) |
|
|
|
|
|
|
|
|
print(f'🔄 Processing batch {times + 1}/{lenth // batch_size + 1} processed: {start}/{lenth}', end='\r') |
|
|
|
|
|
ids = range(start, end) |
|
|
|
|
|
embeddings = np.array([self.model.encode(self.data_frame[self.retrieve_column][i]) for i in range(start, end)]).astype('float32') |
|
|
|
|
|
|
|
|
self.embeddings[start:end] = embeddings |
|
|
|
|
|
|
|
|
self.index_data(ids, embeddings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print('✅ All batches processed. Saving embeddings...') |
|
|
if self.embeddings_file: |
|
|
np.save(self.embeddings_file, self.embeddings) |
|
|
print(f'✅ Embeddings saved to {self.embeddings_file}.') |
|
|
|
|
|
print('🎉 Indexing complete!') |
|
|
|
|
|
def index_data(self, ids: List[int], embeddings: np.array, **kwargs): |
|
|
""" |
|
|
将数据从 Excel 中加载并索引 |
|
|
:param ids: 来自 Excel 的行 ID(可以是某一列唯一标识符) |
|
|
:param embeddings: 行的嵌入向量 |
|
|
""" |
|
|
|
|
|
self._update_id_mapping(ids) |
|
|
|
|
|
|
|
|
embeddings = embeddings.astype('float32') |
|
|
|
|
|
|
|
|
if not self.index.is_trained: |
|
|
print('⚙️ Training FAISS index...') |
|
|
self.index.train(embeddings) |
|
|
print('✅ FAISS index trained.') |
|
|
|
|
|
|
|
|
self.index.add(embeddings) |
|
|
print(' ' * 80, end='\r', flush=True) |
|
|
print(f'✅ Indexed {len(self.index_id_to_row_id)} rows.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _update_id_mapping(self, row_ids: List[int]): |
|
|
"""更新行 ID 到索引 ID 的映射关系""" |
|
|
self.index_id_to_row_id.extend(row_ids) |
|
|
|
|
|
def search_return_text(self, query: str, top_docs: int, index_batch_size: int = 10000, **kwargs) -> Tuple[List[str], List[float]]: |
|
|
db_ids, socres = self.search_dp(query, top_docs) |
|
|
|
|
|
return [self.data_frame[self.retrieve_column][int(i)] for i in db_ids if len(self.data_frame[self.retrieve_column][int(i)]) > 20 ], socres |
|
|
|
|
|
def search_dp(self, query: str, top_docs: int, **kwargs) -> Tuple[List[int], List[float]]: |
|
|
""" |
|
|
执行 dp 查询,返回 Excel 文件行 ID 和相似度得分 |
|
|
:param query_vectors: 查询的嵌入向量 |
|
|
:param top_docs: 返回的最近邻文档数量 |
|
|
:return: 返回每个查询向量对应的最近邻行 ID 和得分 |
|
|
""" |
|
|
query_vectors = self.model.encode(query).astype('float32').reshape(1, -1) |
|
|
scores, indexes = self.index.search(query_vectors, top_docs) |
|
|
|
|
|
scores = scores[0] |
|
|
indexes = indexes[0] |
|
|
db_ids = [self.index_id_to_row_id[i] for i in indexes] |
|
|
|
|
|
|
|
|
return db_ids, scores |
|
|
|
|
|
|
|
|
|