| import json
|
| import numpy as np
|
| import faiss
|
| from transformers import AutoTokenizer, AutoModel
|
| import torch
|
| from collections import defaultdict
|
| class EntityLevelRetriever:
|
| def __init__(self, model_name='bert-base-chinese'):
|
| self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| self.model = AutoModel.from_pretrained(model_name)
|
| self.index = faiss.IndexFlatL2(768)
|
| self.entity_db = []
|
| self.metadata = []
|
|
|
| def _get_entity_span(self, text, entity):
|
| """通过精确匹配获取实体在文本中的位置"""
|
| start = text.find(entity)
|
| if start == -1:
|
| return None
|
| return (start, start + len(entity))
|
|
|
| def _generate_entity_embedding(self, text, entity):
|
| """生成实体级上下文嵌入"""
|
|
|
|
|
| span = self._get_entity_span(text, entity)
|
| if not span:
|
| return None
|
|
|
| inputs = self.tokenizer(text, return_tensors='pt', truncation=True)
|
| with torch.no_grad():
|
| outputs = self.model(**inputs)
|
|
|
|
|
| char_to_token = lambda x: inputs.char_to_token(x)
|
| start_token = char_to_token(span[0])
|
| end_token = char_to_token(span[1]-1)
|
|
|
| if not start_token or not end_token:
|
| return None
|
|
|
|
|
| entity_embedding = outputs.last_hidden_state[0, start_token:end_token+1].mean(dim=0).numpy()
|
| return entity_embedding.astype('float32')
|
|
|
| def build_index(self, train_path):
|
| """构建实体索引"""
|
|
|
|
|
| with open(train_path, 'r', encoding='utf-8') as f:
|
| dataset = json.load(f)
|
|
|
| dataset = dataset[500:1000]
|
| for item in dataset:
|
| text = item['text']
|
| for triple in item['triple_list']:
|
|
|
| for entity in [triple[0], triple[2]]:
|
| embedding = self._generate_entity_embedding(text, entity)
|
| if embedding is not None:
|
| self.entity_db.append(embedding)
|
| self.metadata.append({
|
| 'entity': entity,
|
| 'type': triple[1],
|
| 'context': text
|
| })
|
|
|
| print(f"实体数量检查 - 向量数: {len(self.entity_db)}, 元数据数: {len(self.metadata)}")
|
| self.index.add(np.array(self.entity_db))
|
| print(f"索引维度: {self.index.d}, 存储数量: {self.index.ntotal}")
|
|
|
| def search_texts(self, test_path, top_k=3, score_mode='weighted'):
|
| """
|
| 基于实体聚合的文本级检索
|
| :param score_mode: 评分模式,可选'simple'(简单累加)/'weighted'(带距离权重)
|
| """
|
|
|
|
|
|
|
|
|
| with open(test_path, 'r', encoding='utf-8') as f:
|
| test_data = json.load(f)
|
|
|
| results = []
|
| for item in test_data:
|
| text = item['text']
|
| context_scores = defaultdict(float)
|
| context_hits = defaultdict(int)
|
|
|
|
|
| for triple in item['triple_list']:
|
| for entity in [triple[0], triple[2]]:
|
| embedding = self._generate_entity_embedding(text, entity)
|
| if embedding is None:
|
| continue
|
|
|
| distances, indices = self.index.search(np.array([embedding]), top_k)
|
|
|
| for j in range(top_k):
|
| idx = indices[0][j]
|
| if 0 <= idx < len(self.metadata):
|
| ctx_info = self.metadata[idx]
|
| distance = distances[0][j]
|
|
|
|
|
| if score_mode == 'simple':
|
| context_scores[ctx_info['context']] += 1
|
| elif score_mode == 'weighted':
|
| context_scores[ctx_info['context']] += 1 / (1 + distance)
|
|
|
| context_hits[ctx_info['context']] += 1
|
|
|
|
|
| scored_contexts = []
|
| for ctx, score in context_scores.items():
|
|
|
| normalized_score = score / context_hits[ctx] if context_hits[ctx] > 0 else 0
|
| scored_contexts.append((ctx, normalized_score))
|
|
|
|
|
| scored_contexts.sort(key=lambda x: x[1], reverse=True)
|
| final_results = [{'context': ctx, 'score': float(score)}
|
| for ctx, score in scored_contexts[:top_k]]
|
|
|
| results.append({
|
| 'query_text': text,
|
| 'matched_texts': final_results,
|
| 'total_hits': sum(context_hits.values())
|
| })
|
|
|
| return results
|
|
|
| if __name__ == "__main__":
|
|
|
| retriever = EntityLevelRetriever()
|
|
|
|
|
| print("Building training index...")
|
| retriever.build_index('./data/train_triples.json')
|
|
|
|
|
| print("\nSearching similar entities...")
|
|
|
| text_results = retriever.search_texts('./data/GT_500.json', top_k=3)
|
|
|
|
|
| with open('./data/text_retrieval_results.json', 'w', encoding='utf-8') as f:
|
| json.dump(text_results, f, ensure_ascii=False, indent=2)
|
|
|
| print("text_retrieval_results.json")
|
|
|