File size: 3,170 Bytes
3487f22 908ff10 3487f22 908ff10 fc12fa1 908ff10 3487f22 908ff10 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | """Classification Agent: tags contract clauses using the CUAD taxonomy"""
import json
import re
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from state import ContractState
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
from observability import get_logger
with open("data/cuad/taxonomy.json") as f:
TAXONOMY = json.load(f)
CUAD_CLAUSE_TYPES = [entry["name"] for entry in TAXONOMY]
SYSTEM_PROMPT = """You are a legal clause classifier for commercial contracts.
Given a contract clause, classify it into one or more of the following CUAD clause types:
{clause_types}
If the clause does not match any type, classify it as "Other".
Respond with ONLY valid JSON in this exact format:
{{
"clause_type": "the primary clause type",
"confidence": 0.0 to 1.0,
"reasoning": "brief explanation"
}}"""
llm = ChatAnthropic(model="claude-haiku-4-5-20251001", max_tokens=256)
prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
("human", "Classify this clause:\n\n{clause_text}"),
])
chain = prompt | llm
def classify_clause(clause_text: str) -> dict:
"""Classify a single clause and return structured result."""
response = chain.invoke({
# "clause_types": "\n".join(f"- {ct}" for ct in CUAD_CLAUSE_TYPES),
"clause_types": "\n".join(
f"- {entry['name']}: {entry['question'].split('Details: ')[-1]}"
for entry in TAXONOMY
),
"clause_text": clause_text,
})
try:
text = response.content.strip()
match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
if match:
text = match.group(1)
result = json.loads(text)
except json.JSONDecodeError:
result = {
"clause_type": "Other",
"confidence": 0.0,
"reasoning": "Failed to parse LLM response",
}
return result
def classification_node(state: ContractState) -> dict:
"""LangGraph node: classify all clauses from the ingestion agent."""
classified = []
# for clause in state["clauses"]:
for clause in state["clauses"][:3]: # to conserve API credits
result = classify_clause(clause["text"])
classified.append({
**clause,
"clause_type": result["clause_type"],
"confidence": result.get("confidence", 0.0),
})
# observability: log classification summary as Braintrust span
logger = get_logger()
if logger:
with logger.start_span("classification_node") as span:
span.log(
input={"clauses_received": len(state["clauses"])},
output={
"clauses_classified": len(classified),
"clause_types": [c["clause_type"] for c in classified],
},
metadata={
"avg_confidence": (
sum(c["confidence"] for c in classified) / len(classified)
if classified else 0.0
),
},
)
return {"classified_clauses": classified}
|