legislation-tracker / vectorstore /bills_vectorstore.py
ramanna's picture
Upload 7 files
b501a8e verified
# vectorstore/bills_vectorstore.py
from __future__ import annotations
import os, json, hashlib, time
from pathlib import Path
from typing import Dict, List, Optional, Iterable, Any
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
try:
from langchain_chroma import Chroma
except Exception:
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
DEFAULT_EMBED_MODEL = os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small")
DEFAULT_PERSIST_DIR = "data/bills_vectorstore"
DEFAULT_COLLECTION = "bills"
DEFAULT_MANIFEST = "data/bills_vectorstore_manifest.json"
def get_embeddings(model: Optional[str] = None) -> OpenAIEmbeddings:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY is not set. Check your .env or environment.")
return OpenAIEmbeddings(api_key=api_key, model=model or DEFAULT_EMBED_MODEL, chunk_size=32)
def _sha256(text: str) -> str:
import hashlib
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def _bill_id(b: Dict[str, Any]) -> str:
return f"{b.get('state','Unknown')}_{b.get('bill_number','Unknown')}"
def _bill_text(b: Dict[str, Any]) -> str:
title = b.get("title") or ""
summary = b.get("description") or ""
txt = b.get("text") or ""
return f"Title: {title}\n\nSummary: {summary}\n\nFull Text:\n{txt}"
def _bill_hash(b: Dict[str, Any]) -> str:
payload = json.dumps({
"title": b.get("title"),
"description": b.get("description"),
"text": b.get("text"),
"status": b.get("status"),
"last_action_date": b.get("last_action_date"),
}, ensure_ascii=False, sort_keys=True)
return _sha256(payload)
def _manifest_load(path: str) -> Dict[str, Dict[str, str]]:
p = Path(path)
if not p.exists():
return {}
try:
return json.loads(p.read_text(encoding="utf-8"))
except Exception:
return {}
def _manifest_save(path: str, data: Dict[str, Dict[str, str]]) -> None:
Path(path).parent.mkdir(parents=True, exist_ok=True)
Path(path).write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
def _clean_metadata(meta: Dict[str, Any]) -> Dict[str, Any]:
"""Keep only metadata values that Chroma accepts: str/int/float/bool and not None."""
allowed_types = (str, int, float, bool)
cleaned: Dict[str, Any] = {}
for k, v in meta.items():
if v is None:
continue
if isinstance(v, allowed_types):
cleaned[k] = v
else:
# If you prefer to drop complex types instead of stringifying, replace with `continue`
cleaned[k] = str(v)
return cleaned
def _make_doc(b: Dict[str, Any]) -> Document:
sponsors_list = b.get("sponsors") or []
if isinstance(sponsors_list, list):
sponsors_str = "; ".join(map(str, sponsors_list))
else:
sponsors_str = str(sponsors_list) if sponsors_list else ""
flat_iapp = []
iapp = b.get("iapp_categories")
if isinstance(iapp, dict):
for k, v in iapp.items():
if isinstance(v, list):
for sub in v:
flat_iapp.append(f"{k}:{sub}")
iapp_str = "; ".join(flat_iapp) if flat_iapp else ""
meta = {
"doc_id": _bill_id(b),
"state": b.get("state"),
"session_year": b.get("session_year"),
"legislative_body": b.get("chamber") or b.get("legislative_body") or None,
"status": b.get("status"),
"title": b.get("title"),
"bill_number": b.get("bill_number"),
"sponsors": sponsors_str,
"last_action_date": b.get("last_action_date"),
"iapp_flat": iapp_str,
}
meta = _clean_metadata(meta)
return Document(page_content=_bill_text(b), metadata=meta)
meta = {k: v for k, v in meta.items() if v is not None}
return Document(page_content=_bill_text(b), metadata=meta)
meta = {k: v for k, v in meta.items() if v is not None}
return Document(page_content=_bill_text(b), metadata=meta)
def _load_bills(source_json_path: str) -> List[Dict[str, Any]]:
data = json.loads(Path(source_json_path).read_text(encoding="utf-8"))
if not isinstance(data, list):
raise ValueError(f"{source_json_path} must contain a list of bills")
return data
def load_vectorstore(
persist_dir: str = DEFAULT_PERSIST_DIR,
collection: str = DEFAULT_COLLECTION,
embeddings: Optional[OpenAIEmbeddings] = None,
) -> Chroma:
embeddings = embeddings or get_embeddings()
Path(persist_dir).mkdir(parents=True, exist_ok=True)
return Chroma(
collection_name=collection,
persist_directory=persist_dir,
embedding_function=embeddings,
)
def _chunk_bill(b: Dict[str, Any], *, size: int = 1500, overlap: int = 200) -> List[Document]:
text = _bill_text(b)
splitter = RecursiveCharacterTextSplitter(
chunk_size=size, chunk_overlap=overlap,
separators=["\n\n", "\n", ". ", " ", ""]
)
pieces = splitter.split_text(text) or ["(no content)"]
docs: List[Document] = []
base_meta = {
"doc_id": _bill_id(b),
"state": b.get("state"),
"session_year": b.get("session_year"),
"legislative_body": b.get("chamber") or b.get("legislative_body") or None,
"status": b.get("status"),
"title": b.get("title"),
"bill_number": b.get("bill_number"),
"sponsors": (("; ".join(map(str, b.get("sponsors") or [])))
if isinstance(b.get("sponsors"), list)
else (b.get("sponsors") or "")),
"last_action_date": b.get("last_action_date"),
}
iapp = b.get("iapp_categories") or {}
flat = []
if isinstance(iapp, dict):
for k, v in iapp.items():
if isinstance(v, list):
for sub in v:
flat.append(f"{k}:{sub}")
base_meta["iapp_flat"] = "; ".join(flat)
# 🔑 Clean out None / bad types before using this as metadata
base_meta = _clean_metadata(base_meta)
total = len(pieces)
for i, chunk in enumerate(pieces):
m = dict(base_meta)
m["chunk_index"] = i
m["chunk_total"] = total
docs.append(Document(page_content=chunk, metadata=m))
return docs
def upsert_from_bills_json(
source_json_path: str = "data/known_bills_visualize.json",
persist_dir: str = DEFAULT_PERSIST_DIR,
collection: str = DEFAULT_COLLECTION,
manifest_path: str = DEFAULT_MANIFEST,
embed_model: Optional[str] = None,
batch_size: int = 128,
) -> Dict[str, int]:
t0 = time.time()
bills = _load_bills(source_json_path)
embeddings = get_embeddings(embed_model)
vs = load_vectorstore(persist_dir, collection, embeddings)
manifest = _manifest_load(manifest_path)
manifest_meta = manifest.get("_meta", {})
if manifest_meta.get("embed_model") != (embed_model or DEFAULT_EMBED_MODEL):
manifest = {}
manifest["_meta"] = {"embed_model": embed_model or DEFAULT_EMBED_MODEL}
to_docs, to_ids = [], []
added, skipped = 0, 0
for b in bills:
if not (b.get("text") or b.get("description") or b.get("title")):
skipped += 1
continue
doc_id = _bill_id(b)
hsh = _bill_hash(b)
if manifest.get(doc_id, {}).get("hash") == hsh:
skipped += 1
continue
try:
vs.delete(where={"doc_id": doc_id})
except Exception:
pass
chunks = _chunk_bill(b)
for d in chunks:
to_docs.append(d)
to_ids.append(f"{doc_id}::c{d.metadata['chunk_index']}")
if len(to_docs) >= batch_size:
vs.add_documents(documents=to_docs, ids=to_ids)
to_docs, to_ids = [], []
manifest[doc_id] = {"hash": hsh}
added += 1
if to_docs:
vs.add_documents(documents=to_docs, ids=to_ids)
if hasattr(vs, "persist"):
vs.persist()
manifest["_meta"] = {"embed_model": embed_model or DEFAULT_EMBED_MODEL}
_manifest_save(manifest_path, manifest)
return {
"total_bills": len(bills),
"embedded": added,
"skipped_unchanged": skipped,
"elapsed_sec": int(time.time() - t0),
}
def get_retriever(persist_dir=DEFAULT_PERSIST_DIR, collection=DEFAULT_COLLECTION, k=8, filter_kwargs=None):
vs = load_vectorstore(persist_dir=persist_dir, collection=collection)
search_kwargs = {"k": k}
if filter_kwargs:
search_kwargs["filter"] = filter_kwargs
return vs.as_retriever(search_kwargs=search_kwargs)
def similarity_search(
query: str,
k: int = 5,
where: Optional[Dict[str, Any]] = None,
persist_dir: str = DEFAULT_PERSIST_DIR,
collection: str = DEFAULT_COLLECTION,
):
vs = load_vectorstore(persist_dir=persist_dir, collection=collection)
filt = where if (where and len(where) > 0) else None # <-- key line
return vs.similarity_search(query, k=k, filter=filt)