| | import json
|
| | import numpy as np
|
| | from transformers import AutoTokenizer, AutoModel
|
| | from sklearn.neighbors import NearestNeighbors
|
| | from tqdm import tqdm
|
| |
|
| |
|
| | class TripleRetrievalSystem:
|
| | def __init__(self, model_name='bert-base-uncased'):
|
| |
|
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| | self.model = AutoModel.from_pretrained(model_name)
|
| |
|
| | self.train_embeddings = []
|
| | self.train_texts = []
|
| |
|
| | def _generate_embeddings(self, text):
|
| | """生成上下文敏感的token嵌入"""
|
| |
|
| | inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=4096)
|
| |
|
| | outputs = self.model(**inputs)
|
| |
|
| | return outputs.last_hidden_state.detach().numpy()[0]
|
| |
|
| | def load_train_data(self, train_path):
|
| | """预处理并存储训练数据嵌入"""
|
| | with open(train_path, encoding='utf-8') as f:
|
| | train_data = json.load(f)
|
| |
|
| | print("Processing training data...")
|
| |
|
| | train_data = train_data[500:1000]
|
| |
|
| |
|
| | for item in tqdm(train_data):
|
| | text = item['text']
|
| |
|
| | embeddings = self._generate_embeddings(text)
|
| |
|
| | self.train_embeddings.append(embeddings.mean(axis=0))
|
| | self.train_texts.append(text)
|
| |
|
| |
|
| | self.train_embeddings = np.array(self.train_embeddings)
|
| |
|
| | self.nbrs = NearestNeighbors(n_neighbors=1, metric='cosine').fit(self.train_embeddings)
|
| |
|
| | def retrieve_similar(self, test_path, output_path):
|
| | """处理测试数据并查找相似训练样本"""
|
| | with open(test_path, encoding='utf-8') as f:
|
| | test_data = json.load(f)
|
| |
|
| | results = []
|
| | print("Processing test data...")
|
| |
|
| | for item in tqdm(test_data):
|
| | test_text = item['text']
|
| |
|
| | test_embed = self._generate_embeddings(test_text).mean(axis=0)
|
| |
|
| |
|
| | distances, indices = self.nbrs.kneighbors([test_embed])
|
| |
|
| | relevant = [self.train_texts[i] for i in indices[0]]
|
| |
|
| | results.append({
|
| | "test_text": test_text,
|
| | "relevant_train_texts": relevant
|
| | })
|
| |
|
| |
|
| | with open(output_path, 'w', encoding='utf-8') as f:
|
| | json.dump(results, f, ensure_ascii=False, indent=2)
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | system = TripleRetrievalSystem()
|
| | system.load_train_data('./data/train_triples.json')
|
| |
|
| |
|
| | system.retrieve_similar(
|
| | './data/test_triples.json',
|
| | './data/output_results.json'
|
| | )
|
| |
|