|
|
import asyncio |
|
|
from typing import Any, Optional, Sequence |
|
|
|
|
|
from llama_index.core.async_utils import run_jobs |
|
|
from llama_index.core.schema import TransformComponent, BaseNode, MetadataMode |
|
|
from llama_index.core.graph_stores.types import ( |
|
|
EntityNode, |
|
|
Relation, |
|
|
KG_NODES_KEY, |
|
|
KG_RELATIONS_KEY, |
|
|
) |
|
|
|
|
|
from evoagentx.core.logging import logger |
|
|
from evoagentx.models.base_model import BaseLLM, LLMOutputParser |
|
|
from evoagentx.prompts.rag.graph_extract import ENTITY_EXTRACT_PROMPT, RELATION_EXTRACT_PROMPT |
|
|
|
|
|
|
|
|
class BasicGraphExtractLLM(TransformComponent): |
|
|
""" |
|
|
A TransformComponent for extracting knowledge graph triplets using an LLM without tool-calling capabilities. |
|
|
|
|
|
This class performs two-stage extraction: |
|
|
1. Entity extraction: Identifies named entities and their types (e.g., Person, Organization). |
|
|
2. Relation extraction: Identifies directed relationships between extracted entities. |
|
|
|
|
|
The extracted entities and relations are stored in the node's metadata for use in LlamaIndex's PropertyGraphIndex. |
|
|
|
|
|
Attributes: |
|
|
llm (BaseLLM): The language model for entity and relation extraction. |
|
|
entity_extract_prompt (str): Prompt template for entity extraction. |
|
|
relation_extract_prompt (str): Prompt template for relation extraction. |
|
|
num_workers (int): Number of workers for parallel processing of nodes. |
|
|
""" |
|
|
|
|
|
llm: BaseLLM |
|
|
entity_extract_prompt: str |
|
|
relation_extract_prompt: str |
|
|
num_workers: int |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
llm: BaseLLM, |
|
|
entity_extract_prompt: Optional[str] = None, |
|
|
relation_extract_prompt: Optional[str] = None, |
|
|
num_workers: int = 4, |
|
|
): |
|
|
""" |
|
|
Initialize the BasicGraphExtractLLM. |
|
|
|
|
|
Args: |
|
|
llm (BaseLLM): The language model to use for extraction. |
|
|
entity_extract_prompt (Optional[str]): Custom prompt for entity extraction. Defaults to ENTITY_EXTRACT_PROMPT. |
|
|
relation_extract_prompt (Optional[str]): Custom prompt for relation extraction. Defaults to RELATION_EXTRACT_PROMPT. |
|
|
num_workers (int): Number of workers for parallel node processing. Defaults to 4. |
|
|
""" |
|
|
|
|
|
super().__init__( |
|
|
llm=llm, |
|
|
entity_extract_prompt=entity_extract_prompt or ENTITY_EXTRACT_PROMPT, |
|
|
relation_extract_prompt=relation_extract_prompt or RELATION_EXTRACT_PROMPT, |
|
|
num_workers=num_workers, |
|
|
) |
|
|
|
|
|
async def _aextract(self, node: BaseNode) -> BaseNode: |
|
|
""" |
|
|
Asynchronously extract entities and relations from a single node. |
|
|
|
|
|
This method performs two LLM calls: |
|
|
1. Extracts entities and their types using the entity_extract_prompt. |
|
|
2. Extracts relations between entities using the relation_extract_prompt. |
|
|
|
|
|
The results are stored in the node's metadata under KG_NODES_KEY and KG_RELATIONS_KEY. |
|
|
|
|
|
Args: |
|
|
node (BaseNode): The node containing text to process. |
|
|
|
|
|
Returns: |
|
|
BaseNode: The node with updated metadata containing extracted entities and relations. |
|
|
|
|
|
Raises: |
|
|
AssertionError: If the node lacks a 'text' attribute. |
|
|
ValueError: If JSON parsing of LLM output fails (handled with empty fallback). |
|
|
""" |
|
|
assert hasattr(node, "text"), "Node must have a 'text' attribute" |
|
|
|
|
|
text = node.get_content(metadata_mode=MetadataMode.LLM) |
|
|
|
|
|
try: |
|
|
|
|
|
extract_prompt = self.entity_extract_prompt.replace("{text}", text) |
|
|
llm_response = await self.llm.async_generate( |
|
|
prompt=extract_prompt, |
|
|
parse_mode="json", |
|
|
) |
|
|
|
|
|
json_string = llm_response.content.strip() |
|
|
|
|
|
entity_label_mapping = { |
|
|
entity_dict["name"]: entity_dict["type"] |
|
|
for entity_dict in LLMOutputParser._parse_json_content(json_string)["entities"] |
|
|
} |
|
|
|
|
|
|
|
|
relation_extract_prompt = self.relation_extract_prompt.replace("{text}", text).replace( |
|
|
"{entities_json}", json_string |
|
|
) |
|
|
llm_response = self.llm.generate( |
|
|
prompt=relation_extract_prompt, |
|
|
parse_mode="json", |
|
|
) |
|
|
|
|
|
|
|
|
triples = LLMOutputParser._parse_json_content(llm_response.content.strip())["graph"] |
|
|
|
|
|
except ValueError as e: |
|
|
logger.warning(f"Failed to parse LLM output for node {node.node_id}: {str(e)}. Returning empty triples.") |
|
|
entity_label_mapping = {} |
|
|
triples = [] |
|
|
|
|
|
logger.info(f"Extracted triples from chunk: {triples}") |
|
|
|
|
|
existing_nodes = node.metadata.pop(KG_NODES_KEY, []) |
|
|
existing_relations = node.metadata.pop(KG_RELATIONS_KEY, []) |
|
|
metadata = node.metadata.copy() |
|
|
|
|
|
|
|
|
for triple in triples: |
|
|
subj, rel, obj = triple["source"], triple["relation"], triple["target"] |
|
|
|
|
|
subj = subj.capitalize().replace(" ", "_") |
|
|
rel = rel.lower().replace(" ", "_") |
|
|
obj = obj.capitalize().replace(" ", "_") |
|
|
|
|
|
subj_node = EntityNode( |
|
|
name=subj, |
|
|
label=entity_label_mapping.get(subj, "entity"), |
|
|
) |
|
|
obj_node = EntityNode( |
|
|
name=obj, |
|
|
label=entity_label_mapping.get(obj, "entity"), |
|
|
) |
|
|
|
|
|
rel_node = Relation( |
|
|
label=rel, |
|
|
source_id=subj_node.id, |
|
|
target_id=obj_node.id, |
|
|
properties=metadata, |
|
|
) |
|
|
|
|
|
existing_nodes.extend([subj_node, obj_node]) |
|
|
existing_relations.append(rel_node) |
|
|
|
|
|
|
|
|
node.metadata[KG_NODES_KEY] = existing_nodes |
|
|
node.metadata[KG_RELATIONS_KEY] = existing_relations |
|
|
|
|
|
return node |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
nodes: Sequence[BaseNode], |
|
|
show_progress: bool = False, |
|
|
**kwargs: Any, |
|
|
) -> Sequence[BaseNode]: |
|
|
""" |
|
|
Synchronously extract triples from a sequence of nodes. |
|
|
|
|
|
This method wraps the asynchronous acall method for synchronous execution. |
|
|
|
|
|
Args: |
|
|
nodes (Sequence[BaseNode]): The nodes to process. |
|
|
show_progress (bool): Whether to display a progress bar. Defaults to False. |
|
|
**kwargs: Additional keyword arguments passed to acall. |
|
|
|
|
|
Returns: |
|
|
Sequence[BaseNode]: The processed nodes with updated metadata. |
|
|
""" |
|
|
return asyncio.run(self.acall(nodes, show_progress=show_progress, **kwargs)) |
|
|
|
|
|
async def acall( |
|
|
self, |
|
|
nodes: Sequence[BaseNode], |
|
|
show_progress: bool = False, |
|
|
**kwargs: Any, |
|
|
) -> Sequence[BaseNode]: |
|
|
""" |
|
|
Asynchronously extract triples from a sequence of nodes. |
|
|
|
|
|
This method processes nodes in parallel using run_jobs for efficiency. |
|
|
|
|
|
Args: |
|
|
nodes (Sequence[BaseNode]): The nodes to process. |
|
|
show_progress (bool): Whether to display a progress bar. Defaults to False. |
|
|
**kwargs: Additional keyword arguments passed to run_jobs. |
|
|
|
|
|
Returns: |
|
|
Sequence[BaseNode]: The processed nodes with updated metadata. |
|
|
""" |
|
|
jobs = [self._aextract(node, **kwargs) for node in nodes] |
|
|
return await run_jobs( |
|
|
jobs, |
|
|
workers=self.num_workers, |
|
|
show_progress=show_progress, |
|
|
desc="Extracting paths from text", |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def class_name(cls) -> str: |
|
|
return "BasicGraphExtractLLM" |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import os |
|
|
import dotenv |
|
|
|
|
|
dotenv.load_dotenv() |
|
|
|
|
|
from llama_index.core.schema import TextNode |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from evoagentx.models import OpenAILLMConfig, OpenAILLM |
|
|
|
|
|
config = OpenAILLMConfig( |
|
|
model="gpt-4o-mini", |
|
|
temperature=0.7, |
|
|
max_tokens=1000, |
|
|
openai_key=os.environ["OPENAI_API_KEY"], |
|
|
) |
|
|
|
|
|
llm = OpenAILLM(config=config) |
|
|
|
|
|
trans = BasicGraphExtractLLM(llm=llm) |
|
|
node = TextNode( |
|
|
text="Satya Nadella, the CEO of Microsoft, announced a new partnership with OpenAI in 2023. Microsoft, headquartered in Redmond, Washington, will integrate OpenAI’s AI technologies into its Azure cloud platform. OpenAI, based in San Francisco, California, is known for developing ChatGPT." |
|
|
) |
|
|
graph_nodes = trans([node] * 10) |