Spaces:
Sleeping
Sleeping
| """ | |
| Pinecone-backed vectorstore utilities for the AI Litigation Tracker. | |
| Responsibilities: | |
| - Embed case text using OpenAI embeddings. | |
| - Chunk long documents and mean-pool embeddings into a single case vector. | |
| - Upsert vectors into a Pinecone index. | |
| - Run global similarity search for RAG (query_global). | |
| - Look up a single case by normalized docket number or case name (get_case_by_filter). | |
| Metadata stored with each vector includes: | |
| - docket_number, case_name | |
| - court_id, filing_date | |
| - jurisdiction, courtlistener_url, latest_update | |
| - n_docs (document count for the case) | |
| """ | |
| import os | |
| import hashlib | |
| import time | |
| from typing import Dict, List, Optional | |
| from dotenv import load_dotenv | |
| from pinecone import Pinecone | |
| from openai import OpenAI | |
| import tiktoken | |
| # Load environment variables (e.g., OPENAI_API_KEY, PINECONE_API_KEY, PINECONE_INDEX) | |
| load_dotenv() | |
| OPENAI_EMBED_MODEL = os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small") | |
| PINECONE_INDEX = os.getenv("PINECONE_INDEX", "ai-litigation-cases") | |
| # Reuse single clients across calls | |
| _oai = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) | |
| _pc = Pinecone(api_key=os.environ["PINECONE_API_KEY"]) | |
| try: | |
| _index = _pc.Index(PINECONE_INDEX) | |
| except Exception as e: | |
| # Fail early with a clear message if the index has not been created yet. | |
| raise RuntimeError( | |
| f"Pinecone index '{PINECONE_INDEX}' not found. " | |
| "Run data_updating_scripts.create_pinecone_index first." | |
| ) from e | |
| # ============================================================ | |
| # Internal helpers | |
| # ============================================================ | |
| def _norm(s: Optional[str]) -> Optional[str]: | |
| """ | |
| Normalize strings for case-insensitive lookups. | |
| Returns lowercased, stripped text or None if input is not a string. | |
| """ | |
| return s.lower().strip() if isinstance(s, str) else None | |
| def _make_id(court_id: str, docket_number: str) -> str: | |
| """ | |
| Create a stable, opaque ID for a (court_id, docket_number) pair. | |
| Uses SHA-1 for compactness and to avoid leaking the raw identifiers in the ID. | |
| """ | |
| return hashlib.sha1(f"{court_id}|{docket_number}".encode()).hexdigest() | |
| def _chunk_text(txt: str, max_tokens: int = 750) -> List[str]: | |
| """ | |
| Split text into chunks based on token count using tiktoken. | |
| Args: | |
| txt: Raw text to tokenize and chunk. | |
| max_tokens: Maximum tokens per chunk (approximate prompt size control). | |
| Returns: | |
| A list of decoded text chunks. Returns [""] for empty input. | |
| """ | |
| enc = tiktoken.get_encoding("cl100k_base") | |
| ids = enc.encode(txt or "") | |
| chunks = [enc.decode(ids[i : i + max_tokens]) for i in range(0, len(ids), max_tokens)] | |
| return chunks or [""] | |
| def _embed_batch(texts: List[str]) -> List[List[float]]: | |
| """ | |
| Embed a batch of texts using the configured OpenAI embedding model. | |
| Implements simple retry with backoff for transient errors. | |
| Args: | |
| texts: List of strings to embed. | |
| Returns: | |
| List of embedding vectors (one per input string), in order. | |
| """ | |
| out: List[List[float]] = [] | |
| i = 0 | |
| while i < len(texts): | |
| batch = texts[i : i + 32] # modest batch size for reliability | |
| for attempt in range(4): | |
| try: | |
| resp = _oai.embeddings.create(model=OPENAI_EMBED_MODEL, input=batch) | |
| out.extend([d.embedding for d in resp.data]) | |
| break | |
| except Exception: | |
| if attempt == 3: | |
| # After 4 attempts, re-raise to surface the failure. | |
| raise | |
| # Simple linear backoff to avoid hammering the API. | |
| time.sleep(1.5 * (attempt + 1)) | |
| i += 32 | |
| return out | |
| def _mean_pool(vectors: List[List[float]]) -> List[float]: | |
| """ | |
| Compute the element-wise mean of a list of vectors. | |
| Used to aggregate chunk-level embeddings into a single case-level embedding. | |
| """ | |
| if not vectors: | |
| return [] | |
| d = len(vectors[0]) | |
| acc = [0.0] * d | |
| for v in vectors: | |
| for j in range(d): | |
| acc[j] += v[j] | |
| return [x / len(vectors) for x in acc] | |
| # ============================================================ | |
| # Public API: vector creation + indexing | |
| # ============================================================ | |
| def case_to_vector_payload( | |
| *, | |
| docket_number: str, | |
| case_name: str, | |
| court_id: str, | |
| filing_date: Optional[str], | |
| concatenated_plain_text: str, | |
| extra_meta: Optional[Dict] = None, | |
| ): | |
| """ | |
| Build a Pinecone-ready vector payload for a single case. | |
| Steps: | |
| 1. Tokenize and chunk the concatenated case text. | |
| 2. Embed each chunk with OpenAI embeddings. | |
| 3. Mean-pool chunk embeddings into a single centroid vector. | |
| 4. Construct a stable ID and attach metadata used for filtering and display. | |
| Args: | |
| docket_number: Raw docket number string. | |
| case_name: Case name/title. | |
| court_id: CourtListener slug (e.g., "mdd", "nysd"). | |
| filing_date: Filing date as a string (e.g., "08302023" or "2023-08-30"). | |
| concatenated_plain_text: Combined plain text for all documents in the case. | |
| extra_meta: Optional dict for additional metadata such as: | |
| - n_docs | |
| - courtlistener_url | |
| - jurisdiction | |
| - latest_update | |
| Returns: | |
| (stable_id, embedding_vector, metadata_dict) | |
| """ | |
| chunks = _chunk_text(concatenated_plain_text) | |
| embs = _embed_batch(chunks) | |
| centroid = _mean_pool(embs) | |
| stable_id = _make_id(court_id, docket_number) | |
| extra_meta = extra_meta or {} | |
| metadata = { | |
| "docket_number": docket_number, | |
| "docket_number_norm": _norm(docket_number), | |
| "case_name": case_name, | |
| "case_name_norm": _norm(case_name), | |
| "court_id": court_id, | |
| "filing_date": filing_date or "", | |
| "n_docs": extra_meta.get("n_docs"), | |
| "courtlistener_url": extra_meta.get("courtlistener_url"), | |
| "jurisdiction": extra_meta.get("jurisdiction"), | |
| "latest_update": extra_meta.get("latest_update"), | |
| } | |
| # Drop None values so the index metadata stays lean. | |
| metadata = {k: v for k, v in metadata.items() if v is not None} | |
| return stable_id, centroid, metadata | |
| def already_indexed(*, court_id: str, docket_number: str) -> bool: | |
| """ | |
| Check whether a case is already present in the Pinecone index. | |
| Args: | |
| court_id: CourtListener slug. | |
| docket_number: Docket number for the case. | |
| Returns: | |
| True if an entry with the stable ID exists in the index, else False. | |
| """ | |
| vid = _make_id(court_id, docket_number) | |
| res = _index.fetch(ids=[vid]) | |
| vectors = (res or {}).get("vectors") or {} | |
| return vid in vectors | |
| def upsert_cases(vectors: List[Dict]) -> None: | |
| """ | |
| Upsert a list of vector payloads into the Pinecone index. | |
| Args: | |
| vectors: List of dicts of the form {"id": ..., "values": [...], "metadata": {...}}. | |
| Notes: | |
| - Pinecone handles deduplication by ID, so upserts are idempotent. | |
| - We still batch requests for efficiency and API friendliness. | |
| """ | |
| for i in range(0, len(vectors), 100): | |
| _index.upsert(vectors=vectors[i : i + 100]) | |
| # ============================================================ | |
| # Public API: querying | |
| # ============================================================ | |
| def query_global(question: str, top_k: int = 5) -> List[Dict]: | |
| """ | |
| Run a global semantic search over all cases for a natural language question. | |
| Args: | |
| question: Text query to embed and search with. | |
| top_k: Maximum number of matches to return. | |
| Returns: | |
| A list of dictionaries for each match: | |
| { | |
| "score": <similarity score>, | |
| ...<all stored metadata fields>... | |
| } | |
| """ | |
| q = _embed_batch([question])[0] | |
| res = _index.query(vector=q, top_k=top_k, include_metadata=True) | |
| return [{"score": m["score"], **m["metadata"]} for m in res.get("matches", [])] | |
| def get_case_by_filter( | |
| *, | |
| docket_number: Optional[str] = None, | |
| case_name: Optional[str] = None, | |
| ) -> Optional[Dict]: | |
| """ | |
| Look up a single case by normalized docket number and/or case name. | |
| This is primarily used by case-specific Q&A in rag/chains.py. | |
| Args: | |
| docket_number: Optional docket string (case-insensitive). | |
| case_name: Optional case name string (case-insensitive). | |
| Returns: | |
| A metadata dict for the first matching case, or None if no match. | |
| The returned metadata mirrors what was stored in case_to_vector_payload. | |
| """ | |
| flt: Dict[str, str] = {} | |
| if docket_number: | |
| flt["docket_number_norm"] = _norm(docket_number) | |
| if case_name: | |
| flt["case_name_norm"] = _norm(case_name) | |
| if not flt: | |
| return None | |
| # We do a "dummy" query using a zero vector and rely on metadata filter | |
| # to select the matching case. The dimension must match the embed model. | |
| dim = 1536 if "3-small" in OPENAI_EMBED_MODEL else 3072 | |
| res = _index.query( | |
| vector=[0.0] * dim, | |
| top_k=1, | |
| include_metadata=True, | |
| filter=flt, | |
| ) | |
| matches = res.get("matches", []) | |
| return matches[0]["metadata"] if matches else None |