| | import json
|
| | import numpy as np
|
| | import faiss
|
| | from transformers import AutoTokenizer, AutoModel
|
| | import torch
|
| |
|
| | class EntityLevelRetriever:
|
| | def __init__(self, model_name='bert-base-chinese'):
|
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| | self.model = AutoModel.from_pretrained(model_name)
|
| | self.index = faiss.IndexFlatL2(768)
|
| | self.entity_db = []
|
| | self.metadata = []
|
| |
|
| | def _get_entity_span(self, text, entity):
|
| | """通过精确匹配获取实体在文本中的位置"""
|
| | start = text.find(entity)
|
| | if start == -1:
|
| | return None
|
| | return (start, start + len(entity))
|
| |
|
| | def _generate_entity_embedding(self, text, entity):
|
| | """生成实体级上下文嵌入"""
|
| | span = self._get_entity_span(text, entity)
|
| | if not span:
|
| | return None
|
| |
|
| | inputs = self.tokenizer(text, return_tensors='pt', truncation=True)
|
| | with torch.no_grad():
|
| | outputs = self.model(**inputs)
|
| |
|
| |
|
| | char_to_token = lambda x: inputs.char_to_token(x)
|
| | start_token = char_to_token(span[0])
|
| | end_token = char_to_token(span[1]-1)
|
| |
|
| | if not start_token or not end_token:
|
| | return None
|
| |
|
| |
|
| | entity_embedding = outputs.last_hidden_state[0, start_token:end_token+1].mean(dim=0).numpy()
|
| | return entity_embedding.astype('float32')
|
| |
|
| | def build_index(self, train_path):
|
| | """构建实体索引"""
|
| | with open(train_path, 'r', encoding='utf-8') as f:
|
| | dataset = json.load(f)
|
| |
|
| | dataset = dataset[500:1000]
|
| | embeddings = []
|
| | meta_info = []
|
| |
|
| | for idx, item in enumerate(dataset):
|
| | if idx % 100 == 0:
|
| | print(f"处理进度: {idx}/{len(dataset)}")
|
| |
|
| | text = item['text']
|
| | for triple in item['triple_list']:
|
| | for entity in [triple[0], triple[2]]:
|
| | try:
|
| | embedding = self._generate_entity_embedding(text, entity)
|
| | if embedding is not None:
|
| | embeddings.append(embedding)
|
| | meta_info.append({
|
| | 'entity': entity,
|
| | 'type': triple[1],
|
| | 'context': text
|
| | })
|
| | except Exception as e:
|
| | print(f"处理实体 {entity} 时出错: {str(e)}")
|
| |
|
| | if embeddings:
|
| | self.entity_db = embeddings
|
| | self.metadata = meta_info
|
| | self.index.add(np.array(embeddings))
|
| | print(f"索引构建完成 - 向量数: {len(self.entity_db)}, 元数据数: {len(self.metadata)}")
|
| | print(f"索引维度: {self.index.d}, 存储数量: {self.index.ntotal}")
|
| | else:
|
| | print("警告:没有有效的实体嵌入被添加到索引中")
|
| |
|
| | def search_entities(self, test_path, top_k=3, batch_size=32):
|
| | """优化的实体检索"""
|
| | with open(test_path, 'r', encoding='utf-8') as f:
|
| | test_data = json.load(f)
|
| |
|
| | results = []
|
| | for item_idx, item in enumerate(test_data):
|
| | if item_idx % 10 == 0:
|
| | print(f"检索进度: {item_idx}/{len(test_data)}")
|
| |
|
| | text = item['text']
|
| | entity_results = {}
|
| | batch_embeddings = []
|
| | batch_entities = []
|
| |
|
| | for triple in item['triple_list']:
|
| | for entity in [triple[0], triple[2]]:
|
| | embedding = self._generate_entity_embedding(text, entity)
|
| | if embedding is not None:
|
| | batch_embeddings.append(embedding)
|
| | batch_entities.append(entity)
|
| |
|
| | if len(batch_embeddings) >= batch_size:
|
| | self._process_batch(batch_embeddings, batch_entities, entity_results, top_k)
|
| | batch_embeddings = []
|
| | batch_entities = []
|
| |
|
| |
|
| | if batch_embeddings:
|
| | self._process_batch(batch_embeddings, batch_entities, entity_results, top_k)
|
| |
|
| | results.append({
|
| | 'text': text,
|
| | 'entity_matches': entity_results
|
| | })
|
| | return results
|
| |
|
| | def _process_batch(self, embeddings, entities, entity_results, top_k):
|
| | """批量处理实体检索"""
|
| | distances, indices = self.index.search(np.array(embeddings), top_k)
|
| |
|
| | for idx, (entity, dist, ind) in enumerate(zip(entities, distances, indices)):
|
| | neighbors = []
|
| | for j, (distance, index) in enumerate(zip(dist, ind)):
|
| | if 0 <= index < len(self.metadata):
|
| | neighbors.append({
|
| | 'entity': self.metadata[index]['entity'],
|
| | 'relation': self.metadata[index]['type'],
|
| | 'context': self.metadata[index]['context'],
|
| | 'distance': float(distance)
|
| | })
|
| | entity_results[entity] = neighbors
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| |
|
| | retriever = EntityLevelRetriever()
|
| |
|
| |
|
| | print("Building training index...")
|
| | retriever.build_index('./data/train_triples.json')
|
| |
|
| |
|
| | print("\nSearching similar entities...")
|
| | results = retriever.search_entities('./data/test_triples.json')
|
| |
|
| |
|
| | with open('./data/entity_search_results.json', 'w', encoding='utf-8') as f:
|
| | json.dump(results, f, ensure_ascii=False, indent=2)
|
| |
|
| |
|
| |
|