Spaces:
Runtime error
Runtime error
File size: 5,888 Bytes
61411b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | from __future__ import annotations
import logging
import os
from typing import Any, Dict, List, Optional
from ai_business_automation_agent.embeddings.embedding_model import embed_texts
logger = logging.getLogger(__name__)
class PineconeVectorStore:
"""
Minimal Pinecone wrapper for policy/compliance retrieval.
Supports both:
- pinecone-client (legacy) import style: import pinecone
- newer pinecone SDK import style: from pinecone import Pinecone
"""
def __init__(
self,
*,
api_key: Optional[str] = None,
index_name: Optional[str] = None,
cloud: Optional[str] = None,
region: Optional[str] = None,
namespace: str = "policies",
) -> None:
self.api_key = api_key or os.getenv("PINECONE_API_KEY", "")
self.index_name = index_name or os.getenv("PINECONE_INDEX_NAME", "ai-bpa-agent")
self.cloud = cloud or os.getenv("PINECONE_CLOUD", "aws")
self.region = region or os.getenv("PINECONE_REGION", "us-east-1")
self.namespace = namespace
if not self.api_key:
raise ValueError("Missing PINECONE_API_KEY.")
self._index = self._init_index()
def _init_index(self):
# Newer SDK
try:
from pinecone import Pinecone # type: ignore
pc = Pinecone(api_key=self.api_key)
# list_indexes shape varies by pinecone SDK version
raw = pc.list_indexes() # type: ignore[call-arg]
existing: set[str] = set()
if isinstance(raw, dict):
for i in raw.get("indexes", []) or []:
if isinstance(i, dict) and i.get("name"):
existing.add(str(i["name"]))
elif isinstance(raw, list):
for i in raw:
if isinstance(i, str):
existing.add(i)
else:
name = getattr(i, "name", None)
if name:
existing.add(str(name))
else:
# Some versions return an object with `.indexes`
indexes = getattr(raw, "indexes", None)
if isinstance(indexes, list):
for i in indexes:
if isinstance(i, dict) and i.get("name"):
existing.add(str(i["name"]))
else:
name = getattr(i, "name", None)
if name:
existing.add(str(name))
if self.index_name not in existing:
logger.info("Creating Pinecone index '%s' (cloud=%s region=%s)", self.index_name, self.cloud, self.region)
pc.create_index(
name=self.index_name,
dimension=384,
metric="cosine",
spec={"serverless": {"cloud": self.cloud, "region": self.region}},
)
return pc.Index(self.index_name)
except Exception:
pass
# Legacy pinecone-client
import pinecone # type: ignore
pinecone.init(api_key=self.api_key, environment=os.getenv("PINECONE_ENVIRONMENT", ""))
if self.index_name not in pinecone.list_indexes():
logger.info("Creating Pinecone index '%s' (legacy)", self.index_name)
pinecone.create_index(self.index_name, dimension=384, metric="cosine")
return pinecone.Index(self.index_name)
def seed_default_policies(self) -> None:
"""
Idempotently seed a small set of example policy/rule documents.
In production, replace this with your real corp policies and compliance corpus.
"""
docs = [
(
"policy-1",
"Invoices must include invoice number, invoice date, vendor name, and total amount.",
{"type": "policy", "topic": "required_fields"},
),
(
"policy-2",
"If vendor is flagged or unknown, route invoice to manual review or reject based on risk severity.",
{"type": "policy", "topic": "vendor_risk"},
),
(
"rule-1",
"Reject invoices where subtotal + tax differs from total by more than 0.02 (rounding tolerance).",
{"type": "rule", "topic": "totals_consistency"},
),
(
"rule-2",
"For high-severity compliance issues (e.g., missing total, missing invoice number), reject the invoice.",
{"type": "rule", "topic": "compliance"},
),
]
texts = [d[1] for d in docs]
vectors = embed_texts(texts)
upserts = []
for (doc_id, text, meta), vec in zip(docs, vectors):
upserts.append({"id": doc_id, "values": vec, "metadata": {"text": text, **meta}})
self._index.upsert(vectors=upserts, namespace=self.namespace)
def retrieve(self, query: str, *, top_k: int = 5) -> List[Dict[str, Any]]:
vec = embed_texts([query])[0]
res = self._index.query(vector=vec, top_k=top_k, include_metadata=True, namespace=self.namespace)
matches = res.get("matches", []) if isinstance(res, dict) else getattr(res, "matches", [])
out: List[Dict[str, Any]] = []
for m in matches:
md = m.get("metadata", {}) if isinstance(m, dict) else getattr(m, "metadata", {}) # type: ignore
score = m.get("score") if isinstance(m, dict) else getattr(m, "score", None) # type: ignore
out.append({"score": score, "text": md.get("text"), "metadata": md})
return out
|