File size: 8,460 Bytes
3487f22 908ff10 3487f22 908ff10 3487f22 908ff10 3487f22 908ff10 3487f22 908ff10 3487f22 908ff10 3487f22 908ff10 3487f22 fc12fa1 3487f22 fc12fa1 3487f22 908ff10 3487f22 908ff10 3487f22 908ff10 3487f22 908ff10 3487f22 fc12fa1 908ff10 fc12fa1 5399658 908ff10 3487f22 908ff10 3487f22 908ff10 3487f22 908ff10 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | """Benchmark Agent: compares clauses against real CUAD contract examples.
RECENT UPDATE: Hybrid Retrieval (BM25 + ChromaDB):
Now uses 2 retrieval methods & merges results before passing to the LLM:
- ChromaDB (semantic/vector search): finds conceptually similar clauses
- BM25 (keyword search): finds clauses w/ exact or near-exact matching terms
"""
import json
import os
import re
import chromadb
from observability import get_logger
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from rank_bm25 import BM25Okapi
from state import ContractState
STORE_DIR = os.path.join(os.path.dirname(__file__), "..", "data", "cuad_vector_store")
COLLECTION_NAME = "cuad_contracts"
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
N_RESULTS = 3 # examples to retrieve per method (semantic + keyword)
SYSTEM_PROMPT = """You are a legal benchmarking analyst for commercial contracts.
Given a contract clause, its classified type, and examples of similar clauses retrieved
from real CUAD commercial contracts, evaluate how standard the clause is.
The examples below come from two retrieval methods:
- Semantic matches: conceptually similar clauses found via vector search
- Keyword matches: clauses with similar exact phrasing found via keyword search
{examples_section}
Respond with ONLY valid JSON in this exact format:
{{
"benchmark_similarity": 0.0 to 1.0 (0 = highly unusual, 1 = very standard),
"deviations": ["deviation 1", "deviation 2"],
"standard_language_summary": "brief description of what is typical for this clause type",
"reasoning": "brief explanation of how this clause compares to the CUAD examples"
}}"""
llm = ChatAnthropic(model="claude-haiku-4-5-20251001", max_tokens=512)
prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
("human", "Clause type: {clause_type}\n\nClause text:\n{clause_text}"),
])
chain = prompt | llm
# Load vector store once at module import
_collection = None
# BM25 index loaded once at module import
_bm25_index = None
_bm25_corpus = None # list of {"text": ..., "source": ...}
def _get_collection():
global _collection
if _collection is None:
try:
ef = SentenceTransformerEmbeddingFunction(model_name=EMBEDDING_MODEL)
client = chromadb.PersistentClient(path=STORE_DIR)
_collection = client.get_collection(COLLECTION_NAME, embedding_function=ef)
except Exception as e:
print(f"[benchmark_agent] Warning: could not load CUAD vector store: {e}")
print("[benchmark_agent] Run 'python scripts/build_vector_store.py' to build it.")
print("[benchmark_agent] Falling back to LLM-only benchmarking.")
return _collection
# Load BM25 index from chunks.json saved by build_vector_store.py
def _get_bm25():
global _bm25_index, _bm25_corpus
if _bm25_index is None:
chunks_path = os.path.join(STORE_DIR, "chunks.json")
try:
with open(chunks_path, encoding="utf-8") as f:
_bm25_corpus = json.load(f)
tokenized = [doc["text"].lower().split() for doc in _bm25_corpus]
_bm25_index = BM25Okapi(tokenized)
except Exception as e:
print(f"[benchmark_agent] BM25 index not available: {e}")
return _bm25_index, _bm25_corpus
def _retrieve_semantic(clause_text: str, clause_type: str) -> list[dict]:
"""Retrieve top-N semantically similar CUAD examples via ChromaDB"""
collection = _get_collection()
if collection is None:
return []
query = f"{clause_type}: {clause_text[:500]}"
results = collection.query(query_texts=[query], n_results=N_RESULTS)
return [
{"text": doc, "source": meta["source"]}
for doc, meta in zip(results["documents"][0], results["metadatas"][0])
]
# Keyword retrieval w/ BM25
def _retrieve_bm25(clause_text: str, clause_type: str) -> list[dict]:
"""Retrieve top-N keyword-matched CUAD examples via BM25"""
bm25, corpus = _get_bm25()
if bm25 is None:
return []
query_tokens = f"{clause_type} {clause_text[:500]}".lower().split()
scores = bm25.get_scores(query_tokens)
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:N_RESULTS]
return [{"text": corpus[i]["text"], "source": corpus[i]["source"]} for i in top_indices]
# Merge semantic + BM25 results, dedupe, format for LLM prompt
def _retrieve_cuad_examples(clause_text: str, clause_type: str) -> tuple[str, list[str]]:
"""Hybrid retrieval: combine ChromaDB (semantic) and BM25 (keyword) results"""
semantic_results = _retrieve_semantic(clause_text, clause_type)
bm25_results = _retrieve_bm25(clause_text, clause_type)
# Deduplicate by first 80 chars of text (same chunk shouldn't appear twice)
seen = set()
unique_semantic, unique_bm25 = [], []
for item in semantic_results:
key = item["text"][:80]
if key not in seen:
seen.add(key)
unique_semantic.append(item)
for item in bm25_results:
key = item["text"][:80]
if key not in seen:
seen.add(key)
unique_bm25.append(item)
sources = [r["source"] for r in unique_semantic + unique_bm25]
# Format examples for the LLM prompt, labelled by retrieval method
parts = []
for i, r in enumerate(unique_semantic):
parts.append(f"Semantic match {i + 1} (source: {r['source']}):\n{r['text'][:400]}")
for i, r in enumerate(unique_bm25):
parts.append(f"Keyword match {i + 1} (source: {r['source']}):\n{r['text'][:400]}")
examples_text = "\n\n".join(parts)
return examples_text, sources
def benchmark_clause(clause_text: str, clause_type: str) -> dict:
"""Benchmark a single clause against CUAD examples."""
examples_text, sources = _retrieve_cuad_examples(clause_text, clause_type)
if examples_text:
examples_section = f"Similar clauses retrieved from CUAD contracts:\n\n{examples_text}"
else:
examples_section = "No CUAD examples available. Use general legal knowledge."
response = chain.invoke({
"examples_section": examples_section,
"clause_type": clause_type,
"clause_text": clause_text,
})
try:
text = response.content.strip()
match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
if match:
text = match.group(1)
result = json.loads(text)
except json.JSONDecodeError:
result = {
"benchmark_similarity": 0.0,
"deviations": [],
"standard_language_summary": "",
"reasoning": "Failed to parse LLM response",
}
result["_sources"] = sources
return result
def benchmark_node(state: ContractState) -> dict:
"""LangGraph node: benchmark all risk-scored clauses against CUAD contracts."""
benchmarked = []
source = state.get("risk_scores", state.get("classified_clauses", []))
for clause in source:
result = benchmark_clause(clause["text"], clause.get("clause_type", "Other"))
sources = result.get("_sources", [])
benchmark_source = (
"CUAD: " + ", ".join(s[:50] for s in sources[:2])
if sources
else "CUAD (LLM fallback. Run build_vector_store.py)"
)
benchmarked.append({
**clause,
"benchmark_similarity": result.get("benchmark_similarity", 0.0),
"benchmark_source": benchmark_source,
})
# observability: log benchmark summary as Braintrust span
logger = get_logger()
if logger:
with logger.start_span("benchmark_node") as span:
sim_scores = [c["benchmark_similarity"] for c in benchmarked]
span.log(
input={"clauses_received": len(source)},
output={
"clauses_benchmarked": len(benchmarked),
"avg_benchmark_similarity": (
sum(sim_scores) / len(sim_scores) if sim_scores else 0.0
),
"retrieval_method": (
"hybrid (semantic + BM25)"
if _bm25_index is not None
else "semantic only"
),
},
)
return {"benchmark_results": benchmarked}
|