from __future__ import annotations import logging import os from typing import Any, Dict, List, Optional from ai_business_automation_agent.embeddings.embedding_model import embed_texts logger = logging.getLogger(__name__) class PineconeVectorStore: """ Minimal Pinecone wrapper for policy/compliance retrieval. Supports both: - pinecone-client (legacy) import style: import pinecone - newer pinecone SDK import style: from pinecone import Pinecone """ def __init__( self, *, api_key: Optional[str] = None, index_name: Optional[str] = None, cloud: Optional[str] = None, region: Optional[str] = None, namespace: str = "policies", ) -> None: self.api_key = api_key or os.getenv("PINECONE_API_KEY", "") self.index_name = index_name or os.getenv("PINECONE_INDEX_NAME", "ai-bpa-agent") self.cloud = cloud or os.getenv("PINECONE_CLOUD", "aws") self.region = region or os.getenv("PINECONE_REGION", "us-east-1") self.namespace = namespace if not self.api_key: raise ValueError("Missing PINECONE_API_KEY.") self._index = self._init_index() def _init_index(self): # Newer SDK try: from pinecone import Pinecone # type: ignore pc = Pinecone(api_key=self.api_key) # list_indexes shape varies by pinecone SDK version raw = pc.list_indexes() # type: ignore[call-arg] existing: set[str] = set() if isinstance(raw, dict): for i in raw.get("indexes", []) or []: if isinstance(i, dict) and i.get("name"): existing.add(str(i["name"])) elif isinstance(raw, list): for i in raw: if isinstance(i, str): existing.add(i) else: name = getattr(i, "name", None) if name: existing.add(str(name)) else: # Some versions return an object with `.indexes` indexes = getattr(raw, "indexes", None) if isinstance(indexes, list): for i in indexes: if isinstance(i, dict) and i.get("name"): existing.add(str(i["name"])) else: name = getattr(i, "name", None) if name: existing.add(str(name)) if self.index_name not in existing: logger.info("Creating Pinecone index '%s' (cloud=%s region=%s)", self.index_name, self.cloud, self.region) pc.create_index( name=self.index_name, dimension=384, metric="cosine", spec={"serverless": {"cloud": self.cloud, "region": self.region}}, ) return pc.Index(self.index_name) except Exception: pass # Legacy pinecone-client import pinecone # type: ignore pinecone.init(api_key=self.api_key, environment=os.getenv("PINECONE_ENVIRONMENT", "")) if self.index_name not in pinecone.list_indexes(): logger.info("Creating Pinecone index '%s' (legacy)", self.index_name) pinecone.create_index(self.index_name, dimension=384, metric="cosine") return pinecone.Index(self.index_name) def seed_default_policies(self) -> None: """ Idempotently seed a small set of example policy/rule documents. In production, replace this with your real corp policies and compliance corpus. """ docs = [ ( "policy-1", "Invoices must include invoice number, invoice date, vendor name, and total amount.", {"type": "policy", "topic": "required_fields"}, ), ( "policy-2", "If vendor is flagged or unknown, route invoice to manual review or reject based on risk severity.", {"type": "policy", "topic": "vendor_risk"}, ), ( "rule-1", "Reject invoices where subtotal + tax differs from total by more than 0.02 (rounding tolerance).", {"type": "rule", "topic": "totals_consistency"}, ), ( "rule-2", "For high-severity compliance issues (e.g., missing total, missing invoice number), reject the invoice.", {"type": "rule", "topic": "compliance"}, ), ] texts = [d[1] for d in docs] vectors = embed_texts(texts) upserts = [] for (doc_id, text, meta), vec in zip(docs, vectors): upserts.append({"id": doc_id, "values": vec, "metadata": {"text": text, **meta}}) self._index.upsert(vectors=upserts, namespace=self.namespace) def retrieve(self, query: str, *, top_k: int = 5) -> List[Dict[str, Any]]: vec = embed_texts([query])[0] res = self._index.query(vector=vec, top_k=top_k, include_metadata=True, namespace=self.namespace) matches = res.get("matches", []) if isinstance(res, dict) else getattr(res, "matches", []) out: List[Dict[str, Any]] = [] for m in matches: md = m.get("metadata", {}) if isinstance(m, dict) else getattr(m, "metadata", {}) # type: ignore score = m.get("score") if isinstance(m, dict) else getattr(m, "score", None) # type: ignore out.append({"score": score, "text": md.get("text"), "metadata": md}) return out