Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # vectorstore/pinecone_bills_vectorstore.py | |
| from __future__ import annotations | |
| import os, json, time | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Any | |
| from datetime import datetime | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_core.documents import Document | |
| from pinecone import Pinecone | |
| from langchain_pinecone import PineconeVectorStore | |
| from dotenv import load_dotenv | |
| from pathlib import Path as _Path | |
| load_dotenv(dotenv_path=_Path.cwd() / ".env") | |
| DEFAULT_EMBED_MODEL = os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small") | |
| DEFAULT_COLLECTION = os.getenv("PINECONE_INDEX", "legislation-tracker") | |
| DEFAULT_MANIFEST = "data/bills_vectorstore_manifest.json" | |
| # Use empty string for namespace if not specified (Pinecone default) | |
| DEFAULT_NAMESPACE = os.getenv("PINECONE_NAMESPACE", "") | |
| def get_embeddings(model: Optional[str] = None) -> OpenAIEmbeddings: | |
| key = os.getenv("OPENAI_API_KEY") | |
| if not key: | |
| raise RuntimeError("OPENAI_API_KEY not set") | |
| return OpenAIEmbeddings(api_key=key, model=model or DEFAULT_EMBED_MODEL, chunk_size=32) | |
| def _clean_meta(m: dict) -> dict: | |
| out = {} | |
| for k, v in m.items(): | |
| if v is None: | |
| continue | |
| if isinstance(v, (str, bool, int, float)): | |
| out[k] = v | |
| elif isinstance(v, (list, tuple)): | |
| out[k] = [str(x) for x in v if x is not None] | |
| else: | |
| out[k] = str(v) | |
| return out | |
| def _sha256(text: str) -> str: | |
| import hashlib | |
| return hashlib.sha256(text.encode("utf-8")).hexdigest() | |
| def _bill_id(b: Dict[str, Any]) -> str: | |
| return f"{b.get('state','Unknown')}_{b.get('bill_number','Unknown')}" | |
| def _bill_text(b: Dict[str, Any]) -> str: | |
| title = b.get("title") or "" | |
| summary = b.get("description") or "" | |
| txt = b.get("text") or "" | |
| return f"Title: {title}\n\nSummary: {summary}\n\nFull Text:\n{txt}" | |
| def _bill_hash(b: Dict[str, Any]) -> str: | |
| payload = json.dumps({ | |
| "title": b.get("title"), | |
| "description": b.get("description"), | |
| "text": b.get("text"), | |
| "status": b.get("status"), | |
| "last_action_date": b.get("last_action_date"), | |
| }, ensure_ascii=False, sort_keys=True) | |
| return _sha256(payload) | |
| def _manifest_load(path: str) -> Dict[str, Dict[str, str]]: | |
| p = Path(path) | |
| if not p.exists(): | |
| return {} | |
| try: | |
| return json.loads(p.read_text(encoding="utf-8")) | |
| except Exception: | |
| return {} | |
| def _manifest_save(path: str, data: Dict[str, Dict[str, str]]) -> None: | |
| Path(path).parent.mkdir(parents=True, exist_ok=True) | |
| Path(path).write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8") | |
| def _flatten_iapp(iapp: Any) -> list[str]: | |
| flat: List[str] = [] | |
| if isinstance(iapp, dict): | |
| for k, v in iapp.items(): | |
| if isinstance(v, list): | |
| for sub in v: | |
| flat.append(f"{k}:{sub}") | |
| return flat # keep as list[str] for Pinecone $in filters | |
| def _parse_session_years(val) -> tuple[int | None, int | None]: | |
| """ | |
| Accepts: | |
| - string like '2023-2024' | |
| - dict with keys 'year_start'/'year_end' | |
| Returns (start, end) as ints or (None, None) | |
| """ | |
| if isinstance(val, str) and "-" in val: | |
| try: | |
| a, b = val.split("-", 1) | |
| return int(a), int(b) | |
| except Exception: | |
| return None, None | |
| if isinstance(val, dict): | |
| try: | |
| return int(val.get("year_start")), int(val.get("year_end")) | |
| except Exception: | |
| return None, None | |
| return None, None | |
| def _to_epoch(date_str: str | None) -> int | None: | |
| """ | |
| Accepts YYYY-MM-DD or ISO-8601; returns Unix epoch seconds or None | |
| """ | |
| if not date_str: | |
| return None | |
| try: | |
| ds = date_str.replace("Z", "") | |
| return int(datetime.fromisoformat(ds).timestamp()) | |
| except Exception: | |
| return None | |
| def _chunk_bill(b: Dict[str, Any], *, size: int = 1500, overlap: int = 200) -> List[Document]: | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=size, chunk_overlap=overlap, | |
| separators=["\n\n", "\n", ". ", " ", ""] | |
| ) | |
| text = _bill_text(b) | |
| pieces = splitter.split_text(text) or ["(no content)"] | |
| iapp_list = _flatten_iapp(b.get("iapp_categories")) | |
| sy_start, sy_end = _parse_session_years(b.get("session_year")) | |
| last_action_date = b.get("last_action_date") | |
| last_action_ts = _to_epoch(last_action_date) | |
| base_meta = { | |
| "doc_id": _bill_id(b), | |
| "state": b.get("state"), | |
| "session_year": b.get("session_year"), | |
| "session_year_start": sy_start, | |
| "session_year_end": sy_end, | |
| "legislative_body": b.get("chamber") or b.get("legislative_body") or "", | |
| "status": b.get("status"), | |
| "title": b.get("title"), | |
| "bill_number": b.get("bill_number"), | |
| "sponsors": (("; ".join(map(str, b.get("sponsors") or []))) if isinstance(b.get("sponsors"), list) else (b.get("sponsors") or "")), | |
| "last_action_date": last_action_date, | |
| "last_action_ts": last_action_ts, | |
| "iapp_flat": iapp_list, # list[str] | |
| } | |
| docs: List[Document] = [] | |
| total = len(pieces) | |
| for i, chunk in enumerate(pieces): | |
| m = dict(base_meta) | |
| m["chunk_index"] = i | |
| m["chunk_total"] = total | |
| m["text"] = chunk | |
| m = _clean_meta(m) | |
| docs.append(Document(page_content=chunk, metadata=m)) | |
| return docs | |
| def _load_bills(path: str) -> List[Dict[str, Any]]: | |
| import json as _json | |
| from pathlib import Path as _P | |
| data = _json.loads(_P(path).read_text(encoding="utf-8")) | |
| if not isinstance(data, list): | |
| raise ValueError(f"{path} must contain a list of bills") | |
| return data | |
| def _vectorstore(embeddings: OpenAIEmbeddings) -> PineconeVectorStore: | |
| pc = Pinecone(api_key=os.environ["PINECONE_API_KEY"]) | |
| index_name = os.getenv("PINECONE_INDEX", DEFAULT_COLLECTION) | |
| index = pc.Index(index_name) | |
| namespace = DEFAULT_NAMESPACE if DEFAULT_NAMESPACE else None | |
| return PineconeVectorStore(index=index, embedding=embeddings, namespace=namespace) | |
| def upsert_from_bills_json( | |
| source_json_path: str = "data/known_bills_visualize.json", | |
| manifest_path: str = DEFAULT_MANIFEST, | |
| embed_model: Optional[str] = None, | |
| batch_size: int = 128, | |
| ) -> Dict[str, int]: | |
| t0 = time.time() | |
| bills = _load_bills(source_json_path) | |
| embeddings = get_embeddings(embed_model) | |
| vs = _vectorstore(embeddings) | |
| manifest = _manifest_load(manifest_path) | |
| meta = manifest.get("_meta", {}) | |
| model_in_use = embed_model or DEFAULT_EMBED_MODEL | |
| if meta.get("embed_model") != model_in_use: | |
| manifest = {"_meta": {"embed_model": model_in_use}} | |
| to_upsert: List[Document] = [] | |
| added, skipped = 0, 0 | |
| for b in bills: | |
| if not (b.get("text") or b.get("description") or b.get("title")): | |
| skipped += 1 | |
| continue | |
| doc_id = _bill_id(b) | |
| hsh = _bill_hash(b) | |
| if manifest.get(doc_id, {}).get("hash") == hsh: | |
| skipped += 1 | |
| continue | |
| try: | |
| vs.delete(filter={"doc_id": doc_id}) | |
| except Exception: | |
| pass | |
| for d in _chunk_bill(b): | |
| to_upsert.append(d) | |
| if len(to_upsert) >= batch_size: | |
| vs.add_documents(documents=to_upsert) | |
| to_upsert = [] | |
| manifest[doc_id] = {"hash": hsh} | |
| added += 1 | |
| if to_upsert: | |
| vs.add_documents(documents=to_upsert) | |
| manifest["_meta"] = {"embed_model": model_in_use} | |
| _manifest_save(manifest_path, manifest) | |
| return { | |
| "total_bills": len(bills), | |
| "embedded": added, | |
| "skipped_unchanged": skipped, | |
| "elapsed_sec": int(time.time() - t0), | |
| } | |
| def get_retriever(k=8, filter_kwargs: Optional[Dict[str, Any]] = None): | |
| embeddings = get_embeddings() | |
| vs = _vectorstore(embeddings) | |
| kwargs = {"k": k} | |
| if filter_kwargs: | |
| kwargs["filter"] = filter_kwargs | |
| return vs.as_retriever(search_kwargs=kwargs) | |
| def similarity_search(query: str, k: int = 5, where: Optional[Dict[str, Any]] = None): | |
| embeddings = get_embeddings() | |
| vs = _vectorstore(embeddings) | |
| return vs.similarity_search(query, k=k, filter=where or None) |