File size: 5,888 Bytes
61411b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from __future__ import annotations

import logging
import os
from typing import Any, Dict, List, Optional

from ai_business_automation_agent.embeddings.embedding_model import embed_texts

logger = logging.getLogger(__name__)


class PineconeVectorStore:
    """

    Minimal Pinecone wrapper for policy/compliance retrieval.



    Supports both:

    - pinecone-client (legacy) import style: import pinecone

    - newer pinecone SDK import style: from pinecone import Pinecone

    """

    def __init__(

        self,

        *,

        api_key: Optional[str] = None,

        index_name: Optional[str] = None,

        cloud: Optional[str] = None,

        region: Optional[str] = None,

        namespace: str = "policies",

    ) -> None:
        self.api_key = api_key or os.getenv("PINECONE_API_KEY", "")
        self.index_name = index_name or os.getenv("PINECONE_INDEX_NAME", "ai-bpa-agent")
        self.cloud = cloud or os.getenv("PINECONE_CLOUD", "aws")
        self.region = region or os.getenv("PINECONE_REGION", "us-east-1")
        self.namespace = namespace

        if not self.api_key:
            raise ValueError("Missing PINECONE_API_KEY.")

        self._index = self._init_index()

    def _init_index(self):
        # Newer SDK
        try:
            from pinecone import Pinecone  # type: ignore

            pc = Pinecone(api_key=self.api_key)
            # list_indexes shape varies by pinecone SDK version
            raw = pc.list_indexes()  # type: ignore[call-arg]
            existing: set[str] = set()
            if isinstance(raw, dict):
                for i in raw.get("indexes", []) or []:
                    if isinstance(i, dict) and i.get("name"):
                        existing.add(str(i["name"]))
            elif isinstance(raw, list):
                for i in raw:
                    if isinstance(i, str):
                        existing.add(i)
                    else:
                        name = getattr(i, "name", None)
                        if name:
                            existing.add(str(name))
            else:
                # Some versions return an object with `.indexes`
                indexes = getattr(raw, "indexes", None)
                if isinstance(indexes, list):
                    for i in indexes:
                        if isinstance(i, dict) and i.get("name"):
                            existing.add(str(i["name"]))
                        else:
                            name = getattr(i, "name", None)
                            if name:
                                existing.add(str(name))
            if self.index_name not in existing:
                logger.info("Creating Pinecone index '%s' (cloud=%s region=%s)", self.index_name, self.cloud, self.region)
                pc.create_index(
                    name=self.index_name,
                    dimension=384,
                    metric="cosine",
                    spec={"serverless": {"cloud": self.cloud, "region": self.region}},
                )
            return pc.Index(self.index_name)
        except Exception:
            pass

        # Legacy pinecone-client
        import pinecone  # type: ignore

        pinecone.init(api_key=self.api_key, environment=os.getenv("PINECONE_ENVIRONMENT", ""))
        if self.index_name not in pinecone.list_indexes():
            logger.info("Creating Pinecone index '%s' (legacy)", self.index_name)
            pinecone.create_index(self.index_name, dimension=384, metric="cosine")
        return pinecone.Index(self.index_name)

    def seed_default_policies(self) -> None:
        """

        Idempotently seed a small set of example policy/rule documents.

        In production, replace this with your real corp policies and compliance corpus.

        """

        docs = [
            (
                "policy-1",
                "Invoices must include invoice number, invoice date, vendor name, and total amount.",
                {"type": "policy", "topic": "required_fields"},
            ),
            (
                "policy-2",
                "If vendor is flagged or unknown, route invoice to manual review or reject based on risk severity.",
                {"type": "policy", "topic": "vendor_risk"},
            ),
            (
                "rule-1",
                "Reject invoices where subtotal + tax differs from total by more than 0.02 (rounding tolerance).",
                {"type": "rule", "topic": "totals_consistency"},
            ),
            (
                "rule-2",
                "For high-severity compliance issues (e.g., missing total, missing invoice number), reject the invoice.",
                {"type": "rule", "topic": "compliance"},
            ),
        ]

        texts = [d[1] for d in docs]
        vectors = embed_texts(texts)
        upserts = []
        for (doc_id, text, meta), vec in zip(docs, vectors):
            upserts.append({"id": doc_id, "values": vec, "metadata": {"text": text, **meta}})

        self._index.upsert(vectors=upserts, namespace=self.namespace)

    def retrieve(self, query: str, *, top_k: int = 5) -> List[Dict[str, Any]]:
        vec = embed_texts([query])[0]
        res = self._index.query(vector=vec, top_k=top_k, include_metadata=True, namespace=self.namespace)
        matches = res.get("matches", []) if isinstance(res, dict) else getattr(res, "matches", [])
        out: List[Dict[str, Any]] = []
        for m in matches:
            md = m.get("metadata", {}) if isinstance(m, dict) else getattr(m, "metadata", {})  # type: ignore
            score = m.get("score") if isinstance(m, dict) else getattr(m, "score", None)  # type: ignore
            out.append({"score": score, "text": md.get("text"), "metadata": md})
        return out