parthib07's picture
Upload 52 files
61411b5 verified
from __future__ import annotations
import logging
import os
from typing import Any, Dict, List, Optional
from ai_business_automation_agent.embeddings.embedding_model import embed_texts
logger = logging.getLogger(__name__)
class PineconeVectorStore:
"""
Minimal Pinecone wrapper for policy/compliance retrieval.
Supports both:
- pinecone-client (legacy) import style: import pinecone
- newer pinecone SDK import style: from pinecone import Pinecone
"""
def __init__(
self,
*,
api_key: Optional[str] = None,
index_name: Optional[str] = None,
cloud: Optional[str] = None,
region: Optional[str] = None,
namespace: str = "policies",
) -> None:
self.api_key = api_key or os.getenv("PINECONE_API_KEY", "")
self.index_name = index_name or os.getenv("PINECONE_INDEX_NAME", "ai-bpa-agent")
self.cloud = cloud or os.getenv("PINECONE_CLOUD", "aws")
self.region = region or os.getenv("PINECONE_REGION", "us-east-1")
self.namespace = namespace
if not self.api_key:
raise ValueError("Missing PINECONE_API_KEY.")
self._index = self._init_index()
def _init_index(self):
# Newer SDK
try:
from pinecone import Pinecone # type: ignore
pc = Pinecone(api_key=self.api_key)
# list_indexes shape varies by pinecone SDK version
raw = pc.list_indexes() # type: ignore[call-arg]
existing: set[str] = set()
if isinstance(raw, dict):
for i in raw.get("indexes", []) or []:
if isinstance(i, dict) and i.get("name"):
existing.add(str(i["name"]))
elif isinstance(raw, list):
for i in raw:
if isinstance(i, str):
existing.add(i)
else:
name = getattr(i, "name", None)
if name:
existing.add(str(name))
else:
# Some versions return an object with `.indexes`
indexes = getattr(raw, "indexes", None)
if isinstance(indexes, list):
for i in indexes:
if isinstance(i, dict) and i.get("name"):
existing.add(str(i["name"]))
else:
name = getattr(i, "name", None)
if name:
existing.add(str(name))
if self.index_name not in existing:
logger.info("Creating Pinecone index '%s' (cloud=%s region=%s)", self.index_name, self.cloud, self.region)
pc.create_index(
name=self.index_name,
dimension=384,
metric="cosine",
spec={"serverless": {"cloud": self.cloud, "region": self.region}},
)
return pc.Index(self.index_name)
except Exception:
pass
# Legacy pinecone-client
import pinecone # type: ignore
pinecone.init(api_key=self.api_key, environment=os.getenv("PINECONE_ENVIRONMENT", ""))
if self.index_name not in pinecone.list_indexes():
logger.info("Creating Pinecone index '%s' (legacy)", self.index_name)
pinecone.create_index(self.index_name, dimension=384, metric="cosine")
return pinecone.Index(self.index_name)
def seed_default_policies(self) -> None:
"""
Idempotently seed a small set of example policy/rule documents.
In production, replace this with your real corp policies and compliance corpus.
"""
docs = [
(
"policy-1",
"Invoices must include invoice number, invoice date, vendor name, and total amount.",
{"type": "policy", "topic": "required_fields"},
),
(
"policy-2",
"If vendor is flagged or unknown, route invoice to manual review or reject based on risk severity.",
{"type": "policy", "topic": "vendor_risk"},
),
(
"rule-1",
"Reject invoices where subtotal + tax differs from total by more than 0.02 (rounding tolerance).",
{"type": "rule", "topic": "totals_consistency"},
),
(
"rule-2",
"For high-severity compliance issues (e.g., missing total, missing invoice number), reject the invoice.",
{"type": "rule", "topic": "compliance"},
),
]
texts = [d[1] for d in docs]
vectors = embed_texts(texts)
upserts = []
for (doc_id, text, meta), vec in zip(docs, vectors):
upserts.append({"id": doc_id, "values": vec, "metadata": {"text": text, **meta}})
self._index.upsert(vectors=upserts, namespace=self.namespace)
def retrieve(self, query: str, *, top_k: int = 5) -> List[Dict[str, Any]]:
vec = embed_texts([query])[0]
res = self._index.query(vector=vec, top_k=top_k, include_metadata=True, namespace=self.namespace)
matches = res.get("matches", []) if isinstance(res, dict) else getattr(res, "matches", [])
out: List[Dict[str, Any]] = []
for m in matches:
md = m.get("metadata", {}) if isinstance(m, dict) else getattr(m, "metadata", {}) # type: ignore
score = m.get("score") if isinstance(m, dict) else getattr(m, "score", None) # type: ignore
out.append({"score": score, "text": md.get("text"), "metadata": md})
return out