| """Risk Analysis Agent: scores clause risk using CUAD taxonomy as a structured rubric.""" |
|
|
| import json |
| import os |
| import re |
| from langchain_anthropic import ChatAnthropic |
| from langchain_core.prompts import ChatPromptTemplate |
| from state import ContractState |
| from observability import get_logger |
|
|
| TAXONOMY_PATH = os.path.join(os.path.dirname(__file__), "..", "data", "cuad", "taxonomy.json") |
|
|
| CUAD_RUBRIC = {} |
| if os.path.exists(TAXONOMY_PATH): |
| with open(TAXONOMY_PATH) as f: |
| for entry in json.load(f): |
| CUAD_RUBRIC[entry["name"]] = entry["question"] |
|
|
| SYSTEM_PROMPT = """You are a legal risk analyst for commercial contracts. |
| |
| Given a contract clause, its classified type, and a legal review rubric for that clause type, |
| evaluate the risk level by considering: |
| 1. Ambiguous or vague language that could lead to unfavorable interpretations |
| 2. Missing protective provisions that are standard for this clause type |
| 3. Deviation from standard phrasing or industry norms |
| 4. The specific legal review criteria provided in the rubric |
| |
| {rubric_section} |
| |
| Respond with ONLY valid JSON in this exact format: |
| {{ |
| "risk_score": 0.0 to 1.0 (0 = low risk, 1 = high risk), |
| "risk_factors": ["factor 1", "factor 2"], |
| "reasoning": "brief explanation of the overall risk assessment" |
| }}""" |
|
|
| llm = ChatAnthropic(model="claude-haiku-4-5-20251001", max_tokens=512) |
|
|
| prompt = ChatPromptTemplate.from_messages([ |
| ("system", SYSTEM_PROMPT), |
| ("human", "Clause type: {clause_type}\n\nClause text:\n{clause_text}"), |
| ]) |
|
|
| chain = prompt | llm |
|
|
|
|
| def analyze_risk(clause_text: str, clause_type: str) -> dict: |
| """Analyze risk for a single clause using CUAD rubric.""" |
| rubric_question = CUAD_RUBRIC.get(clause_type, "") |
| if rubric_question: |
| rubric_section = f"Legal review rubric for this clause type:\n{rubric_question}" |
| else: |
| rubric_section = "No specific rubric available for this clause type. Use general legal risk analysis." |
|
|
| response = chain.invoke({ |
| "rubric_section": rubric_section, |
| "clause_type": clause_type, |
| "clause_text": clause_text, |
| }) |
|
|
| try: |
| text = response.content.strip() |
| match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL) |
| if match: |
| text = match.group(1) |
| result = json.loads(text) |
| except json.JSONDecodeError: |
| result = { |
| "risk_score": 0.0, |
| "risk_factors": [], |
| "reasoning": "Failed to parse LLM response", |
| } |
|
|
| return result |
|
|
|
|
| def risk_analysis_node(state: ContractState) -> dict: |
| """LangGraph node: analyze risk for all classified clauses.""" |
| risk_scored = [] |
|
|
| for clause in state["classified_clauses"]: |
| result = analyze_risk(clause["text"], clause.get("clause_type", "Other")) |
| risk_scored.append({ |
| **clause, |
| "risk_score": result.get("risk_score", 0.0), |
| "risk_factors": result.get("risk_factors", []), |
| }) |
|
|
| |
| logger = get_logger() |
| if logger: |
| with logger.start_span("risk_analysis_node") as span: |
| risk_scores = [c["risk_score"] for c in risk_scored] |
| span.log( |
| input={"clauses_received": len(state["classified_clauses"])}, |
| output={ |
| "clauses_scored": len(risk_scored), |
| "high_risk_count": sum(1 for s in risk_scores if s >= 0.7), |
| "avg_risk_score": sum(risk_scores) / len(risk_scores) if risk_scores else 0.0, |
| }, |
| metadata={ |
| "risk_scores": risk_scores, |
| }, |
| ) |
|
|
| return {"risk_scores": risk_scored} |
|
|