contract-clause-analyzer / agents /knowledge_graph_agent.py
satomitheito's picture
Add new agents and observability, fix sys.path for HF Space
3487f22
"""Knowledge Graph Agent: extracts entities and relationships from contract text.
Uses an LLM to gather main entities (parties, dates, amounts, products etc) and their
relationships from the contract opening section, builds & saves a NetworkX graph
visualization.
Inserted into the pipeline between ingestion and classification.
Output stored in state as: entities, relationships, graph_image_path.
"""
import json
import os
import re
from observability import get_logger
import matplotlib
matplotlib.use("Agg") # non-interactive backend (no display)
import matplotlib.pyplot as plt
import networkx as nx
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from state import ContractState
OUTPUT_DIR = "outputs"
# Only use first ~5000 chars of the contract for info on parties/dates/obligations
TEXT_EXCERPT_CHARS = 5000
SYSTEM_PROMPT = """You are a legal entity extractor for commercial contracts.
Given the opening section of a contract, extract the most important entities and
relationships. Focus on who the parties are, what they are agreeing to do, and
any key constraints (dates, amounts, exclusivity, territory, etc.).
Respond with ONLY valid JSON in this exact format:
{{
"entities": [
{{"name": "short name", "type": "party|date|amount|product|location|other"}}
],
"relationships": [
{{"source": "entity name", "relation": "short verb phrase", "target": "entity name"}}
]
}}
Aim for 5-10 entities and 5-10 relationships. Keep entity names short (2-4 words max).
Only include relationships between entities you listed."""
llm = ChatAnthropic(model="claude-haiku-4-5-20251001", max_tokens=1024)
prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
("human", "Contract opening section:\n\n{contract_text}"),
])
chain = prompt | llm
# color scheme for entity types in graph viz
ENTITY_COLORS = {
"party": "#4C72B0",
"date": "#DD8452",
"amount": "#55A868",
"product": "#8172B2",
"location": "#937860",
"other": "#C44E52",
}
def extract_entities_and_relationships(contract_text: str) -> tuple[list, list]:
"""Call LLM to extract entities & relationships from contract opening"""
excerpt = contract_text[:TEXT_EXCERPT_CHARS].strip()
response = chain.invoke({"contract_text": excerpt})
try:
text = response.content.strip()
match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
if match:
text = match.group(1)
result = json.loads(text)
except json.JSONDecodeError:
result = {"entities": [], "relationships": []}
return result.get("entities", []), result.get("relationships", [])
def build_and_save_graph(entities: list, relationships: list) -> str:
"""Build directed NetworkX graph, save PNG"""
if not entities:
return ""
os.makedirs(OUTPUT_DIR, exist_ok=True)
output_path = os.path.join(OUTPUT_DIR, "knowledge_graph.png")
G = nx.DiGraph()
for entity in entities:
G.add_node(entity["name"], entity_type=entity.get("type", "other"))
for rel in relationships:
src, tgt = rel.get("source", ""), rel.get("target", "")
if src and tgt and src in G.nodes and tgt in G.nodes:
G.add_edge(src, tgt, label=rel.get("relation", ""))
if len(G.nodes) == 0:
return ""
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(G, seed=42, k=2.5)
node_colors = [
ENTITY_COLORS.get(G.nodes[n].get("entity_type", "other"), ENTITY_COLORS["other"])
for n in G.nodes
]
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=4000, alpha=0.9)
nx.draw_networkx_labels(G, pos, font_size=8, font_color="#1a1a1a", font_weight="bold")
nx.draw_networkx_edges(G, pos, arrows=True, arrowsize=20, edge_color="#888888", width=1.5)
edge_labels = {(u, v): d["label"][:25] for u, v, d in G.edges(data=True)}
nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=7, font_color="#333333")
# Legend
legend_handles = [
plt.Line2D([0], [0], marker="o", color="w", markerfacecolor=color,
markersize=10, label=etype)
for etype, color in ENTITY_COLORS.items()
]
plt.legend(handles=legend_handles, loc="upper left", fontsize=8)
plt.title("Contract Knowledge Graph", fontsize=14, pad=20)
plt.axis("off")
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
return output_path
def knowledge_graph_node(state: ContractState) -> dict:
"""LangGraph node- extract entities/relationships & build knowledge graph"""
entities, relationships = extract_entities_and_relationships(state["raw_text"])
graph_path = build_and_save_graph(entities, relationships)
# observability: log knowledge graph results as Braintrust span
logger = get_logger()
if logger:
with logger.start_span("knowledge_graph_node") as span:
span.log(
input={"contract_excerpt_chars": min(TEXT_EXCERPT_CHARS, len(state["raw_text"]))},
output={
"entities_count": len(entities),
"relationships_count": len(relationships),
"graph_saved": bool(graph_path),
},
metadata={
"entity_types": list({e.get("type") for e in entities}),
"graph_image_path": graph_path,
},
)
return {
"entities": entities,
"relationships": relationships,
"graph_image_path": graph_path,
}