|
|
import json
|
|
|
import numpy as np
|
|
|
from transformers import AutoTokenizer, AutoModel
|
|
|
from sklearn.neighbors import NearestNeighbors
|
|
|
from tqdm import tqdm
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
|
class TripleRetrievalSystem:
|
|
|
def __init__(self, model_name='bert-base-uncased'):
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
self.model = AutoModel.from_pretrained(model_name)
|
|
|
|
|
|
self.train_embeddings = []
|
|
|
self.train_full_data = []
|
|
|
|
|
|
def _generate_embeddings(self, text):
|
|
|
"""生成上下文敏感的token嵌入"""
|
|
|
|
|
|
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=4096)
|
|
|
|
|
|
outputs = self.model(**inputs)
|
|
|
|
|
|
return outputs.last_hidden_state.detach().numpy()[0]
|
|
|
|
|
|
def load_train_data(self, train_path):
|
|
|
"""从Excel读取训练数据"""
|
|
|
|
|
|
train_df = pd.read_excel(train_path, sheet_name='Factoid Test')
|
|
|
|
|
|
print("Processing training data...")
|
|
|
self.train_embeddings = []
|
|
|
self.train_full_data = []
|
|
|
|
|
|
for _, row in tqdm(train_df.iterrows(), total=len(train_df)):
|
|
|
full_text = f"{row['Text']} {row['Question']}"
|
|
|
embeddings = self._generate_embeddings(full_text)
|
|
|
self.train_embeddings.append(embeddings.mean(axis=0))
|
|
|
|
|
|
self.train_full_data.append({
|
|
|
"text": row['Text'],
|
|
|
"question": row['Question'],
|
|
|
"answer": row['Answer']
|
|
|
})
|
|
|
|
|
|
self.train_embeddings = np.array(self.train_embeddings)
|
|
|
|
|
|
self.nbrs = NearestNeighbors(n_neighbors=3, metric='cosine').fit(self.train_embeddings)
|
|
|
|
|
|
def retrieve_similar(self, test_path, output_path):
|
|
|
"""处理测试数据并查找相似样本"""
|
|
|
test_df = pd.read_excel(test_path, sheet_name='Factoid Train')
|
|
|
|
|
|
results = []
|
|
|
print("Processing test data...")
|
|
|
for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
|
|
|
test_text = f"{row['Text']} {row['Question']}"
|
|
|
test_embed = self._generate_embeddings(test_text).mean(axis=0)
|
|
|
|
|
|
|
|
|
distances, indices = self.nbrs.kneighbors([test_embed])
|
|
|
|
|
|
matched = []
|
|
|
for i, (dist, idx) in enumerate(zip(distances[0], indices[0])):
|
|
|
|
|
|
matched_data = self.train_full_data[idx]
|
|
|
matched.append({
|
|
|
"context": matched_data['text'],
|
|
|
"question": matched_data['question'],
|
|
|
"answer": matched_data['answer'],
|
|
|
"score": float(1 - dist)
|
|
|
})
|
|
|
|
|
|
results.append({
|
|
|
"test_query": {
|
|
|
"text": row['Text'],
|
|
|
"question": row['Question'],
|
|
|
},
|
|
|
"matched_samples": matched
|
|
|
})
|
|
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
|
json.dump(results, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
system = TripleRetrievalSystem()
|
|
|
system.load_train_data('./data/data.xlsx')
|
|
|
|
|
|
|
|
|
system.retrieve_similar(
|
|
|
'./data/data.xlsx',
|
|
|
'./data/knn_factoid.json'
|
|
|
)
|
|
|
|