| """Benchmark Agent: compares clauses against real CUAD contract examples. |
| |
| RECENT UPDATE: Hybrid Retrieval (BM25 + ChromaDB): |
| Now uses 2 retrieval methods & merges results before passing to the LLM: |
| - ChromaDB (semantic/vector search): finds conceptually similar clauses |
| - BM25 (keyword search): finds clauses w/ exact or near-exact matching terms |
| """ |
|
|
| import json |
| import os |
| import re |
|
|
| import chromadb |
| from observability import get_logger |
| from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction |
| from langchain_anthropic import ChatAnthropic |
| from langchain_core.prompts import ChatPromptTemplate |
| from rank_bm25 import BM25Okapi |
| from state import ContractState |
|
|
| STORE_DIR = os.path.join(os.path.dirname(__file__), "..", "data", "cuad_vector_store") |
| COLLECTION_NAME = "cuad_contracts" |
| EMBEDDING_MODEL = "all-MiniLM-L6-v2" |
| N_RESULTS = 3 |
|
|
| SYSTEM_PROMPT = """You are a legal benchmarking analyst for commercial contracts. |
| |
| Given a contract clause, its classified type, and examples of similar clauses retrieved |
| from real CUAD commercial contracts, evaluate how standard the clause is. |
| |
| The examples below come from two retrieval methods: |
| - Semantic matches: conceptually similar clauses found via vector search |
| - Keyword matches: clauses with similar exact phrasing found via keyword search |
| |
| {examples_section} |
| |
| Respond with ONLY valid JSON in this exact format: |
| {{ |
| "benchmark_similarity": 0.0 to 1.0 (0 = highly unusual, 1 = very standard), |
| "deviations": ["deviation 1", "deviation 2"], |
| "standard_language_summary": "brief description of what is typical for this clause type", |
| "reasoning": "brief explanation of how this clause compares to the CUAD examples" |
| }}""" |
|
|
| llm = ChatAnthropic(model="claude-haiku-4-5-20251001", max_tokens=512) |
|
|
| prompt = ChatPromptTemplate.from_messages([ |
| ("system", SYSTEM_PROMPT), |
| ("human", "Clause type: {clause_type}\n\nClause text:\n{clause_text}"), |
| ]) |
|
|
| chain = prompt | llm |
|
|
| |
| _collection = None |
|
|
| |
| _bm25_index = None |
| _bm25_corpus = None |
|
|
| def _get_collection(): |
| global _collection |
| if _collection is None: |
| try: |
| ef = SentenceTransformerEmbeddingFunction(model_name=EMBEDDING_MODEL) |
| client = chromadb.PersistentClient(path=STORE_DIR) |
| _collection = client.get_collection(COLLECTION_NAME, embedding_function=ef) |
| except Exception as e: |
| print(f"[benchmark_agent] Warning: could not load CUAD vector store: {e}") |
| print("[benchmark_agent] Run 'python scripts/build_vector_store.py' to build it.") |
| print("[benchmark_agent] Falling back to LLM-only benchmarking.") |
| return _collection |
|
|
| |
| def _get_bm25(): |
| global _bm25_index, _bm25_corpus |
| if _bm25_index is None: |
| chunks_path = os.path.join(STORE_DIR, "chunks.json") |
| try: |
| with open(chunks_path, encoding="utf-8") as f: |
| _bm25_corpus = json.load(f) |
| tokenized = [doc["text"].lower().split() for doc in _bm25_corpus] |
| _bm25_index = BM25Okapi(tokenized) |
| except Exception as e: |
| print(f"[benchmark_agent] BM25 index not available: {e}") |
| return _bm25_index, _bm25_corpus |
|
|
| def _retrieve_semantic(clause_text: str, clause_type: str) -> list[dict]: |
| """Retrieve top-N semantically similar CUAD examples via ChromaDB""" |
| collection = _get_collection() |
| if collection is None: |
| return [] |
| query = f"{clause_type}: {clause_text[:500]}" |
| results = collection.query(query_texts=[query], n_results=N_RESULTS) |
| return [ |
| {"text": doc, "source": meta["source"]} |
| for doc, meta in zip(results["documents"][0], results["metadatas"][0]) |
| ] |
|
|
| |
| def _retrieve_bm25(clause_text: str, clause_type: str) -> list[dict]: |
| """Retrieve top-N keyword-matched CUAD examples via BM25""" |
| bm25, corpus = _get_bm25() |
| if bm25 is None: |
| return [] |
| query_tokens = f"{clause_type} {clause_text[:500]}".lower().split() |
| scores = bm25.get_scores(query_tokens) |
| top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:N_RESULTS] |
| return [{"text": corpus[i]["text"], "source": corpus[i]["source"]} for i in top_indices] |
|
|
| |
| def _retrieve_cuad_examples(clause_text: str, clause_type: str) -> tuple[str, list[str]]: |
| """Hybrid retrieval: combine ChromaDB (semantic) and BM25 (keyword) results""" |
| semantic_results = _retrieve_semantic(clause_text, clause_type) |
| bm25_results = _retrieve_bm25(clause_text, clause_type) |
|
|
| |
| seen = set() |
| unique_semantic, unique_bm25 = [], [] |
| for item in semantic_results: |
| key = item["text"][:80] |
| if key not in seen: |
| seen.add(key) |
| unique_semantic.append(item) |
| for item in bm25_results: |
| key = item["text"][:80] |
| if key not in seen: |
| seen.add(key) |
| unique_bm25.append(item) |
|
|
| sources = [r["source"] for r in unique_semantic + unique_bm25] |
|
|
| |
| parts = [] |
| for i, r in enumerate(unique_semantic): |
| parts.append(f"Semantic match {i + 1} (source: {r['source']}):\n{r['text'][:400]}") |
| for i, r in enumerate(unique_bm25): |
| parts.append(f"Keyword match {i + 1} (source: {r['source']}):\n{r['text'][:400]}") |
|
|
| examples_text = "\n\n".join(parts) |
| return examples_text, sources |
|
|
| def benchmark_clause(clause_text: str, clause_type: str) -> dict: |
| """Benchmark a single clause against CUAD examples.""" |
| examples_text, sources = _retrieve_cuad_examples(clause_text, clause_type) |
|
|
| if examples_text: |
| examples_section = f"Similar clauses retrieved from CUAD contracts:\n\n{examples_text}" |
| else: |
| examples_section = "No CUAD examples available. Use general legal knowledge." |
|
|
| response = chain.invoke({ |
| "examples_section": examples_section, |
| "clause_type": clause_type, |
| "clause_text": clause_text, |
| }) |
|
|
| try: |
| text = response.content.strip() |
| match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL) |
| if match: |
| text = match.group(1) |
| result = json.loads(text) |
| except json.JSONDecodeError: |
| result = { |
| "benchmark_similarity": 0.0, |
| "deviations": [], |
| "standard_language_summary": "", |
| "reasoning": "Failed to parse LLM response", |
| } |
|
|
| result["_sources"] = sources |
| return result |
|
|
| def benchmark_node(state: ContractState) -> dict: |
| """LangGraph node: benchmark all risk-scored clauses against CUAD contracts.""" |
| benchmarked = [] |
|
|
| source = state.get("risk_scores", state.get("classified_clauses", [])) |
|
|
| for clause in source: |
| result = benchmark_clause(clause["text"], clause.get("clause_type", "Other")) |
| sources = result.get("_sources", []) |
| benchmark_source = ( |
| "CUAD: " + ", ".join(s[:50] for s in sources[:2]) |
| if sources |
| else "CUAD (LLM fallback. Run build_vector_store.py)" |
| ) |
| benchmarked.append({ |
| **clause, |
| "benchmark_similarity": result.get("benchmark_similarity", 0.0), |
| "benchmark_source": benchmark_source, |
| }) |
|
|
| |
| logger = get_logger() |
| if logger: |
| with logger.start_span("benchmark_node") as span: |
| sim_scores = [c["benchmark_similarity"] for c in benchmarked] |
| span.log( |
| input={"clauses_received": len(source)}, |
| output={ |
| "clauses_benchmarked": len(benchmarked), |
| "avg_benchmark_similarity": ( |
| sum(sim_scores) / len(sim_scores) if sim_scores else 0.0 |
| ), |
| "retrieval_method": ( |
| "hybrid (semantic + BM25)" |
| if _bm25_index is not None |
| else "semantic only" |
| ), |
| }, |
| ) |
|
|
| return {"benchmark_results": benchmarked} |
|
|