Kevin Hu
commited on
Commit
·
6d4f792
1
Parent(s):
b6ce919
refine loginfo about graprag progress (#1823)
Browse files### What problem does this PR solve?
### Type of change
- [x] Refactoring
- api/db/services/document_service.py +2 -1
- graphrag/community_reports_extractor.py +11 -5
- graphrag/graph_extractor.py +15 -5
- graphrag/index.py +3 -3
- rag/nlp/search.py +1 -1
api/db/services/document_service.py
CHANGED
|
@@ -317,7 +317,8 @@ class DocumentService(CommonService):
|
|
| 317 |
if 0 <= t.progress < 1:
|
| 318 |
finished = False
|
| 319 |
prg += t.progress if t.progress >= 0 else 0
|
| 320 |
-
|
|
|
|
| 321 |
if t.progress == -1:
|
| 322 |
bad += 1
|
| 323 |
prg /= len(tsks)
|
|
|
|
| 317 |
if 0 <= t.progress < 1:
|
| 318 |
finished = False
|
| 319 |
prg += t.progress if t.progress >= 0 else 0
|
| 320 |
+
if t.progress_msg not in msg:
|
| 321 |
+
msg.append(t.progress_msg)
|
| 322 |
if t.progress == -1:
|
| 323 |
bad += 1
|
| 324 |
prg /= len(tsks)
|
graphrag/community_reports_extractor.py
CHANGED
|
@@ -23,16 +23,16 @@ import logging
|
|
| 23 |
import re
|
| 24 |
import traceback
|
| 25 |
from dataclasses import dataclass
|
| 26 |
-
from typing import Any, List
|
| 27 |
-
|
| 28 |
import networkx as nx
|
| 29 |
import pandas as pd
|
| 30 |
-
|
| 31 |
from graphrag import leiden
|
| 32 |
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
| 33 |
from graphrag.leiden import add_community_info2graph
|
| 34 |
from rag.llm.chat_model import Base as CompletionLLM
|
| 35 |
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
|
|
|
|
|
|
|
| 36 |
|
| 37 |
log = logging.getLogger(__name__)
|
| 38 |
|
|
@@ -67,11 +67,14 @@ class CommunityReportsExtractor:
|
|
| 67 |
self._on_error = on_error or (lambda _e, _s, _d: None)
|
| 68 |
self._max_report_length = max_report_length or 1500
|
| 69 |
|
| 70 |
-
def __call__(self, graph: nx.Graph):
|
| 71 |
communities: dict[str, dict[str, List]] = leiden.run(graph, {})
|
|
|
|
| 72 |
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
|
| 73 |
res_str = []
|
| 74 |
res_dict = []
|
|
|
|
|
|
|
| 75 |
for level, comm in communities.items():
|
| 76 |
for cm_id, ents in comm.items():
|
| 77 |
weight = ents["weight"]
|
|
@@ -84,9 +87,10 @@ class CommunityReportsExtractor:
|
|
| 84 |
"relation_df": rela_df.to_csv(index_label="id")
|
| 85 |
}
|
| 86 |
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
| 87 |
-
gen_conf = {"temperature": 0.
|
| 88 |
try:
|
| 89 |
response = self._llm.chat(text, [], gen_conf)
|
|
|
|
| 90 |
response = re.sub(r"^[^\{]*", "", response)
|
| 91 |
response = re.sub(r"[^\}]*$", "", response)
|
| 92 |
print(response)
|
|
@@ -108,6 +112,8 @@ class CommunityReportsExtractor:
|
|
| 108 |
add_community_info2graph(graph, ents, response["title"])
|
| 109 |
res_str.append(self._get_text_output(response))
|
| 110 |
res_dict.append(response)
|
|
|
|
|
|
|
| 111 |
|
| 112 |
return CommunityReportsResult(
|
| 113 |
structured_output=res_dict,
|
|
|
|
| 23 |
import re
|
| 24 |
import traceback
|
| 25 |
from dataclasses import dataclass
|
| 26 |
+
from typing import Any, List, Callable
|
|
|
|
| 27 |
import networkx as nx
|
| 28 |
import pandas as pd
|
|
|
|
| 29 |
from graphrag import leiden
|
| 30 |
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
| 31 |
from graphrag.leiden import add_community_info2graph
|
| 32 |
from rag.llm.chat_model import Base as CompletionLLM
|
| 33 |
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
|
| 34 |
+
from rag.utils import num_tokens_from_string
|
| 35 |
+
from timeit import default_timer as timer
|
| 36 |
|
| 37 |
log = logging.getLogger(__name__)
|
| 38 |
|
|
|
|
| 67 |
self._on_error = on_error or (lambda _e, _s, _d: None)
|
| 68 |
self._max_report_length = max_report_length or 1500
|
| 69 |
|
| 70 |
+
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
| 71 |
communities: dict[str, dict[str, List]] = leiden.run(graph, {})
|
| 72 |
+
total = sum([len(comm.items()) for _, comm in communities.items()])
|
| 73 |
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
|
| 74 |
res_str = []
|
| 75 |
res_dict = []
|
| 76 |
+
over, token_count = 0, 0
|
| 77 |
+
st = timer()
|
| 78 |
for level, comm in communities.items():
|
| 79 |
for cm_id, ents in comm.items():
|
| 80 |
weight = ents["weight"]
|
|
|
|
| 87 |
"relation_df": rela_df.to_csv(index_label="id")
|
| 88 |
}
|
| 89 |
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
| 90 |
+
gen_conf = {"temperature": 0.3}
|
| 91 |
try:
|
| 92 |
response = self._llm.chat(text, [], gen_conf)
|
| 93 |
+
token_count += num_tokens_from_string(text + response)
|
| 94 |
response = re.sub(r"^[^\{]*", "", response)
|
| 95 |
response = re.sub(r"[^\}]*$", "", response)
|
| 96 |
print(response)
|
|
|
|
| 112 |
add_community_info2graph(graph, ents, response["title"])
|
| 113 |
res_str.append(self._get_text_output(response))
|
| 114 |
res_dict.append(response)
|
| 115 |
+
over += 1
|
| 116 |
+
if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
|
| 117 |
|
| 118 |
return CommunityReportsResult(
|
| 119 |
structured_output=res_dict,
|
graphrag/graph_extractor.py
CHANGED
|
@@ -21,13 +21,14 @@ import numbers
|
|
| 21 |
import re
|
| 22 |
import traceback
|
| 23 |
from dataclasses import dataclass
|
| 24 |
-
from typing import Any, Mapping
|
| 25 |
import tiktoken
|
| 26 |
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
| 27 |
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
|
| 28 |
from rag.llm.chat_model import Base as CompletionLLM
|
| 29 |
import networkx as nx
|
| 30 |
from rag.utils import num_tokens_from_string
|
|
|
|
| 31 |
|
| 32 |
DEFAULT_TUPLE_DELIMITER = "<|>"
|
| 33 |
DEFAULT_RECORD_DELIMITER = "##"
|
|
@@ -103,7 +104,9 @@ class GraphExtractor:
|
|
| 103 |
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
|
| 104 |
|
| 105 |
def __call__(
|
| 106 |
-
self, texts: list[str],
|
|
|
|
|
|
|
| 107 |
) -> GraphExtractionResult:
|
| 108 |
"""Call method definition."""
|
| 109 |
if prompt_variables is None:
|
|
@@ -127,12 +130,17 @@ class GraphExtractor:
|
|
| 127 |
),
|
| 128 |
}
|
| 129 |
|
|
|
|
|
|
|
|
|
|
| 130 |
for doc_index, text in enumerate(texts):
|
| 131 |
try:
|
| 132 |
# Invoke the entity extraction
|
| 133 |
-
result = self._process_document(text, prompt_variables)
|
| 134 |
source_doc_map[doc_index] = text
|
| 135 |
all_records[doc_index] = result
|
|
|
|
|
|
|
| 136 |
except Exception as e:
|
| 137 |
logging.exception("error extracting graph")
|
| 138 |
self._on_error(
|
|
@@ -162,9 +170,11 @@ class GraphExtractor:
|
|
| 162 |
**prompt_variables,
|
| 163 |
self._input_text_key: text,
|
| 164 |
}
|
|
|
|
| 165 |
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
| 166 |
-
gen_conf = {"temperature": 0.
|
| 167 |
response = self._llm.chat(text, [], gen_conf)
|
|
|
|
| 168 |
|
| 169 |
results = response or ""
|
| 170 |
history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
|
|
@@ -185,7 +195,7 @@ class GraphExtractor:
|
|
| 185 |
if continuation != "YES":
|
| 186 |
break
|
| 187 |
|
| 188 |
-
return results
|
| 189 |
|
| 190 |
def _process_results(
|
| 191 |
self,
|
|
|
|
| 21 |
import re
|
| 22 |
import traceback
|
| 23 |
from dataclasses import dataclass
|
| 24 |
+
from typing import Any, Mapping, Callable
|
| 25 |
import tiktoken
|
| 26 |
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
| 27 |
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
|
| 28 |
from rag.llm.chat_model import Base as CompletionLLM
|
| 29 |
import networkx as nx
|
| 30 |
from rag.utils import num_tokens_from_string
|
| 31 |
+
from timeit import default_timer as timer
|
| 32 |
|
| 33 |
DEFAULT_TUPLE_DELIMITER = "<|>"
|
| 34 |
DEFAULT_RECORD_DELIMITER = "##"
|
|
|
|
| 104 |
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
|
| 105 |
|
| 106 |
def __call__(
|
| 107 |
+
self, texts: list[str],
|
| 108 |
+
prompt_variables: dict[str, Any] | None = None,
|
| 109 |
+
callback: Callable | None = None
|
| 110 |
) -> GraphExtractionResult:
|
| 111 |
"""Call method definition."""
|
| 112 |
if prompt_variables is None:
|
|
|
|
| 130 |
),
|
| 131 |
}
|
| 132 |
|
| 133 |
+
st = timer()
|
| 134 |
+
total = len(texts)
|
| 135 |
+
total_token_count = 0
|
| 136 |
for doc_index, text in enumerate(texts):
|
| 137 |
try:
|
| 138 |
# Invoke the entity extraction
|
| 139 |
+
result, token_count = self._process_document(text, prompt_variables)
|
| 140 |
source_doc_map[doc_index] = text
|
| 141 |
all_records[doc_index] = result
|
| 142 |
+
total_token_count += token_count
|
| 143 |
+
if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
|
| 144 |
except Exception as e:
|
| 145 |
logging.exception("error extracting graph")
|
| 146 |
self._on_error(
|
|
|
|
| 170 |
**prompt_variables,
|
| 171 |
self._input_text_key: text,
|
| 172 |
}
|
| 173 |
+
token_count = 0
|
| 174 |
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
| 175 |
+
gen_conf = {"temperature": 0.3}
|
| 176 |
response = self._llm.chat(text, [], gen_conf)
|
| 177 |
+
token_count = num_tokens_from_string(text + response)
|
| 178 |
|
| 179 |
results = response or ""
|
| 180 |
history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
|
|
|
|
| 195 |
if continuation != "YES":
|
| 196 |
break
|
| 197 |
|
| 198 |
+
return results, token_count
|
| 199 |
|
| 200 |
def _process_results(
|
| 201 |
self,
|
graphrag/index.py
CHANGED
|
@@ -86,7 +86,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
|
|
| 86 |
for i in range(len(chunks)):
|
| 87 |
tkn_cnt = num_tokens_from_string(chunks[i])
|
| 88 |
if cnt+tkn_cnt >= left_token_count and texts:
|
| 89 |
-
threads.append(exe.submit(ext, texts, {"entity_types": entity_types}))
|
| 90 |
texts = []
|
| 91 |
cnt = 0
|
| 92 |
texts.append(chunks[i])
|
|
@@ -98,7 +98,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
|
|
| 98 |
graphs = []
|
| 99 |
for i, _ in enumerate(threads):
|
| 100 |
graphs.append(_.result().output)
|
| 101 |
-
callback(0.5 + 0.1*i/len(threads))
|
| 102 |
|
| 103 |
graph = reduce(graph_merge, graphs)
|
| 104 |
er = EntityResolution(llm_bdl)
|
|
@@ -125,7 +125,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
|
|
| 125 |
|
| 126 |
callback(0.6, "Extracting community reports.")
|
| 127 |
cr = CommunityReportsExtractor(llm_bdl)
|
| 128 |
-
cr = cr(graph)
|
| 129 |
for community, desc in zip(cr.structured_output, cr.output):
|
| 130 |
chunk = {
|
| 131 |
"title_tks": rag_tokenizer.tokenize(community["title"]),
|
|
|
|
| 86 |
for i in range(len(chunks)):
|
| 87 |
tkn_cnt = num_tokens_from_string(chunks[i])
|
| 88 |
if cnt+tkn_cnt >= left_token_count and texts:
|
| 89 |
+
threads.append(exe.submit(ext, texts, {"entity_types": entity_types}, callback))
|
| 90 |
texts = []
|
| 91 |
cnt = 0
|
| 92 |
texts.append(chunks[i])
|
|
|
|
| 98 |
graphs = []
|
| 99 |
for i, _ in enumerate(threads):
|
| 100 |
graphs.append(_.result().output)
|
| 101 |
+
callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")
|
| 102 |
|
| 103 |
graph = reduce(graph_merge, graphs)
|
| 104 |
er = EntityResolution(llm_bdl)
|
|
|
|
| 125 |
|
| 126 |
callback(0.6, "Extracting community reports.")
|
| 127 |
cr = CommunityReportsExtractor(llm_bdl)
|
| 128 |
+
cr = cr(graph, callback=callback)
|
| 129 |
for community, desc in zip(cr.structured_output, cr.output):
|
| 130 |
chunk = {
|
| 131 |
"title_tks": rag_tokenizer.tokenize(community["title"]),
|
rag/nlp/search.py
CHANGED
|
@@ -138,7 +138,7 @@ class Dealer:
|
|
| 138 |
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
|
| 139 |
if self.es.getTotal(res) == 0 and "knn" in s:
|
| 140 |
bqry, _ = self.qryr.question(qst, min_match="10%")
|
| 141 |
-
bqry = self._add_filters(bqry)
|
| 142 |
s["query"] = bqry.to_dict()
|
| 143 |
s["knn"]["filter"] = bqry.to_dict()
|
| 144 |
s["knn"]["similarity"] = 0.17
|
|
|
|
| 138 |
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
|
| 139 |
if self.es.getTotal(res) == 0 and "knn" in s:
|
| 140 |
bqry, _ = self.qryr.question(qst, min_match="10%")
|
| 141 |
+
bqry = self._add_filters(bqry, req)
|
| 142 |
s["query"] = bqry.to_dict()
|
| 143 |
s["knn"]["filter"] = bqry.to_dict()
|
| 144 |
s["knn"]["similarity"] = 0.17
|