GeoLLM / Task2 /KNN.py
Pengfa Li
Upload folder using huggingface_hub
badcf3c verified
import json
import numpy as np
from transformers import AutoTokenizer, AutoModel
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm
import pandas as pd
# 定义三元组检索系统类
class TripleRetrievalSystem:
def __init__(self, model_name='bert-base-uncased'):
# 初始化BERT分词器和模型(使用预训练的BERT基础模型)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
# 初始化训练数据存储结构
self.train_embeddings = []
self.train_full_data = [] # 修改为存储完整三元组数据
def _generate_embeddings(self, text):
"""生成上下文敏感的token嵌入"""
# 对输入文本进行分词和编码(自动截断到512个token)
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=4096)
# 获取BERT模型的隐藏层输出(最后一层)
outputs = self.model(**inputs)
# 将输出转换为numpy数组并去除批次维度
return outputs.last_hidden_state.detach().numpy()[0]
def load_train_data(self, train_path):
"""从Excel读取训练数据"""
# 读取Excel文件中的指定sheet
train_df = pd.read_excel(train_path, sheet_name='Factoid Test')
print("Processing training data...")
self.train_embeddings = []
self.train_full_data = [] # 存储完整的(text, question, answer)
for _, row in tqdm(train_df.iterrows(), total=len(train_df)):
full_text = f"{row['Text']} {row['Question']}"
embeddings = self._generate_embeddings(full_text)
self.train_embeddings.append(embeddings.mean(axis=0))
# 存储完整的三元组数据
self.train_full_data.append({
"text": row['Text'],
"question": row['Question'],
"answer": row['Answer']
})
self.train_embeddings = np.array(self.train_embeddings)
# 修改为查找3个最近邻
self.nbrs = NearestNeighbors(n_neighbors=3, metric='cosine').fit(self.train_embeddings)
def retrieve_similar(self, test_path, output_path):
"""处理测试数据并查找相似样本"""
test_df = pd.read_excel(test_path, sheet_name='Factoid Train')
results = []
print("Processing test data...")
for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
test_text = f"{row['Text']} {row['Question']}"
test_embed = self._generate_embeddings(test_text).mean(axis=0)
# 获取前3个匹配结果
distances, indices = self.nbrs.kneighbors([test_embed])
matched = []
for i, (dist, idx) in enumerate(zip(distances[0], indices[0])):
# 获取匹配样本的完整信息
matched_data = self.train_full_data[idx]
matched.append({
"context": matched_data['text'],
"question": matched_data['question'],
"answer": matched_data['answer'], # 新增答案字段
"score": float(1 - dist)
})
results.append({
"test_query": {
"text": row['Text'],
"question": row['Question'],
},
"matched_samples": matched
})
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
# 示例运行
if __name__ == "__main__":
system = TripleRetrievalSystem()
system.load_train_data('./data/data.xlsx') # 训练数据路径
# 注意测试和训练使用同一个文件的不同sheet
system.retrieve_similar(
'./data/data.xlsx', # 测试数据文件
'./data/knn_factoid.json' # 输出文件路径
)