File size: 3,170 Bytes
3487f22
908ff10
 
 
 
 
 
 
 
 
 
3487f22
908ff10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc12fa1
 
908ff10
 
 
 
 
 
 
3487f22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
908ff10
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Classification Agent: tags contract clauses using the CUAD taxonomy"""

import json
import re
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from state import ContractState
from pathlib import Path

from dotenv import load_dotenv
load_dotenv()
from observability import get_logger

with open("data/cuad/taxonomy.json") as f:
    TAXONOMY = json.load(f)

CUAD_CLAUSE_TYPES = [entry["name"] for entry in TAXONOMY]

SYSTEM_PROMPT = """You are a legal clause classifier for commercial contracts.

Given a contract clause, classify it into one or more of the following CUAD clause types:

{clause_types}

If the clause does not match any type, classify it as "Other".

Respond with ONLY valid JSON in this exact format:
{{
    "clause_type": "the primary clause type",
    "confidence": 0.0 to 1.0,
    "reasoning": "brief explanation"
}}"""

llm = ChatAnthropic(model="claude-haiku-4-5-20251001", max_tokens=256)

prompt = ChatPromptTemplate.from_messages([
    ("system", SYSTEM_PROMPT),
    ("human", "Classify this clause:\n\n{clause_text}"),
])

chain = prompt | llm


def classify_clause(clause_text: str) -> dict:
    """Classify a single clause and return structured result."""
    response = chain.invoke({
        # "clause_types": "\n".join(f"- {ct}" for ct in CUAD_CLAUSE_TYPES),
        "clause_types": "\n".join(
            f"- {entry['name']}: {entry['question'].split('Details: ')[-1]}"
            for entry in TAXONOMY
        ),
        "clause_text": clause_text,
    })

    try:
        text = response.content.strip()
        match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
        if match:
            text = match.group(1)
        result = json.loads(text)
    except json.JSONDecodeError:
        result = {
            "clause_type": "Other",
            "confidence": 0.0,
            "reasoning": "Failed to parse LLM response",
        }

    return result


def classification_node(state: ContractState) -> dict:
    """LangGraph node: classify all clauses from the ingestion agent."""
    classified = []

    # for clause in state["clauses"]:
    for clause in state["clauses"][:3]: # to conserve API credits
        result = classify_clause(clause["text"])
        classified.append({
            **clause,
            "clause_type": result["clause_type"],
            "confidence": result.get("confidence", 0.0),
        })

    # observability: log classification summary as Braintrust span
    logger = get_logger()
    if logger:
        with logger.start_span("classification_node") as span:
            span.log(
                input={"clauses_received": len(state["clauses"])},
                output={
                    "clauses_classified": len(classified),
                    "clause_types": [c["clause_type"] for c in classified],
                },
                metadata={
                    "avg_confidence": (
                        sum(c["confidence"] for c in classified) / len(classified)
                        if classified else 0.0
                    ),
                },
            )

    return {"classified_clauses": classified}