| """Classification Agent: tags contract clauses using the CUAD taxonomy""" |
|
|
| import json |
| import re |
| from langchain_anthropic import ChatAnthropic |
| from langchain_core.prompts import ChatPromptTemplate |
| from state import ContractState |
| from pathlib import Path |
|
|
| from dotenv import load_dotenv |
| load_dotenv() |
| from observability import get_logger |
|
|
| with open("data/cuad/taxonomy.json") as f: |
| TAXONOMY = json.load(f) |
|
|
| CUAD_CLAUSE_TYPES = [entry["name"] for entry in TAXONOMY] |
|
|
| SYSTEM_PROMPT = """You are a legal clause classifier for commercial contracts. |
| |
| Given a contract clause, classify it into one or more of the following CUAD clause types: |
| |
| {clause_types} |
| |
| If the clause does not match any type, classify it as "Other". |
| |
| Respond with ONLY valid JSON in this exact format: |
| {{ |
| "clause_type": "the primary clause type", |
| "confidence": 0.0 to 1.0, |
| "reasoning": "brief explanation" |
| }}""" |
|
|
| llm = ChatAnthropic(model="claude-haiku-4-5-20251001", max_tokens=256) |
|
|
| prompt = ChatPromptTemplate.from_messages([ |
| ("system", SYSTEM_PROMPT), |
| ("human", "Classify this clause:\n\n{clause_text}"), |
| ]) |
|
|
| chain = prompt | llm |
|
|
|
|
| def classify_clause(clause_text: str) -> dict: |
| """Classify a single clause and return structured result.""" |
| response = chain.invoke({ |
| |
| "clause_types": "\n".join( |
| f"- {entry['name']}: {entry['question'].split('Details: ')[-1]}" |
| for entry in TAXONOMY |
| ), |
| "clause_text": clause_text, |
| }) |
|
|
| try: |
| text = response.content.strip() |
| match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL) |
| if match: |
| text = match.group(1) |
| result = json.loads(text) |
| except json.JSONDecodeError: |
| result = { |
| "clause_type": "Other", |
| "confidence": 0.0, |
| "reasoning": "Failed to parse LLM response", |
| } |
|
|
| return result |
|
|
|
|
| def classification_node(state: ContractState) -> dict: |
| """LangGraph node: classify all clauses from the ingestion agent.""" |
| classified = [] |
|
|
| |
| for clause in state["clauses"][:3]: |
| result = classify_clause(clause["text"]) |
| classified.append({ |
| **clause, |
| "clause_type": result["clause_type"], |
| "confidence": result.get("confidence", 0.0), |
| }) |
|
|
| |
| logger = get_logger() |
| if logger: |
| with logger.start_span("classification_node") as span: |
| span.log( |
| input={"clauses_received": len(state["clauses"])}, |
| output={ |
| "clauses_classified": len(classified), |
| "clause_types": [c["clause_type"] for c in classified], |
| }, |
| metadata={ |
| "avg_confidence": ( |
| sum(c["confidence"] for c in classified) / len(classified) |
| if classified else 0.0 |
| ), |
| }, |
| ) |
|
|
| return {"classified_clauses": classified} |
|
|