Spaces:
Paused
Paused
| # | |
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| """ | |
| Reference: | |
| - [graphrag](https://github.com/microsoft/graphrag) | |
| """ | |
| import json | |
| import logging | |
| import re | |
| import traceback | |
| from dataclasses import dataclass | |
| from typing import Any, List, Callable | |
| import networkx as nx | |
| import pandas as pd | |
| from graphrag import leiden | |
| from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT | |
| from graphrag.leiden import add_community_info2graph | |
| from rag.llm.chat_model import Base as CompletionLLM | |
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types | |
| from rag.utils import num_tokens_from_string | |
| from timeit import default_timer as timer | |
| log = logging.getLogger(__name__) | |
| class CommunityReportsResult: | |
| """Community reports result class definition.""" | |
| output: List[str] | |
| structured_output: List[dict] | |
| class CommunityReportsExtractor: | |
| """Community reports extractor class definition.""" | |
| _llm: CompletionLLM | |
| _extraction_prompt: str | |
| _output_formatter_prompt: str | |
| _on_error: ErrorHandlerFn | |
| _max_report_length: int | |
| def __init__( | |
| self, | |
| llm_invoker: CompletionLLM, | |
| extraction_prompt: str | None = None, | |
| on_error: ErrorHandlerFn | None = None, | |
| max_report_length: int | None = None, | |
| ): | |
| """Init method definition.""" | |
| self._llm = llm_invoker | |
| self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT | |
| self._on_error = on_error or (lambda _e, _s, _d: None) | |
| self._max_report_length = max_report_length or 1500 | |
| def __call__(self, graph: nx.Graph, callback: Callable | None = None): | |
| communities: dict[str, dict[str, List]] = leiden.run(graph, {}) | |
| total = sum([len(comm.items()) for _, comm in communities.items()]) | |
| relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)]) | |
| res_str = [] | |
| res_dict = [] | |
| over, token_count = 0, 0 | |
| st = timer() | |
| for level, comm in communities.items(): | |
| for cm_id, ents in comm.items(): | |
| weight = ents["weight"] | |
| ents = ents["nodes"] | |
| ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents]) | |
| rela_df = relations_df[(relations_df["source"].isin(ents)) | (relations_df["target"].isin(ents))].reset_index(drop=True) | |
| prompt_variables = { | |
| "entity_df": ent_df.to_csv(index_label="id"), | |
| "relation_df": rela_df.to_csv(index_label="id") | |
| } | |
| text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) | |
| gen_conf = {"temperature": 0.3} | |
| try: | |
| response = self._llm.chat(text, [], gen_conf) | |
| token_count += num_tokens_from_string(text + response) | |
| response = re.sub(r"^[^\{]*", "", response) | |
| response = re.sub(r"[^\}]*$", "", response) | |
| print(response) | |
| response = json.loads(response) | |
| if not dict_has_keys_with_types(response, [ | |
| ("title", str), | |
| ("summary", str), | |
| ("findings", list), | |
| ("rating", float), | |
| ("rating_explanation", str), | |
| ]): continue | |
| response["weight"] = weight | |
| response["entities"] = ents | |
| except Exception as e: | |
| print("ERROR: ", traceback.format_exc()) | |
| self._on_error(e, traceback.format_exc(), None) | |
| continue | |
| add_community_info2graph(graph, ents, response["title"]) | |
| res_str.append(self._get_text_output(response)) | |
| res_dict.append(response) | |
| over += 1 | |
| if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}") | |
| return CommunityReportsResult( | |
| structured_output=res_dict, | |
| output=res_str, | |
| ) | |
| def _get_text_output(self, parsed_output: dict) -> str: | |
| title = parsed_output.get("title", "Report") | |
| summary = parsed_output.get("summary", "") | |
| findings = parsed_output.get("findings", []) | |
| def finding_summary(finding: dict): | |
| if isinstance(finding, str): | |
| return finding | |
| return finding.get("summary") | |
| def finding_explanation(finding: dict): | |
| if isinstance(finding, str): | |
| return "" | |
| return finding.get("explanation") | |
| report_sections = "\n\n".join( | |
| f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings | |
| ) | |
| return f"# {title}\n\n{summary}\n\n{report_sections}" |