| |
| |
| """ |
| Reference: |
| - [graphrag](https://github.com/microsoft/graphrag) |
| """ |
|
|
| import json |
| import logging |
| import re |
| import traceback |
| from dataclasses import dataclass |
| from typing import Any, List, Callable |
| import networkx as nx |
| import pandas as pd |
| from graphrag import leiden |
| from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT |
| from graphrag.leiden import add_community_info2graph |
| from rag.llm.chat_model import Base as CompletionLLM |
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types |
| from rag.utils import num_tokens_from_string |
| from timeit import default_timer as timer |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class CommunityReportsResult: |
| """Community reports result class definition.""" |
|
|
| output: List[str] |
| structured_output: List[dict] |
|
|
|
|
| class CommunityReportsExtractor: |
| """Community reports extractor class definition.""" |
|
|
| _llm: CompletionLLM |
| _extraction_prompt: str |
| _output_formatter_prompt: str |
| _on_error: ErrorHandlerFn |
| _max_report_length: int |
|
|
| def __init__( |
| self, |
| llm_invoker: CompletionLLM, |
| extraction_prompt: str | None = None, |
| on_error: ErrorHandlerFn | None = None, |
| max_report_length: int | None = None, |
| ): |
| """Init method definition.""" |
| self._llm = llm_invoker |
| self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT |
| self._on_error = on_error or (lambda _e, _s, _d: None) |
| self._max_report_length = max_report_length or 1500 |
|
|
| def __call__(self, graph: nx.Graph, callback: Callable | None = None): |
| communities: dict[str, dict[str, List]] = leiden.run(graph, {}) |
| total = sum([len(comm.items()) for _, comm in communities.items()]) |
| relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)]) |
| res_str = [] |
| res_dict = [] |
| over, token_count = 0, 0 |
| st = timer() |
| for level, comm in communities.items(): |
| for cm_id, ents in comm.items(): |
| weight = ents["weight"] |
| ents = ents["nodes"] |
| ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents]) |
| rela_df = relations_df[(relations_df["source"].isin(ents)) | (relations_df["target"].isin(ents))].reset_index(drop=True) |
|
|
| prompt_variables = { |
| "entity_df": ent_df.to_csv(index_label="id"), |
| "relation_df": rela_df.to_csv(index_label="id") |
| } |
| text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) |
| gen_conf = {"temperature": 0.3} |
| try: |
| response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) |
| token_count += num_tokens_from_string(text + response) |
| response = re.sub(r"^[^\{]*", "", response) |
| response = re.sub(r"[^\}]*$", "", response) |
| print(response) |
| response = json.loads(response) |
| if not dict_has_keys_with_types(response, [ |
| ("title", str), |
| ("summary", str), |
| ("findings", list), |
| ("rating", float), |
| ("rating_explanation", str), |
| ]): continue |
| response["weight"] = weight |
| response["entities"] = ents |
| except Exception as e: |
| print("ERROR: ", traceback.format_exc()) |
| self._on_error(e, traceback.format_exc(), None) |
| continue |
|
|
| add_community_info2graph(graph, ents, response["title"]) |
| res_str.append(self._get_text_output(response)) |
| res_dict.append(response) |
| over += 1 |
| if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}") |
|
|
| return CommunityReportsResult( |
| structured_output=res_dict, |
| output=res_str, |
| ) |
|
|
| def _get_text_output(self, parsed_output: dict) -> str: |
| title = parsed_output.get("title", "Report") |
| summary = parsed_output.get("summary", "") |
| findings = parsed_output.get("findings", []) |
|
|
| def finding_summary(finding: dict): |
| if isinstance(finding, str): |
| return finding |
| return finding.get("summary") |
|
|
| def finding_explanation(finding: dict): |
| if isinstance(finding, str): |
| return "" |
| return finding.get("explanation") |
|
|
| report_sections = "\n\n".join( |
| f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings |
| ) |
| |
| return f"# {title}\n\n{summary}\n\n{report_sections}" |
|
|