litigation-tracker / vectorstore /cases_vectorstore.py
bddinh3's picture
Upload 10 files
dda5804 verified
"""
Pinecone-backed vectorstore utilities for the AI Litigation Tracker.
Responsibilities:
- Embed case text using OpenAI embeddings.
- Chunk long documents and mean-pool embeddings into a single case vector.
- Upsert vectors into a Pinecone index.
- Run global similarity search for RAG (query_global).
- Look up a single case by normalized docket number or case name (get_case_by_filter).
Metadata stored with each vector includes:
- docket_number, case_name
- court_id, filing_date
- jurisdiction, courtlistener_url, latest_update
- n_docs (document count for the case)
"""
import os
import hashlib
import time
from typing import Dict, List, Optional
from dotenv import load_dotenv
from pinecone import Pinecone
from openai import OpenAI
import tiktoken
# Load environment variables (e.g., OPENAI_API_KEY, PINECONE_API_KEY, PINECONE_INDEX)
load_dotenv()
OPENAI_EMBED_MODEL = os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small")
PINECONE_INDEX = os.getenv("PINECONE_INDEX", "ai-litigation-cases")
# Reuse single clients across calls
_oai = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
_pc = Pinecone(api_key=os.environ["PINECONE_API_KEY"])
try:
_index = _pc.Index(PINECONE_INDEX)
except Exception as e:
# Fail early with a clear message if the index has not been created yet.
raise RuntimeError(
f"Pinecone index '{PINECONE_INDEX}' not found. "
"Run data_updating_scripts.create_pinecone_index first."
) from e
# ============================================================
# Internal helpers
# ============================================================
def _norm(s: Optional[str]) -> Optional[str]:
"""
Normalize strings for case-insensitive lookups.
Returns lowercased, stripped text or None if input is not a string.
"""
return s.lower().strip() if isinstance(s, str) else None
def _make_id(court_id: str, docket_number: str) -> str:
"""
Create a stable, opaque ID for a (court_id, docket_number) pair.
Uses SHA-1 for compactness and to avoid leaking the raw identifiers in the ID.
"""
return hashlib.sha1(f"{court_id}|{docket_number}".encode()).hexdigest()
def _chunk_text(txt: str, max_tokens: int = 750) -> List[str]:
"""
Split text into chunks based on token count using tiktoken.
Args:
txt: Raw text to tokenize and chunk.
max_tokens: Maximum tokens per chunk (approximate prompt size control).
Returns:
A list of decoded text chunks. Returns [""] for empty input.
"""
enc = tiktoken.get_encoding("cl100k_base")
ids = enc.encode(txt or "")
chunks = [enc.decode(ids[i : i + max_tokens]) for i in range(0, len(ids), max_tokens)]
return chunks or [""]
def _embed_batch(texts: List[str]) -> List[List[float]]:
"""
Embed a batch of texts using the configured OpenAI embedding model.
Implements simple retry with backoff for transient errors.
Args:
texts: List of strings to embed.
Returns:
List of embedding vectors (one per input string), in order.
"""
out: List[List[float]] = []
i = 0
while i < len(texts):
batch = texts[i : i + 32] # modest batch size for reliability
for attempt in range(4):
try:
resp = _oai.embeddings.create(model=OPENAI_EMBED_MODEL, input=batch)
out.extend([d.embedding for d in resp.data])
break
except Exception:
if attempt == 3:
# After 4 attempts, re-raise to surface the failure.
raise
# Simple linear backoff to avoid hammering the API.
time.sleep(1.5 * (attempt + 1))
i += 32
return out
def _mean_pool(vectors: List[List[float]]) -> List[float]:
"""
Compute the element-wise mean of a list of vectors.
Used to aggregate chunk-level embeddings into a single case-level embedding.
"""
if not vectors:
return []
d = len(vectors[0])
acc = [0.0] * d
for v in vectors:
for j in range(d):
acc[j] += v[j]
return [x / len(vectors) for x in acc]
# ============================================================
# Public API: vector creation + indexing
# ============================================================
def case_to_vector_payload(
*,
docket_number: str,
case_name: str,
court_id: str,
filing_date: Optional[str],
concatenated_plain_text: str,
extra_meta: Optional[Dict] = None,
):
"""
Build a Pinecone-ready vector payload for a single case.
Steps:
1. Tokenize and chunk the concatenated case text.
2. Embed each chunk with OpenAI embeddings.
3. Mean-pool chunk embeddings into a single centroid vector.
4. Construct a stable ID and attach metadata used for filtering and display.
Args:
docket_number: Raw docket number string.
case_name: Case name/title.
court_id: CourtListener slug (e.g., "mdd", "nysd").
filing_date: Filing date as a string (e.g., "08302023" or "2023-08-30").
concatenated_plain_text: Combined plain text for all documents in the case.
extra_meta: Optional dict for additional metadata such as:
- n_docs
- courtlistener_url
- jurisdiction
- latest_update
Returns:
(stable_id, embedding_vector, metadata_dict)
"""
chunks = _chunk_text(concatenated_plain_text)
embs = _embed_batch(chunks)
centroid = _mean_pool(embs)
stable_id = _make_id(court_id, docket_number)
extra_meta = extra_meta or {}
metadata = {
"docket_number": docket_number,
"docket_number_norm": _norm(docket_number),
"case_name": case_name,
"case_name_norm": _norm(case_name),
"court_id": court_id,
"filing_date": filing_date or "",
"n_docs": extra_meta.get("n_docs"),
"courtlistener_url": extra_meta.get("courtlistener_url"),
"jurisdiction": extra_meta.get("jurisdiction"),
"latest_update": extra_meta.get("latest_update"),
}
# Drop None values so the index metadata stays lean.
metadata = {k: v for k, v in metadata.items() if v is not None}
return stable_id, centroid, metadata
def already_indexed(*, court_id: str, docket_number: str) -> bool:
"""
Check whether a case is already present in the Pinecone index.
Args:
court_id: CourtListener slug.
docket_number: Docket number for the case.
Returns:
True if an entry with the stable ID exists in the index, else False.
"""
vid = _make_id(court_id, docket_number)
res = _index.fetch(ids=[vid])
vectors = (res or {}).get("vectors") or {}
return vid in vectors
def upsert_cases(vectors: List[Dict]) -> None:
"""
Upsert a list of vector payloads into the Pinecone index.
Args:
vectors: List of dicts of the form {"id": ..., "values": [...], "metadata": {...}}.
Notes:
- Pinecone handles deduplication by ID, so upserts are idempotent.
- We still batch requests for efficiency and API friendliness.
"""
for i in range(0, len(vectors), 100):
_index.upsert(vectors=vectors[i : i + 100])
# ============================================================
# Public API: querying
# ============================================================
def query_global(question: str, top_k: int = 5) -> List[Dict]:
"""
Run a global semantic search over all cases for a natural language question.
Args:
question: Text query to embed and search with.
top_k: Maximum number of matches to return.
Returns:
A list of dictionaries for each match:
{
"score": <similarity score>,
...<all stored metadata fields>...
}
"""
q = _embed_batch([question])[0]
res = _index.query(vector=q, top_k=top_k, include_metadata=True)
return [{"score": m["score"], **m["metadata"]} for m in res.get("matches", [])]
def get_case_by_filter(
*,
docket_number: Optional[str] = None,
case_name: Optional[str] = None,
) -> Optional[Dict]:
"""
Look up a single case by normalized docket number and/or case name.
This is primarily used by case-specific Q&A in rag/chains.py.
Args:
docket_number: Optional docket string (case-insensitive).
case_name: Optional case name string (case-insensitive).
Returns:
A metadata dict for the first matching case, or None if no match.
The returned metadata mirrors what was stored in case_to_vector_payload.
"""
flt: Dict[str, str] = {}
if docket_number:
flt["docket_number_norm"] = _norm(docket_number)
if case_name:
flt["case_name_norm"] = _norm(case_name)
if not flt:
return None
# We do a "dummy" query using a zero vector and rely on metadata filter
# to select the matching case. The dimension must match the embed model.
dim = 1536 if "3-small" in OPENAI_EMBED_MODEL else 3072
res = _index.query(
vector=[0.0] * dim,
top_k=1,
include_metadata=True,
filter=flt,
)
matches = res.get("matches", [])
return matches[0]["metadata"] if matches else None