| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import logging |
| import itertools |
| import re |
| import traceback |
| from dataclasses import dataclass |
| from typing import Any |
|
|
| import networkx as nx |
| from rag.nlp import is_english |
| import editdistance |
| from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT |
| from rag.llm.chat_model import Base as CompletionLLM |
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements |
|
|
| DEFAULT_RECORD_DELIMITER = "##" |
| DEFAULT_ENTITY_INDEX_DELIMITER = "<|>" |
| DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&" |
|
|
|
|
| @dataclass |
| class EntityResolutionResult: |
| """Entity resolution result class definition.""" |
|
|
| output: nx.Graph |
|
|
|
|
| class EntityResolution: |
| """Entity resolution class definition.""" |
|
|
| _llm: CompletionLLM |
| _resolution_prompt: str |
| _output_formatter_prompt: str |
| _on_error: ErrorHandlerFn |
| _record_delimiter_key: str |
| _entity_index_delimiter_key: str |
| _resolution_result_delimiter_key: str |
|
|
| def __init__( |
| self, |
| llm_invoker: CompletionLLM, |
| resolution_prompt: str | None = None, |
| on_error: ErrorHandlerFn | None = None, |
| record_delimiter_key: str | None = None, |
| entity_index_delimiter_key: str | None = None, |
| resolution_result_delimiter_key: str | None = None, |
| input_text_key: str | None = None |
| ): |
| """Init method definition.""" |
| self._llm = llm_invoker |
| self._resolution_prompt = resolution_prompt or ENTITY_RESOLUTION_PROMPT |
| self._on_error = on_error or (lambda _e, _s, _d: None) |
| self._record_delimiter_key = record_delimiter_key or "record_delimiter" |
| self._entity_index_dilimiter_key = entity_index_delimiter_key or "entity_index_delimiter" |
| self._resolution_result_delimiter_key = resolution_result_delimiter_key or "resolution_result_delimiter" |
| self._input_text_key = input_text_key or "input_text" |
|
|
| def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult: |
| """Call method definition.""" |
| if prompt_variables is None: |
| prompt_variables = {} |
|
|
| |
| prompt_variables = { |
| **prompt_variables, |
| self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) |
| or DEFAULT_RECORD_DELIMITER, |
| self._entity_index_dilimiter_key: prompt_variables.get(self._entity_index_dilimiter_key) |
| or DEFAULT_ENTITY_INDEX_DELIMITER, |
| self._resolution_result_delimiter_key: prompt_variables.get(self._resolution_result_delimiter_key) |
| or DEFAULT_RESOLUTION_RESULT_DELIMITER, |
| } |
|
|
| nodes = graph.nodes |
| entity_types = list(set(graph.nodes[node]['entity_type'] for node in nodes)) |
| node_clusters = {entity_type: [] for entity_type in entity_types} |
|
|
| for node in nodes: |
| node_clusters[graph.nodes[node]['entity_type']].append(node) |
|
|
| candidate_resolution = {entity_type: [] for entity_type in entity_types} |
| for k, v in node_clusters.items(): |
| candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)] |
|
|
| gen_conf = {"temperature": 0.5} |
| resolution_result = set() |
| for candidate_resolution_i in candidate_resolution.items(): |
| if candidate_resolution_i[1]: |
| try: |
| pair_txt = [ |
| f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n'] |
| for index, candidate in enumerate(candidate_resolution_i[1]): |
| pair_txt.append( |
| f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}') |
| sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions' |
| pair_txt.append( |
| f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)') |
| pair_prompt = '\n'.join(pair_txt) |
|
|
| variables = { |
| **prompt_variables, |
| self._input_text_key: pair_prompt |
| } |
| text = perform_variable_replacements(self._resolution_prompt, variables=variables) |
|
|
| response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) |
| result = self._process_results(len(candidate_resolution_i[1]), response, |
| prompt_variables.get(self._record_delimiter_key, |
| DEFAULT_RECORD_DELIMITER), |
| prompt_variables.get(self._entity_index_dilimiter_key, |
| DEFAULT_ENTITY_INDEX_DELIMITER), |
| prompt_variables.get(self._resolution_result_delimiter_key, |
| DEFAULT_RESOLUTION_RESULT_DELIMITER)) |
| for result_i in result: |
| resolution_result.add(candidate_resolution_i[1][result_i[0] - 1]) |
| except Exception as e: |
| logging.exception("error entity resolution") |
| self._on_error(e, traceback.format_exc(), None) |
|
|
| connect_graph = nx.Graph() |
| connect_graph.add_edges_from(resolution_result) |
| for sub_connect_graph in nx.connected_components(connect_graph): |
| sub_connect_graph = connect_graph.subgraph(sub_connect_graph) |
| remove_nodes = list(sub_connect_graph.nodes) |
| keep_node = remove_nodes.pop() |
| for remove_node in remove_nodes: |
| remove_node_neighbors = graph[remove_node] |
| graph.nodes[keep_node]['description'] += graph.nodes[remove_node]['description'] |
| graph.nodes[keep_node]['weight'] += graph.nodes[remove_node]['weight'] |
| remove_node_neighbors = list(remove_node_neighbors) |
| for remove_node_neighbor in remove_node_neighbors: |
| if remove_node_neighbor == keep_node: |
| graph.remove_edge(keep_node, remove_node) |
| continue |
| if graph.has_edge(keep_node, remove_node_neighbor): |
| graph[keep_node][remove_node_neighbor]['weight'] += graph[remove_node][remove_node_neighbor][ |
| 'weight'] |
| graph[keep_node][remove_node_neighbor]['description'] += \ |
| graph[remove_node][remove_node_neighbor]['description'] |
| graph.remove_edge(remove_node, remove_node_neighbor) |
| else: |
| graph.add_edge(keep_node, remove_node_neighbor, |
| weight=graph[remove_node][remove_node_neighbor]['weight'], |
| description=graph[remove_node][remove_node_neighbor]['description'], |
| source_id="") |
| graph.remove_edge(remove_node, remove_node_neighbor) |
| graph.remove_node(remove_node) |
|
|
| for node_degree in graph.degree: |
| graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) |
|
|
| return EntityResolutionResult( |
| output=graph, |
| ) |
|
|
| def _process_results( |
| self, |
| records_length: int, |
| results: str, |
| record_delimiter: str, |
| entity_index_delimiter: str, |
| resolution_result_delimiter: str |
| ) -> list: |
| ans_list = [] |
| records = [r.strip() for r in results.split(record_delimiter)] |
| for record in records: |
| pattern_int = f"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}" |
| match_int = re.search(pattern_int, record) |
| res_int = int(str(match_int.group(1) if match_int else '0')) |
| if res_int > records_length: |
| continue |
|
|
| pattern_bool = f"{re.escape(resolution_result_delimiter)}([a-zA-Z]+){re.escape(resolution_result_delimiter)}" |
| match_bool = re.search(pattern_bool, record) |
| res_bool = str(match_bool.group(1) if match_bool else '') |
|
|
| if res_int and res_bool: |
| if res_bool.lower() == 'yes': |
| ans_list.append((res_int, "yes")) |
|
|
| return ans_list |
|
|
| def is_similarity(self, a, b): |
| if is_english(a) and is_english(b): |
| if editdistance.eval(a, b) <= min(len(a), len(b)) // 2: |
| return True |
|
|
| if len(set(a) & set(b)) > 0: |
| return True |
|
|
| return False |
|
|