File size: 9,209 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
import asyncio
from typing import Any, Optional, Sequence
from llama_index.core.async_utils import run_jobs
from llama_index.core.schema import TransformComponent, BaseNode, MetadataMode
from llama_index.core.graph_stores.types import (
EntityNode,
Relation,
KG_NODES_KEY,
KG_RELATIONS_KEY,
)
from evoagentx.core.logging import logger
from evoagentx.models.base_model import BaseLLM, LLMOutputParser
from evoagentx.prompts.rag.graph_extract import ENTITY_EXTRACT_PROMPT, RELATION_EXTRACT_PROMPT
class BasicGraphExtractLLM(TransformComponent):
"""
A TransformComponent for extracting knowledge graph triplets using an LLM without tool-calling capabilities.
This class performs two-stage extraction:
1. Entity extraction: Identifies named entities and their types (e.g., Person, Organization).
2. Relation extraction: Identifies directed relationships between extracted entities.
The extracted entities and relations are stored in the node's metadata for use in LlamaIndex's PropertyGraphIndex.
Attributes:
llm (BaseLLM): The language model for entity and relation extraction.
entity_extract_prompt (str): Prompt template for entity extraction.
relation_extract_prompt (str): Prompt template for relation extraction.
num_workers (int): Number of workers for parallel processing of nodes.
"""
llm: BaseLLM
entity_extract_prompt: str
relation_extract_prompt: str
num_workers: int
def __init__(
self,
llm: BaseLLM,
entity_extract_prompt: Optional[str] = None,
relation_extract_prompt: Optional[str] = None,
num_workers: int = 4,
):
"""
Initialize the BasicGraphExtractLLM.
Args:
llm (BaseLLM): The language model to use for extraction.
entity_extract_prompt (Optional[str]): Custom prompt for entity extraction. Defaults to ENTITY_EXTRACT_PROMPT.
relation_extract_prompt (Optional[str]): Custom prompt for relation extraction. Defaults to RELATION_EXTRACT_PROMPT.
num_workers (int): Number of workers for parallel node processing. Defaults to 4.
"""
super().__init__(
llm=llm,
entity_extract_prompt=entity_extract_prompt or ENTITY_EXTRACT_PROMPT,
relation_extract_prompt=relation_extract_prompt or RELATION_EXTRACT_PROMPT,
num_workers=num_workers,
)
async def _aextract(self, node: BaseNode) -> BaseNode:
"""
Asynchronously extract entities and relations from a single node.
This method performs two LLM calls:
1. Extracts entities and their types using the entity_extract_prompt.
2. Extracts relations between entities using the relation_extract_prompt.
The results are stored in the node's metadata under KG_NODES_KEY and KG_RELATIONS_KEY.
Args:
node (BaseNode): The node containing text to process.
Returns:
BaseNode: The node with updated metadata containing extracted entities and relations.
Raises:
AssertionError: If the node lacks a 'text' attribute.
ValueError: If JSON parsing of LLM output fails (handled with empty fallback).
"""
assert hasattr(node, "text"), "Node must have a 'text' attribute"
text = node.get_content(metadata_mode=MetadataMode.LLM)
try:
# Step 1: Extract entities and their types
extract_prompt = self.entity_extract_prompt.replace("{text}", text)
llm_response = await self.llm.async_generate(
prompt=extract_prompt,
parse_mode="json",
)
# Parse entity results into a JSON string
json_string = llm_response.content.strip()
# Create a mapping of entity names to their types
entity_label_mapping = {
entity_dict["name"]: entity_dict["type"]
for entity_dict in LLMOutputParser._parse_json_content(json_string)["entities"]
}
# Step 2: Extract relations between entities
relation_extract_prompt = self.relation_extract_prompt.replace("{text}", text).replace(
"{entities_json}", json_string
)
llm_response = self.llm.generate(
prompt=relation_extract_prompt,
parse_mode="json",
)
# Parse relation results into triplets
triples = LLMOutputParser._parse_json_content(llm_response.content.strip())["graph"]
except ValueError as e:
logger.warning(f"Failed to parse LLM output for node {node.node_id}: {str(e)}. Returning empty triples.")
entity_label_mapping = {}
triples = []
logger.info(f"Extracted triples from chunk: {triples}")
existing_nodes = node.metadata.pop(KG_NODES_KEY, [])
existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])
metadata = node.metadata.copy()
# Convert extracted triplets into EntityNode and Relation objects
for triple in triples:
subj, rel, obj = triple["source"], triple["relation"], triple["target"]
# Normalize entity and relation names to lowercase and replace spaces with underscores
subj = subj.capitalize().replace(" ", "_")
rel = rel.lower().replace(" ", "_")
obj = obj.capitalize().replace(" ", "_")
subj_node = EntityNode(
name=subj,
label=entity_label_mapping.get(subj, "entity"),
)
obj_node = EntityNode(
name=obj,
label=entity_label_mapping.get(obj, "entity"),
)
# Create relation between entities
rel_node = Relation(
label=rel,
source_id=subj_node.id,
target_id=obj_node.id,
properties=metadata,
)
existing_nodes.extend([subj_node, obj_node])
existing_relations.append(rel_node)
# Update node metadata with extracted entities and relations
node.metadata[KG_NODES_KEY] = existing_nodes
node.metadata[KG_RELATIONS_KEY] = existing_relations
return node
def __call__(
self,
nodes: Sequence[BaseNode],
show_progress: bool = False,
**kwargs: Any,
) -> Sequence[BaseNode]:
"""
Synchronously extract triples from a sequence of nodes.
This method wraps the asynchronous acall method for synchronous execution.
Args:
nodes (Sequence[BaseNode]): The nodes to process.
show_progress (bool): Whether to display a progress bar. Defaults to False.
**kwargs: Additional keyword arguments passed to acall.
Returns:
Sequence[BaseNode]: The processed nodes with updated metadata.
"""
return asyncio.run(self.acall(nodes, show_progress=show_progress, **kwargs))
async def acall(
self,
nodes: Sequence[BaseNode],
show_progress: bool = False,
**kwargs: Any,
) -> Sequence[BaseNode]:
"""
Asynchronously extract triples from a sequence of nodes.
This method processes nodes in parallel using run_jobs for efficiency.
Args:
nodes (Sequence[BaseNode]): The nodes to process.
show_progress (bool): Whether to display a progress bar. Defaults to False.
**kwargs: Additional keyword arguments passed to run_jobs.
Returns:
Sequence[BaseNode]: The processed nodes with updated metadata.
"""
jobs = [self._aextract(node, **kwargs) for node in nodes]
return await run_jobs(
jobs,
workers=self.num_workers,
show_progress=show_progress,
desc="Extracting paths from text",
)
@classmethod
def class_name(cls) -> str:
return "BasicGraphExtractLLM"
if __name__ == "__main__":
import os
import dotenv
dotenv.load_dotenv()
from llama_index.core.schema import TextNode
# from evoagentx.models import OpenRouterConfig, OpenRouterLLM
# OPEN_ROUNTER_API_KEY = os.environ["OPEN_ROUNTER_API_KEY"]
# config = OpenRouterConfig(
# openrouter_key=OPEN_ROUNTER_API_KEY,
# temperature=0.5,
# model="google/gemini-2.5-flash-lite-preview-06-17",
# )
# llm = OpenRouterLLM(config=config)
from evoagentx.models import OpenAILLMConfig, OpenAILLM
config = OpenAILLMConfig(
model="gpt-4o-mini",
temperature=0.7,
max_tokens=1000,
openai_key=os.environ["OPENAI_API_KEY"],
)
llm = OpenAILLM(config=config)
trans = BasicGraphExtractLLM(llm=llm)
node = TextNode(
text="Satya Nadella, the CEO of Microsoft, announced a new partnership with OpenAI in 2023. Microsoft, headquartered in Redmond, Washington, will integrate OpenAI’s AI technologies into its Azure cloud platform. OpenAI, based in San Francisco, California, is known for developing ChatGPT."
)
graph_nodes = trans([node] * 10) |