File size: 5,645 Bytes
3487f22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""Knowledge Graph Agent: extracts entities and relationships from contract text.

Uses an LLM to gather main entities (parties, dates, amounts, products etc) and their
relationships from the contract opening section, builds & saves a NetworkX graph
visualization. 

Inserted into the pipeline between ingestion and classification.
Output stored in state as: entities, relationships, graph_image_path.
"""

import json
import os
import re

from observability import get_logger 
import matplotlib
matplotlib.use("Agg") # non-interactive backend (no display)
import matplotlib.pyplot as plt
import networkx as nx
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate

from state import ContractState

OUTPUT_DIR = "outputs"

# Only use first ~5000 chars of the contract for info on parties/dates/obligations
TEXT_EXCERPT_CHARS = 5000

SYSTEM_PROMPT = """You are a legal entity extractor for commercial contracts.

Given the opening section of a contract, extract the most important entities and
relationships. Focus on who the parties are, what they are agreeing to do, and
any key constraints (dates, amounts, exclusivity, territory, etc.).

Respond with ONLY valid JSON in this exact format:
{{
    "entities": [
        {{"name": "short name", "type": "party|date|amount|product|location|other"}}
    ],
    "relationships": [
        {{"source": "entity name", "relation": "short verb phrase", "target": "entity name"}}
    ]
}}

Aim for 5-10 entities and 5-10 relationships. Keep entity names short (2-4 words max).
Only include relationships between entities you listed."""

llm = ChatAnthropic(model="claude-haiku-4-5-20251001", max_tokens=1024)

prompt = ChatPromptTemplate.from_messages([
    ("system", SYSTEM_PROMPT),
    ("human", "Contract opening section:\n\n{contract_text}"),
])

chain = prompt | llm

# color scheme for entity types in graph viz
ENTITY_COLORS = {
    "party": "#4C72B0",
    "date": "#DD8452",
    "amount": "#55A868",
    "product": "#8172B2",
    "location": "#937860",
    "other": "#C44E52",
}

def extract_entities_and_relationships(contract_text: str) -> tuple[list, list]:
    """Call LLM to extract entities & relationships from contract opening"""
    excerpt = contract_text[:TEXT_EXCERPT_CHARS].strip()

    response = chain.invoke({"contract_text": excerpt})

    try:
        text = response.content.strip()
        match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
        if match:
            text = match.group(1)
        result = json.loads(text)
    except json.JSONDecodeError:
        result = {"entities": [], "relationships": []}

    return result.get("entities", []), result.get("relationships", [])

def build_and_save_graph(entities: list, relationships: list) -> str:
    """Build directed NetworkX graph, save PNG"""
    if not entities:
        return ""

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    output_path = os.path.join(OUTPUT_DIR, "knowledge_graph.png")

    G = nx.DiGraph()

    for entity in entities:
        G.add_node(entity["name"], entity_type=entity.get("type", "other"))

    for rel in relationships:
        src, tgt = rel.get("source", ""), rel.get("target", "")
        if src and tgt and src in G.nodes and tgt in G.nodes:
            G.add_edge(src, tgt, label=rel.get("relation", ""))

    if len(G.nodes) == 0:
        return ""

    plt.figure(figsize=(12, 8))
    pos = nx.spring_layout(G, seed=42, k=2.5)

    node_colors = [
        ENTITY_COLORS.get(G.nodes[n].get("entity_type", "other"), ENTITY_COLORS["other"])
        for n in G.nodes
    ]

    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=4000, alpha=0.9)
    nx.draw_networkx_labels(G, pos, font_size=8, font_color="#1a1a1a", font_weight="bold")
    nx.draw_networkx_edges(G, pos, arrows=True, arrowsize=20, edge_color="#888888", width=1.5)
    edge_labels = {(u, v): d["label"][:25] for u, v, d in G.edges(data=True)}
    nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=7, font_color="#333333")

    # Legend
    legend_handles = [
        plt.Line2D([0], [0], marker="o", color="w", markerfacecolor=color,
                   markersize=10, label=etype)
        for etype, color in ENTITY_COLORS.items()
    ]
    plt.legend(handles=legend_handles, loc="upper left", fontsize=8)

    plt.title("Contract Knowledge Graph", fontsize=14, pad=20)
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()

    return output_path

def knowledge_graph_node(state: ContractState) -> dict:
    """LangGraph node- extract entities/relationships & build knowledge graph"""
    entities, relationships = extract_entities_and_relationships(state["raw_text"])
    graph_path = build_and_save_graph(entities, relationships)

    # observability: log knowledge graph results as Braintrust span
    logger = get_logger()
    if logger:
        with logger.start_span("knowledge_graph_node") as span:
            span.log(
                input={"contract_excerpt_chars": min(TEXT_EXCERPT_CHARS, len(state["raw_text"]))},
                output={
                    "entities_count": len(entities),
                    "relationships_count": len(relationships),
                    "graph_saved": bool(graph_path),
                },
                metadata={
                    "entity_types": list({e.get("type") for e in entities}),
                    "graph_image_path": graph_path,
                },
            )

    return {
        "entities": entities,
        "relationships": relationships,
        "graph_image_path": graph_path,
    }