legislation-tracker / vectorstore /pinecone_bills_vectorstore.py
ramanna's picture
Upload 7 files
b501a8e verified
# vectorstore/pinecone_bills_vectorstore.py
from __future__ import annotations
import os, json, time
from pathlib import Path
from typing import Dict, List, Optional, Any
from datetime import datetime
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from pinecone import Pinecone
from langchain_pinecone import PineconeVectorStore
from dotenv import load_dotenv
from pathlib import Path as _Path
load_dotenv(dotenv_path=_Path.cwd() / ".env")
DEFAULT_EMBED_MODEL = os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small")
DEFAULT_COLLECTION = os.getenv("PINECONE_INDEX", "legislation-tracker")
DEFAULT_MANIFEST = "data/bills_vectorstore_manifest.json"
# Use empty string for namespace if not specified (Pinecone default)
DEFAULT_NAMESPACE = os.getenv("PINECONE_NAMESPACE", "")
def get_embeddings(model: Optional[str] = None) -> OpenAIEmbeddings:
key = os.getenv("OPENAI_API_KEY")
if not key:
raise RuntimeError("OPENAI_API_KEY not set")
return OpenAIEmbeddings(api_key=key, model=model or DEFAULT_EMBED_MODEL, chunk_size=32)
def _clean_meta(m: dict) -> dict:
out = {}
for k, v in m.items():
if v is None:
continue
if isinstance(v, (str, bool, int, float)):
out[k] = v
elif isinstance(v, (list, tuple)):
out[k] = [str(x) for x in v if x is not None]
else:
out[k] = str(v)
return out
def _sha256(text: str) -> str:
import hashlib
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def _bill_id(b: Dict[str, Any]) -> str:
return f"{b.get('state','Unknown')}_{b.get('bill_number','Unknown')}"
def _bill_text(b: Dict[str, Any]) -> str:
title = b.get("title") or ""
summary = b.get("description") or ""
txt = b.get("text") or ""
return f"Title: {title}\n\nSummary: {summary}\n\nFull Text:\n{txt}"
def _bill_hash(b: Dict[str, Any]) -> str:
payload = json.dumps({
"title": b.get("title"),
"description": b.get("description"),
"text": b.get("text"),
"status": b.get("status"),
"last_action_date": b.get("last_action_date"),
}, ensure_ascii=False, sort_keys=True)
return _sha256(payload)
def _manifest_load(path: str) -> Dict[str, Dict[str, str]]:
p = Path(path)
if not p.exists():
return {}
try:
return json.loads(p.read_text(encoding="utf-8"))
except Exception:
return {}
def _manifest_save(path: str, data: Dict[str, Dict[str, str]]) -> None:
Path(path).parent.mkdir(parents=True, exist_ok=True)
Path(path).write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
def _flatten_iapp(iapp: Any) -> list[str]:
flat: List[str] = []
if isinstance(iapp, dict):
for k, v in iapp.items():
if isinstance(v, list):
for sub in v:
flat.append(f"{k}:{sub}")
return flat # keep as list[str] for Pinecone $in filters
def _parse_session_years(val) -> tuple[int | None, int | None]:
"""
Accepts:
- string like '2023-2024'
- dict with keys 'year_start'/'year_end'
Returns (start, end) as ints or (None, None)
"""
if isinstance(val, str) and "-" in val:
try:
a, b = val.split("-", 1)
return int(a), int(b)
except Exception:
return None, None
if isinstance(val, dict):
try:
return int(val.get("year_start")), int(val.get("year_end"))
except Exception:
return None, None
return None, None
def _to_epoch(date_str: str | None) -> int | None:
"""
Accepts YYYY-MM-DD or ISO-8601; returns Unix epoch seconds or None
"""
if not date_str:
return None
try:
ds = date_str.replace("Z", "")
return int(datetime.fromisoformat(ds).timestamp())
except Exception:
return None
def _chunk_bill(b: Dict[str, Any], *, size: int = 1500, overlap: int = 200) -> List[Document]:
splitter = RecursiveCharacterTextSplitter(
chunk_size=size, chunk_overlap=overlap,
separators=["\n\n", "\n", ". ", " ", ""]
)
text = _bill_text(b)
pieces = splitter.split_text(text) or ["(no content)"]
iapp_list = _flatten_iapp(b.get("iapp_categories"))
sy_start, sy_end = _parse_session_years(b.get("session_year"))
last_action_date = b.get("last_action_date")
last_action_ts = _to_epoch(last_action_date)
base_meta = {
"doc_id": _bill_id(b),
"state": b.get("state"),
"session_year": b.get("session_year"),
"session_year_start": sy_start,
"session_year_end": sy_end,
"legislative_body": b.get("chamber") or b.get("legislative_body") or "",
"status": b.get("status"),
"title": b.get("title"),
"bill_number": b.get("bill_number"),
"sponsors": (("; ".join(map(str, b.get("sponsors") or []))) if isinstance(b.get("sponsors"), list) else (b.get("sponsors") or "")),
"last_action_date": last_action_date,
"last_action_ts": last_action_ts,
"iapp_flat": iapp_list, # list[str]
}
docs: List[Document] = []
total = len(pieces)
for i, chunk in enumerate(pieces):
m = dict(base_meta)
m["chunk_index"] = i
m["chunk_total"] = total
m["text"] = chunk
m = _clean_meta(m)
docs.append(Document(page_content=chunk, metadata=m))
return docs
def _load_bills(path: str) -> List[Dict[str, Any]]:
import json as _json
from pathlib import Path as _P
data = _json.loads(_P(path).read_text(encoding="utf-8"))
if not isinstance(data, list):
raise ValueError(f"{path} must contain a list of bills")
return data
def _vectorstore(embeddings: OpenAIEmbeddings) -> PineconeVectorStore:
pc = Pinecone(api_key=os.environ["PINECONE_API_KEY"])
index_name = os.getenv("PINECONE_INDEX", DEFAULT_COLLECTION)
index = pc.Index(index_name)
namespace = DEFAULT_NAMESPACE if DEFAULT_NAMESPACE else None
return PineconeVectorStore(index=index, embedding=embeddings, namespace=namespace)
def upsert_from_bills_json(
source_json_path: str = "data/known_bills_visualize.json",
manifest_path: str = DEFAULT_MANIFEST,
embed_model: Optional[str] = None,
batch_size: int = 128,
) -> Dict[str, int]:
t0 = time.time()
bills = _load_bills(source_json_path)
embeddings = get_embeddings(embed_model)
vs = _vectorstore(embeddings)
manifest = _manifest_load(manifest_path)
meta = manifest.get("_meta", {})
model_in_use = embed_model or DEFAULT_EMBED_MODEL
if meta.get("embed_model") != model_in_use:
manifest = {"_meta": {"embed_model": model_in_use}}
to_upsert: List[Document] = []
added, skipped = 0, 0
for b in bills:
if not (b.get("text") or b.get("description") or b.get("title")):
skipped += 1
continue
doc_id = _bill_id(b)
hsh = _bill_hash(b)
if manifest.get(doc_id, {}).get("hash") == hsh:
skipped += 1
continue
try:
vs.delete(filter={"doc_id": doc_id})
except Exception:
pass
for d in _chunk_bill(b):
to_upsert.append(d)
if len(to_upsert) >= batch_size:
vs.add_documents(documents=to_upsert)
to_upsert = []
manifest[doc_id] = {"hash": hsh}
added += 1
if to_upsert:
vs.add_documents(documents=to_upsert)
manifest["_meta"] = {"embed_model": model_in_use}
_manifest_save(manifest_path, manifest)
return {
"total_bills": len(bills),
"embedded": added,
"skipped_unchanged": skipped,
"elapsed_sec": int(time.time() - t0),
}
def get_retriever(k=8, filter_kwargs: Optional[Dict[str, Any]] = None):
embeddings = get_embeddings()
vs = _vectorstore(embeddings)
kwargs = {"k": k}
if filter_kwargs:
kwargs["filter"] = filter_kwargs
return vs.as_retriever(search_kwargs=kwargs)
def similarity_search(query: str, k: int = 5, where: Optional[Dict[str, Any]] = None):
embeddings = get_embeddings()
vs = _vectorstore(embeddings)
return vs.similarity_search(query, k=k, filter=where or None)