File size: 5,645 Bytes
3487f22 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | """Knowledge Graph Agent: extracts entities and relationships from contract text.
Uses an LLM to gather main entities (parties, dates, amounts, products etc) and their
relationships from the contract opening section, builds & saves a NetworkX graph
visualization.
Inserted into the pipeline between ingestion and classification.
Output stored in state as: entities, relationships, graph_image_path.
"""
import json
import os
import re
from observability import get_logger
import matplotlib
matplotlib.use("Agg") # non-interactive backend (no display)
import matplotlib.pyplot as plt
import networkx as nx
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from state import ContractState
OUTPUT_DIR = "outputs"
# Only use first ~5000 chars of the contract for info on parties/dates/obligations
TEXT_EXCERPT_CHARS = 5000
SYSTEM_PROMPT = """You are a legal entity extractor for commercial contracts.
Given the opening section of a contract, extract the most important entities and
relationships. Focus on who the parties are, what they are agreeing to do, and
any key constraints (dates, amounts, exclusivity, territory, etc.).
Respond with ONLY valid JSON in this exact format:
{{
"entities": [
{{"name": "short name", "type": "party|date|amount|product|location|other"}}
],
"relationships": [
{{"source": "entity name", "relation": "short verb phrase", "target": "entity name"}}
]
}}
Aim for 5-10 entities and 5-10 relationships. Keep entity names short (2-4 words max).
Only include relationships between entities you listed."""
llm = ChatAnthropic(model="claude-haiku-4-5-20251001", max_tokens=1024)
prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
("human", "Contract opening section:\n\n{contract_text}"),
])
chain = prompt | llm
# color scheme for entity types in graph viz
ENTITY_COLORS = {
"party": "#4C72B0",
"date": "#DD8452",
"amount": "#55A868",
"product": "#8172B2",
"location": "#937860",
"other": "#C44E52",
}
def extract_entities_and_relationships(contract_text: str) -> tuple[list, list]:
"""Call LLM to extract entities & relationships from contract opening"""
excerpt = contract_text[:TEXT_EXCERPT_CHARS].strip()
response = chain.invoke({"contract_text": excerpt})
try:
text = response.content.strip()
match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
if match:
text = match.group(1)
result = json.loads(text)
except json.JSONDecodeError:
result = {"entities": [], "relationships": []}
return result.get("entities", []), result.get("relationships", [])
def build_and_save_graph(entities: list, relationships: list) -> str:
"""Build directed NetworkX graph, save PNG"""
if not entities:
return ""
os.makedirs(OUTPUT_DIR, exist_ok=True)
output_path = os.path.join(OUTPUT_DIR, "knowledge_graph.png")
G = nx.DiGraph()
for entity in entities:
G.add_node(entity["name"], entity_type=entity.get("type", "other"))
for rel in relationships:
src, tgt = rel.get("source", ""), rel.get("target", "")
if src and tgt and src in G.nodes and tgt in G.nodes:
G.add_edge(src, tgt, label=rel.get("relation", ""))
if len(G.nodes) == 0:
return ""
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(G, seed=42, k=2.5)
node_colors = [
ENTITY_COLORS.get(G.nodes[n].get("entity_type", "other"), ENTITY_COLORS["other"])
for n in G.nodes
]
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=4000, alpha=0.9)
nx.draw_networkx_labels(G, pos, font_size=8, font_color="#1a1a1a", font_weight="bold")
nx.draw_networkx_edges(G, pos, arrows=True, arrowsize=20, edge_color="#888888", width=1.5)
edge_labels = {(u, v): d["label"][:25] for u, v, d in G.edges(data=True)}
nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=7, font_color="#333333")
# Legend
legend_handles = [
plt.Line2D([0], [0], marker="o", color="w", markerfacecolor=color,
markersize=10, label=etype)
for etype, color in ENTITY_COLORS.items()
]
plt.legend(handles=legend_handles, loc="upper left", fontsize=8)
plt.title("Contract Knowledge Graph", fontsize=14, pad=20)
plt.axis("off")
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
return output_path
def knowledge_graph_node(state: ContractState) -> dict:
"""LangGraph node- extract entities/relationships & build knowledge graph"""
entities, relationships = extract_entities_and_relationships(state["raw_text"])
graph_path = build_and_save_graph(entities, relationships)
# observability: log knowledge graph results as Braintrust span
logger = get_logger()
if logger:
with logger.start_span("knowledge_graph_node") as span:
span.log(
input={"contract_excerpt_chars": min(TEXT_EXCERPT_CHARS, len(state["raw_text"]))},
output={
"entities_count": len(entities),
"relationships_count": len(relationships),
"graph_saved": bool(graph_path),
},
metadata={
"entity_types": list({e.get("type") for e in entities}),
"graph_image_path": graph_path,
},
)
return {
"entities": entities,
"relationships": relationships,
"graph_image_path": graph_path,
}
|