""" Pinecone-backed vectorstore utilities for the AI Litigation Tracker. Responsibilities: - Embed case text using OpenAI embeddings. - Chunk long documents and mean-pool embeddings into a single case vector. - Upsert vectors into a Pinecone index. - Run global similarity search for RAG (query_global). - Look up a single case by normalized docket number or case name (get_case_by_filter). Metadata stored with each vector includes: - docket_number, case_name - court_id, filing_date - jurisdiction, courtlistener_url, latest_update - n_docs (document count for the case) """ import os import hashlib import time from typing import Dict, List, Optional from dotenv import load_dotenv from pinecone import Pinecone from openai import OpenAI import tiktoken # Load environment variables (e.g., OPENAI_API_KEY, PINECONE_API_KEY, PINECONE_INDEX) load_dotenv() OPENAI_EMBED_MODEL = os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small") PINECONE_INDEX = os.getenv("PINECONE_INDEX", "ai-litigation-cases") # Reuse single clients across calls _oai = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) _pc = Pinecone(api_key=os.environ["PINECONE_API_KEY"]) try: _index = _pc.Index(PINECONE_INDEX) except Exception as e: # Fail early with a clear message if the index has not been created yet. raise RuntimeError( f"Pinecone index '{PINECONE_INDEX}' not found. " "Run data_updating_scripts.create_pinecone_index first." ) from e # ============================================================ # Internal helpers # ============================================================ def _norm(s: Optional[str]) -> Optional[str]: """ Normalize strings for case-insensitive lookups. Returns lowercased, stripped text or None if input is not a string. """ return s.lower().strip() if isinstance(s, str) else None def _make_id(court_id: str, docket_number: str) -> str: """ Create a stable, opaque ID for a (court_id, docket_number) pair. Uses SHA-1 for compactness and to avoid leaking the raw identifiers in the ID. """ return hashlib.sha1(f"{court_id}|{docket_number}".encode()).hexdigest() def _chunk_text(txt: str, max_tokens: int = 750) -> List[str]: """ Split text into chunks based on token count using tiktoken. Args: txt: Raw text to tokenize and chunk. max_tokens: Maximum tokens per chunk (approximate prompt size control). Returns: A list of decoded text chunks. Returns [""] for empty input. """ enc = tiktoken.get_encoding("cl100k_base") ids = enc.encode(txt or "") chunks = [enc.decode(ids[i : i + max_tokens]) for i in range(0, len(ids), max_tokens)] return chunks or [""] def _embed_batch(texts: List[str]) -> List[List[float]]: """ Embed a batch of texts using the configured OpenAI embedding model. Implements simple retry with backoff for transient errors. Args: texts: List of strings to embed. Returns: List of embedding vectors (one per input string), in order. """ out: List[List[float]] = [] i = 0 while i < len(texts): batch = texts[i : i + 32] # modest batch size for reliability for attempt in range(4): try: resp = _oai.embeddings.create(model=OPENAI_EMBED_MODEL, input=batch) out.extend([d.embedding for d in resp.data]) break except Exception: if attempt == 3: # After 4 attempts, re-raise to surface the failure. raise # Simple linear backoff to avoid hammering the API. time.sleep(1.5 * (attempt + 1)) i += 32 return out def _mean_pool(vectors: List[List[float]]) -> List[float]: """ Compute the element-wise mean of a list of vectors. Used to aggregate chunk-level embeddings into a single case-level embedding. """ if not vectors: return [] d = len(vectors[0]) acc = [0.0] * d for v in vectors: for j in range(d): acc[j] += v[j] return [x / len(vectors) for x in acc] # ============================================================ # Public API: vector creation + indexing # ============================================================ def case_to_vector_payload( *, docket_number: str, case_name: str, court_id: str, filing_date: Optional[str], concatenated_plain_text: str, extra_meta: Optional[Dict] = None, ): """ Build a Pinecone-ready vector payload for a single case. Steps: 1. Tokenize and chunk the concatenated case text. 2. Embed each chunk with OpenAI embeddings. 3. Mean-pool chunk embeddings into a single centroid vector. 4. Construct a stable ID and attach metadata used for filtering and display. Args: docket_number: Raw docket number string. case_name: Case name/title. court_id: CourtListener slug (e.g., "mdd", "nysd"). filing_date: Filing date as a string (e.g., "08302023" or "2023-08-30"). concatenated_plain_text: Combined plain text for all documents in the case. extra_meta: Optional dict for additional metadata such as: - n_docs - courtlistener_url - jurisdiction - latest_update Returns: (stable_id, embedding_vector, metadata_dict) """ chunks = _chunk_text(concatenated_plain_text) embs = _embed_batch(chunks) centroid = _mean_pool(embs) stable_id = _make_id(court_id, docket_number) extra_meta = extra_meta or {} metadata = { "docket_number": docket_number, "docket_number_norm": _norm(docket_number), "case_name": case_name, "case_name_norm": _norm(case_name), "court_id": court_id, "filing_date": filing_date or "", "n_docs": extra_meta.get("n_docs"), "courtlistener_url": extra_meta.get("courtlistener_url"), "jurisdiction": extra_meta.get("jurisdiction"), "latest_update": extra_meta.get("latest_update"), } # Drop None values so the index metadata stays lean. metadata = {k: v for k, v in metadata.items() if v is not None} return stable_id, centroid, metadata def already_indexed(*, court_id: str, docket_number: str) -> bool: """ Check whether a case is already present in the Pinecone index. Args: court_id: CourtListener slug. docket_number: Docket number for the case. Returns: True if an entry with the stable ID exists in the index, else False. """ vid = _make_id(court_id, docket_number) res = _index.fetch(ids=[vid]) vectors = (res or {}).get("vectors") or {} return vid in vectors def upsert_cases(vectors: List[Dict]) -> None: """ Upsert a list of vector payloads into the Pinecone index. Args: vectors: List of dicts of the form {"id": ..., "values": [...], "metadata": {...}}. Notes: - Pinecone handles deduplication by ID, so upserts are idempotent. - We still batch requests for efficiency and API friendliness. """ for i in range(0, len(vectors), 100): _index.upsert(vectors=vectors[i : i + 100]) # ============================================================ # Public API: querying # ============================================================ def query_global(question: str, top_k: int = 5) -> List[Dict]: """ Run a global semantic search over all cases for a natural language question. Args: question: Text query to embed and search with. top_k: Maximum number of matches to return. Returns: A list of dictionaries for each match: { "score": , ...... } """ q = _embed_batch([question])[0] res = _index.query(vector=q, top_k=top_k, include_metadata=True) return [{"score": m["score"], **m["metadata"]} for m in res.get("matches", [])] def get_case_by_filter( *, docket_number: Optional[str] = None, case_name: Optional[str] = None, ) -> Optional[Dict]: """ Look up a single case by normalized docket number and/or case name. This is primarily used by case-specific Q&A in rag/chains.py. Args: docket_number: Optional docket string (case-insensitive). case_name: Optional case name string (case-insensitive). Returns: A metadata dict for the first matching case, or None if no match. The returned metadata mirrors what was stored in case_to_vector_payload. """ flt: Dict[str, str] = {} if docket_number: flt["docket_number_norm"] = _norm(docket_number) if case_name: flt["case_name_norm"] = _norm(case_name) if not flt: return None # We do a "dummy" query using a zero vector and rely on metadata filter # to select the matching case. The dimension must match the embed model. dim = 1536 if "3-small" in OPENAI_EMBED_MODEL else 3072 res = _index.query( vector=[0.0] * dim, top_k=1, include_metadata=True, filter=flt, ) matches = res.get("matches", []) return matches[0]["metadata"] if matches else None