| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Reference: |
| - [graphrag](https://github.com/microsoft/graphrag) |
| """ |
| import logging |
| import numbers |
| import re |
| import traceback |
| from dataclasses import dataclass |
| from typing import Any, Mapping, Callable |
| import tiktoken |
| from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT |
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str |
| from rag.llm.chat_model import Base as CompletionLLM |
| import networkx as nx |
| from rag.utils import num_tokens_from_string |
| from timeit import default_timer as timer |
|
|
| DEFAULT_TUPLE_DELIMITER = "<|>" |
| DEFAULT_RECORD_DELIMITER = "##" |
| DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>" |
| DEFAULT_ENTITY_TYPES = ["organization", "person", "location", "event", "time"] |
| ENTITY_EXTRACTION_MAX_GLEANINGS = 1 |
|
|
|
|
| @dataclass |
| class GraphExtractionResult: |
| """Unipartite graph extraction result class definition.""" |
|
|
| output: nx.Graph |
| source_docs: dict[Any, Any] |
|
|
|
|
| class GraphExtractor: |
| """Unipartite graph extractor class definition.""" |
|
|
| _llm: CompletionLLM |
| _join_descriptions: bool |
| _tuple_delimiter_key: str |
| _record_delimiter_key: str |
| _entity_types_key: str |
| _input_text_key: str |
| _completion_delimiter_key: str |
| _entity_name_key: str |
| _input_descriptions_key: str |
| _extraction_prompt: str |
| _summarization_prompt: str |
| _loop_args: dict[str, Any] |
| _max_gleanings: int |
| _on_error: ErrorHandlerFn |
|
|
| def __init__( |
| self, |
| llm_invoker: CompletionLLM, |
| prompt: str | None = None, |
| tuple_delimiter_key: str | None = None, |
| record_delimiter_key: str | None = None, |
| input_text_key: str | None = None, |
| entity_types_key: str | None = None, |
| completion_delimiter_key: str | None = None, |
| join_descriptions=True, |
| encoding_model: str | None = None, |
| max_gleanings: int | None = None, |
| on_error: ErrorHandlerFn | None = None, |
| ): |
| """Init method definition.""" |
| |
| self._llm = llm_invoker |
| self._join_descriptions = join_descriptions |
| self._input_text_key = input_text_key or "input_text" |
| self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter" |
| self._record_delimiter_key = record_delimiter_key or "record_delimiter" |
| self._completion_delimiter_key = ( |
| completion_delimiter_key or "completion_delimiter" |
| ) |
| self._entity_types_key = entity_types_key or "entity_types" |
| self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT |
| self._max_gleanings = ( |
| max_gleanings |
| if max_gleanings is not None |
| else ENTITY_EXTRACTION_MAX_GLEANINGS |
| ) |
| self._on_error = on_error or (lambda _e, _s, _d: None) |
| self.prompt_token_count = num_tokens_from_string(self._extraction_prompt) |
|
|
| |
| encoding = tiktoken.get_encoding(encoding_model or "cl100k_base") |
| yes = encoding.encode("YES") |
| no = encoding.encode("NO") |
| self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} |
|
|
| def __call__( |
| self, texts: list[str], |
| prompt_variables: dict[str, Any] | None = None, |
| callback: Callable | None = None |
| ) -> GraphExtractionResult: |
| """Call method definition.""" |
| if prompt_variables is None: |
| prompt_variables = {} |
| all_records: dict[int, str] = {} |
| source_doc_map: dict[int, str] = {} |
|
|
| |
| prompt_variables = { |
| **prompt_variables, |
| self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key) |
| or DEFAULT_TUPLE_DELIMITER, |
| self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) |
| or DEFAULT_RECORD_DELIMITER, |
| self._completion_delimiter_key: prompt_variables.get( |
| self._completion_delimiter_key |
| ) |
| or DEFAULT_COMPLETION_DELIMITER, |
| self._entity_types_key: ",".join( |
| prompt_variables.get(self._entity_types_key) or DEFAULT_ENTITY_TYPES |
| ), |
| } |
|
|
| st = timer() |
| total = len(texts) |
| total_token_count = 0 |
| for doc_index, text in enumerate(texts): |
| try: |
| |
| result, token_count = self._process_document(text, prompt_variables) |
| source_doc_map[doc_index] = text |
| all_records[doc_index] = result |
| total_token_count += token_count |
| if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}") |
| except Exception as e: |
| logging.exception("error extracting graph") |
| self._on_error( |
| e, |
| traceback.format_exc(), |
| { |
| "doc_index": doc_index, |
| "text": text, |
| }, |
| ) |
|
|
| output = self._process_results( |
| all_records, |
| prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER), |
| prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER), |
| ) |
|
|
| return GraphExtractionResult( |
| output=output, |
| source_docs=source_doc_map, |
| ) |
|
|
| def _process_document( |
| self, text: str, prompt_variables: dict[str, str] |
| ) -> str: |
| variables = { |
| **prompt_variables, |
| self._input_text_key: text, |
| } |
| token_count = 0 |
| text = perform_variable_replacements(self._extraction_prompt, variables=variables) |
| gen_conf = {"temperature": 0.3} |
| response = self._llm.chat(text, [], gen_conf) |
| token_count = num_tokens_from_string(text + response) |
|
|
| results = response or "" |
| history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}] |
|
|
| |
| for i in range(self._max_gleanings): |
| text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) |
| history.append({"role": "user", "content": text}) |
| response = self._llm.chat("", history, gen_conf) |
| results += response or "" |
|
|
| |
| if i >= self._max_gleanings - 1: |
| break |
| history.append({"role": "assistant", "content": response}) |
| history.append({"role": "user", "content": LOOP_PROMPT}) |
| continuation = self._llm.chat("", history, self._loop_args) |
| if continuation != "YES": |
| break |
|
|
| return results, token_count |
|
|
| def _process_results( |
| self, |
| results: dict[int, str], |
| tuple_delimiter: str, |
| record_delimiter: str, |
| ) -> nx.Graph: |
| """Parse the result string to create an undirected unipartite graph. |
| |
| Args: |
| - results - dict of results from the extraction chain |
| - tuple_delimiter - delimiter between tuples in an output record, default is '<|>' |
| - record_delimiter - delimiter between records, default is '##' |
| Returns: |
| - output - unipartite graph in graphML format |
| """ |
| graph = nx.Graph() |
| for source_doc_id, extracted_data in results.items(): |
| records = [r.strip() for r in extracted_data.split(record_delimiter)] |
|
|
| for record in records: |
| record = re.sub(r"^\(|\)$", "", record.strip()) |
| record_attributes = record.split(tuple_delimiter) |
|
|
| if record_attributes[0] == '"entity"' and len(record_attributes) >= 4: |
| |
| entity_name = clean_str(record_attributes[1].upper()) |
| entity_type = clean_str(record_attributes[2].upper()) |
| entity_description = clean_str(record_attributes[3]) |
|
|
| if entity_name in graph.nodes(): |
| node = graph.nodes[entity_name] |
| if self._join_descriptions: |
| node["description"] = "\n".join( |
| list({ |
| *_unpack_descriptions(node), |
| entity_description, |
| }) |
| ) |
| else: |
| if len(entity_description) > len(node["description"]): |
| node["description"] = entity_description |
| node["source_id"] = ", ".join( |
| list({ |
| *_unpack_source_ids(node), |
| str(source_doc_id), |
| }) |
| ) |
| node["entity_type"] = ( |
| entity_type if entity_type != "" else node["entity_type"] |
| ) |
| else: |
| graph.add_node( |
| entity_name, |
| entity_type=entity_type, |
| description=entity_description, |
| source_id=str(source_doc_id), |
| weight=1 |
| ) |
|
|
| if ( |
| record_attributes[0] == '"relationship"' |
| and len(record_attributes) >= 5 |
| ): |
| |
| source = clean_str(record_attributes[1].upper()) |
| target = clean_str(record_attributes[2].upper()) |
| edge_description = clean_str(record_attributes[3]) |
| edge_source_id = clean_str(str(source_doc_id)) |
| weight = ( |
| float(record_attributes[-1]) |
| if isinstance(record_attributes[-1], numbers.Number) |
| else 1.0 |
| ) |
| if source not in graph.nodes(): |
| graph.add_node( |
| source, |
| entity_type="", |
| description="", |
| source_id=edge_source_id, |
| weight=1 |
| ) |
| if target not in graph.nodes(): |
| graph.add_node( |
| target, |
| entity_type="", |
| description="", |
| source_id=edge_source_id, |
| weight=1 |
| ) |
| if graph.has_edge(source, target): |
| edge_data = graph.get_edge_data(source, target) |
| if edge_data is not None: |
| weight += edge_data["weight"] |
| if self._join_descriptions: |
| edge_description = "\n".join( |
| list({ |
| *_unpack_descriptions(edge_data), |
| edge_description, |
| }) |
| ) |
| edge_source_id = ", ".join( |
| list({ |
| *_unpack_source_ids(edge_data), |
| str(source_doc_id), |
| }) |
| ) |
| graph.add_edge( |
| source, |
| target, |
| weight=weight, |
| description=edge_description, |
| source_id=edge_source_id, |
| ) |
|
|
| for node_degree in graph.degree: |
| graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) |
| return graph |
|
|
|
|
| def _unpack_descriptions(data: Mapping) -> list[str]: |
| value = data.get("description", None) |
| return [] if value is None else value.split("\n") |
|
|
|
|
| def _unpack_source_ids(data: Mapping) -> list[str]: |
| value = data.get("source_id", None) |
| return [] if value is None else value.split(", ") |
|
|
|
|
|
|
|
|