Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # vectorstore/bills_vectorstore.py | |
| from __future__ import annotations | |
| import os, json, hashlib, time | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Iterable, Any | |
| from dotenv import load_dotenv, find_dotenv | |
| load_dotenv(find_dotenv()) | |
| try: | |
| from langchain_chroma import Chroma | |
| except Exception: | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain.schema import Document | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| DEFAULT_EMBED_MODEL = os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small") | |
| DEFAULT_PERSIST_DIR = "data/bills_vectorstore" | |
| DEFAULT_COLLECTION = "bills" | |
| DEFAULT_MANIFEST = "data/bills_vectorstore_manifest.json" | |
| def get_embeddings(model: Optional[str] = None) -> OpenAIEmbeddings: | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise RuntimeError("OPENAI_API_KEY is not set. Check your .env or environment.") | |
| return OpenAIEmbeddings(api_key=api_key, model=model or DEFAULT_EMBED_MODEL, chunk_size=32) | |
| def _sha256(text: str) -> str: | |
| import hashlib | |
| return hashlib.sha256(text.encode("utf-8")).hexdigest() | |
| def _bill_id(b: Dict[str, Any]) -> str: | |
| return f"{b.get('state','Unknown')}_{b.get('bill_number','Unknown')}" | |
| def _bill_text(b: Dict[str, Any]) -> str: | |
| title = b.get("title") or "" | |
| summary = b.get("description") or "" | |
| txt = b.get("text") or "" | |
| return f"Title: {title}\n\nSummary: {summary}\n\nFull Text:\n{txt}" | |
| def _bill_hash(b: Dict[str, Any]) -> str: | |
| payload = json.dumps({ | |
| "title": b.get("title"), | |
| "description": b.get("description"), | |
| "text": b.get("text"), | |
| "status": b.get("status"), | |
| "last_action_date": b.get("last_action_date"), | |
| }, ensure_ascii=False, sort_keys=True) | |
| return _sha256(payload) | |
| def _manifest_load(path: str) -> Dict[str, Dict[str, str]]: | |
| p = Path(path) | |
| if not p.exists(): | |
| return {} | |
| try: | |
| return json.loads(p.read_text(encoding="utf-8")) | |
| except Exception: | |
| return {} | |
| def _manifest_save(path: str, data: Dict[str, Dict[str, str]]) -> None: | |
| Path(path).parent.mkdir(parents=True, exist_ok=True) | |
| Path(path).write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8") | |
| def _clean_metadata(meta: Dict[str, Any]) -> Dict[str, Any]: | |
| """Keep only metadata values that Chroma accepts: str/int/float/bool and not None.""" | |
| allowed_types = (str, int, float, bool) | |
| cleaned: Dict[str, Any] = {} | |
| for k, v in meta.items(): | |
| if v is None: | |
| continue | |
| if isinstance(v, allowed_types): | |
| cleaned[k] = v | |
| else: | |
| # If you prefer to drop complex types instead of stringifying, replace with `continue` | |
| cleaned[k] = str(v) | |
| return cleaned | |
| def _make_doc(b: Dict[str, Any]) -> Document: | |
| sponsors_list = b.get("sponsors") or [] | |
| if isinstance(sponsors_list, list): | |
| sponsors_str = "; ".join(map(str, sponsors_list)) | |
| else: | |
| sponsors_str = str(sponsors_list) if sponsors_list else "" | |
| flat_iapp = [] | |
| iapp = b.get("iapp_categories") | |
| if isinstance(iapp, dict): | |
| for k, v in iapp.items(): | |
| if isinstance(v, list): | |
| for sub in v: | |
| flat_iapp.append(f"{k}:{sub}") | |
| iapp_str = "; ".join(flat_iapp) if flat_iapp else "" | |
| meta = { | |
| "doc_id": _bill_id(b), | |
| "state": b.get("state"), | |
| "session_year": b.get("session_year"), | |
| "legislative_body": b.get("chamber") or b.get("legislative_body") or None, | |
| "status": b.get("status"), | |
| "title": b.get("title"), | |
| "bill_number": b.get("bill_number"), | |
| "sponsors": sponsors_str, | |
| "last_action_date": b.get("last_action_date"), | |
| "iapp_flat": iapp_str, | |
| } | |
| meta = _clean_metadata(meta) | |
| return Document(page_content=_bill_text(b), metadata=meta) | |
| meta = {k: v for k, v in meta.items() if v is not None} | |
| return Document(page_content=_bill_text(b), metadata=meta) | |
| meta = {k: v for k, v in meta.items() if v is not None} | |
| return Document(page_content=_bill_text(b), metadata=meta) | |
| def _load_bills(source_json_path: str) -> List[Dict[str, Any]]: | |
| data = json.loads(Path(source_json_path).read_text(encoding="utf-8")) | |
| if not isinstance(data, list): | |
| raise ValueError(f"{source_json_path} must contain a list of bills") | |
| return data | |
| def load_vectorstore( | |
| persist_dir: str = DEFAULT_PERSIST_DIR, | |
| collection: str = DEFAULT_COLLECTION, | |
| embeddings: Optional[OpenAIEmbeddings] = None, | |
| ) -> Chroma: | |
| embeddings = embeddings or get_embeddings() | |
| Path(persist_dir).mkdir(parents=True, exist_ok=True) | |
| return Chroma( | |
| collection_name=collection, | |
| persist_directory=persist_dir, | |
| embedding_function=embeddings, | |
| ) | |
| def _chunk_bill(b: Dict[str, Any], *, size: int = 1500, overlap: int = 200) -> List[Document]: | |
| text = _bill_text(b) | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=size, chunk_overlap=overlap, | |
| separators=["\n\n", "\n", ". ", " ", ""] | |
| ) | |
| pieces = splitter.split_text(text) or ["(no content)"] | |
| docs: List[Document] = [] | |
| base_meta = { | |
| "doc_id": _bill_id(b), | |
| "state": b.get("state"), | |
| "session_year": b.get("session_year"), | |
| "legislative_body": b.get("chamber") or b.get("legislative_body") or None, | |
| "status": b.get("status"), | |
| "title": b.get("title"), | |
| "bill_number": b.get("bill_number"), | |
| "sponsors": (("; ".join(map(str, b.get("sponsors") or []))) | |
| if isinstance(b.get("sponsors"), list) | |
| else (b.get("sponsors") or "")), | |
| "last_action_date": b.get("last_action_date"), | |
| } | |
| iapp = b.get("iapp_categories") or {} | |
| flat = [] | |
| if isinstance(iapp, dict): | |
| for k, v in iapp.items(): | |
| if isinstance(v, list): | |
| for sub in v: | |
| flat.append(f"{k}:{sub}") | |
| base_meta["iapp_flat"] = "; ".join(flat) | |
| # 🔑 Clean out None / bad types before using this as metadata | |
| base_meta = _clean_metadata(base_meta) | |
| total = len(pieces) | |
| for i, chunk in enumerate(pieces): | |
| m = dict(base_meta) | |
| m["chunk_index"] = i | |
| m["chunk_total"] = total | |
| docs.append(Document(page_content=chunk, metadata=m)) | |
| return docs | |
| def upsert_from_bills_json( | |
| source_json_path: str = "data/known_bills_visualize.json", | |
| persist_dir: str = DEFAULT_PERSIST_DIR, | |
| collection: str = DEFAULT_COLLECTION, | |
| manifest_path: str = DEFAULT_MANIFEST, | |
| embed_model: Optional[str] = None, | |
| batch_size: int = 128, | |
| ) -> Dict[str, int]: | |
| t0 = time.time() | |
| bills = _load_bills(source_json_path) | |
| embeddings = get_embeddings(embed_model) | |
| vs = load_vectorstore(persist_dir, collection, embeddings) | |
| manifest = _manifest_load(manifest_path) | |
| manifest_meta = manifest.get("_meta", {}) | |
| if manifest_meta.get("embed_model") != (embed_model or DEFAULT_EMBED_MODEL): | |
| manifest = {} | |
| manifest["_meta"] = {"embed_model": embed_model or DEFAULT_EMBED_MODEL} | |
| to_docs, to_ids = [], [] | |
| added, skipped = 0, 0 | |
| for b in bills: | |
| if not (b.get("text") or b.get("description") or b.get("title")): | |
| skipped += 1 | |
| continue | |
| doc_id = _bill_id(b) | |
| hsh = _bill_hash(b) | |
| if manifest.get(doc_id, {}).get("hash") == hsh: | |
| skipped += 1 | |
| continue | |
| try: | |
| vs.delete(where={"doc_id": doc_id}) | |
| except Exception: | |
| pass | |
| chunks = _chunk_bill(b) | |
| for d in chunks: | |
| to_docs.append(d) | |
| to_ids.append(f"{doc_id}::c{d.metadata['chunk_index']}") | |
| if len(to_docs) >= batch_size: | |
| vs.add_documents(documents=to_docs, ids=to_ids) | |
| to_docs, to_ids = [], [] | |
| manifest[doc_id] = {"hash": hsh} | |
| added += 1 | |
| if to_docs: | |
| vs.add_documents(documents=to_docs, ids=to_ids) | |
| if hasattr(vs, "persist"): | |
| vs.persist() | |
| manifest["_meta"] = {"embed_model": embed_model or DEFAULT_EMBED_MODEL} | |
| _manifest_save(manifest_path, manifest) | |
| return { | |
| "total_bills": len(bills), | |
| "embedded": added, | |
| "skipped_unchanged": skipped, | |
| "elapsed_sec": int(time.time() - t0), | |
| } | |
| def get_retriever(persist_dir=DEFAULT_PERSIST_DIR, collection=DEFAULT_COLLECTION, k=8, filter_kwargs=None): | |
| vs = load_vectorstore(persist_dir=persist_dir, collection=collection) | |
| search_kwargs = {"k": k} | |
| if filter_kwargs: | |
| search_kwargs["filter"] = filter_kwargs | |
| return vs.as_retriever(search_kwargs=search_kwargs) | |
| def similarity_search( | |
| query: str, | |
| k: int = 5, | |
| where: Optional[Dict[str, Any]] = None, | |
| persist_dir: str = DEFAULT_PERSIST_DIR, | |
| collection: str = DEFAULT_COLLECTION, | |
| ): | |
| vs = load_vectorstore(persist_dir=persist_dir, collection=collection) | |
| filt = where if (where and len(where) > 0) else None # <-- key line | |
| return vs.similarity_search(query, k=k, filter=filt) |