File size: 3,350 Bytes
61411b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from __future__ import annotations

import json
import logging
import os
from typing import Any, Dict, List

from langchain_core.messages import HumanMessage

from ai_business_automation_agent.prompts.validation_prompt import VALIDATION_PROMPT
from ai_business_automation_agent.utils import append_agent_log, parse_llm_json
from ai_business_automation_agent.vectorstore.pinecone_client import PineconeVectorStore

logger = logging.getLogger(__name__)


def _format_policy_context(chunks: List[Dict[str, Any]]) -> str:
    if not chunks:
        return "No policy context available."
    lines = []
    for c in chunks:
        score = c.get("score")
        text = (c.get("text") or "").strip()
        if text:
            lines.append(f"- (score={score}) {text}")
    return "\n".join(lines).strip() or "No policy context available."


def run_validation_agent(state: Dict[str, Any], llm) -> Dict[str, Any]:
    extracted = state.get("extracted_data") or {}
    vendor_ver = state.get("vendor_verification") or {}

    policy_context = "No policy context available."
    try:
        vs = PineconeVectorStore(namespace="policies")
        if os.getenv("SEED_VECTORSTORE", "true").lower() in {"1", "true", "yes"}:
            vs.seed_default_policies()
        query = json.dumps(
            {
                "invoice": extracted.get("invoice", {}),
                "vendor": extracted.get("vendor", {}),
                "vendor_verification": vendor_ver,
            },
            ensure_ascii=False,
        )
        chunks = vs.retrieve(query, top_k=5)
        policy_context = _format_policy_context(chunks)
        rag_payload = {"retrieved": chunks}
    except Exception as e:
        logger.warning("Pinecone retrieval unavailable: %s", e)
        rag_payload = {"error": str(e)}

    prompt = VALIDATION_PROMPT.format(
        extracted_json=json.dumps(extracted, ensure_ascii=False),
        vendor_verification_json=json.dumps(vendor_ver, ensure_ascii=False),
        policy_context=policy_context,
    )
    resp = llm.invoke([HumanMessage(content=prompt)])
    text = getattr(resp, "content", str(resp))
    parsed, err = parse_llm_json(text)

    updates: Dict[str, Any] = {}
    if err:
        logger.warning("Validation JSON parse error: %s", err)
        updates["validation_status"] = {
            "status": "needs_review",
            "issues": [{"code": "PARSING_ERROR", "severity": "high", "message": err}],
            "compliance_flags": [],
            "validated_fields": {},
            "recommendation": "manual_review",
            "raw_model_output": text,
            "rag": rag_payload,
        }
        updates.update(append_agent_log(state, agent="validation", event="error", payload={"error": err}))
    else:
        parsed["rag"] = rag_payload
        updates["validation_status"] = parsed
        updates.update(append_agent_log(state, agent="validation", event="ok", payload=parsed))

    updates.update(append_agent_log(state, agent="validation", event="rag", payload=rag_payload))
    updates.update(append_agent_log(state, agent="validation", event="prompt", payload={"prompt": prompt}))
    updates.update(append_agent_log(state, agent="validation", event="raw_response", payload={"text": text}))
    return updates