File size: 17,430 Bytes
cbb1b1a
 
 
 
 
 
 
 
d230384
cbb1b1a
 
 
 
 
 
 
 
 
 
 
 
 
dc0c45b
 
 
 
 
31ce85c
dc0c45b
 
 
cbb1b1a
 
dc0c45b
 
 
 
e45e84e
 
 
 
 
 
31ce85c
dc0c45b
 
31ce85c
 
 
 
 
 
 
 
 
 
 
dc0c45b
 
 
 
cbb1b1a
 
31ce85c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbb1b1a
 
 
dc0c45b
 
 
 
 
 
 
 
 
 
 
31ce85c
 
 
 
 
 
 
 
 
 
 
e45e84e
31ce85c
 
 
00577b6
31ce85c
e45e84e
 
 
 
58caf37
00577b6
dc0c45b
 
 
 
cbb1b1a
 
 
dc0c45b
cbb1b1a
dc0c45b
 
 
 
 
cbb1b1a
dc0c45b
 
 
 
 
 
 
 
 
 
cbb1b1a
 
 
dc0c45b
 
 
 
 
cbb1b1a
dc0c45b
 
 
 
 
cbb1b1a
31ce85c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbb1b1a
dc0c45b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b016462
dc0c45b
 
 
 
 
 
 
31ce85c
 
 
 
 
dc0c45b
31ce85c
 
dc0c45b
 
31ce85c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc0c45b
31ce85c
 
 
 
dc0c45b
31ce85c
dc0c45b
 
b016462
 
 
dc0c45b
 
 
 
 
 
b016462
dc0c45b
31ce85c
 
 
 
 
 
 
 
 
 
dc0c45b
 
 
 
 
b016462
dc0c45b
 
 
 
 
 
 
 
 
 
 
 
b016462
dc0c45b
 
 
 
 
 
 
 
 
31ce85c
 
 
dc0c45b
 
31ce85c
 
 
 
dc0c45b
 
 
31ce85c
 
 
dc0c45b
 
31ce85c
dc0c45b
 
 
 
31ce85c
dc0c45b
 
 
 
 
 
31ce85c
dc0c45b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
"""
Claude Managed Agent (CMA) β€” orchestration layer.

Maintains per-session state:
    - conversation history (list of Anthropic message dicts)
    - SessionMemory (extracted entities, domain, draft versions, etc.)
    - uploaded document paths

Model: env-driven via GUIDE_MODEL (default claude-sonnet-4-6)

Decision flow per user turn:
    1. Presidio redaction happens at the API layer before this module is called.
    2. Is domain known?  No β†’ call classify_domain.
    3. Are minimum fields complete?  No β†’ ask one follow-up question.
    4. Was a document uploaded?  Yes β†’ call process_document, merge entities.
    5. HITL gate: present summary β†’ wait for [USER CONFIRMED].
    6. [USER CONFIRMED] received β†’ call draft_complaint.
    7. User asks next steps β†’ call recommend_action.

The agent loop runs until the model returns a non-tool_use stop reason.
"""

from __future__ import annotations

import json
import logging
import os
import time

import anthropic

from src.agent.memory import SessionMemory
from src.agent.prompts import SYSTEM_PROMPT
from src.agent.tools import TOOL_DEFINITIONS, execute_tool

logger = logging.getLogger(__name__)

# Model name is env-driven so the same code runs against two backends:
#   β€’ Local via the athena LiteLLM gateway, which only exposes Bedrock model
#     IDs (e.g. bedrock-claude-3-5-haiku-20241022-v1:0) β€” set GUIDE_MODEL in .env.
#   β€’ HF Spaces / public Anthropic API β€” leave GUIDE_MODEL unset to use the
#     default claude-sonnet-4-6.
_MODEL = os.getenv("GUIDE_MODEL", "claude-sonnet-4-6")
_MAX_TOKENS = 4096
# Hard cap on tool-use rounds per turn to prevent runaway loops.
_MAX_TOOL_ROUNDS = 12
# Send only the last N messages to keep input tokens low on rate-limited keys.
# SessionMemory preserves all extracted entities so context is not lost.
_MAX_HISTORY_TURNS = 10
# Retry on 429 (RPM/TPM) β€” opt-in via .env so local devs can tune without
# committing. Retry is enabled ONLY when GUIDE_RETRY_MAX is set to a value > 0;
# otherwise the agent uses a plain stream (no backoff) on top of the
# Anthropic SDK's own built-in retries.
_RETRY_MAX = int(os.getenv("GUIDE_RETRY_MAX", "0"))
_RETRY_BASE_DELAY = float(os.getenv("GUIDE_RETRY_BASE_DELAY", "10.0"))  # seconds
_RETRY_MAX_DELAY = float(os.getenv("GUIDE_RETRY_MAX_DELAY", "60.0"))    # seconds
_RETRY_ENABLED = _RETRY_MAX > 0
_FALLBACK_REPLY = (
    "I'm sorry, I encountered an issue processing your request. "
    "Please try again or rephrase your message."
)


def _serialize_block(block) -> dict:
    """Convert an Anthropic SDK content block to a plain API-safe dict.

    model_dump() includes LangSmith-injected fields (e.g. parsed_output) that
    the Anthropic API rejects with 400. Only emit the fields each block type
    actually accepts.
    """
    t = block.type if hasattr(block, "type") else block.get("type")
    if t == "text":
        text = block.text if hasattr(block, "text") else block["text"]
        return {"type": "text", "text": text}
    if t == "tool_use":
        bid = block.id if hasattr(block, "id") else block["id"]
        name = block.name if hasattr(block, "name") else block["name"]
        inp = block.input if hasattr(block, "input") else block["input"]
        return {"type": "tool_use", "id": bid, "name": name, "input": inp}
    # Fallback: strip to primitive types only via JSON round-trip
    import json as _json
    return _json.loads(_json.dumps(block if isinstance(block, dict) else block.model_dump(), default=str))


def _stream_once(client, **kwargs):
    """Plain stream β€” drain the token stream and return the final message.

    Used when retry is disabled (GUIDE_RETRY_MAX unset/0). The Anthropic SDK
    still applies its own built-in retries underneath; this adds no extra
    backoff logic.
    """
    with client.messages.stream(**kwargs) as stream:
        for _chunk in stream.text_stream:
            pass
        return stream.get_final_message()


def _stream_with_retry(client, **kwargs):
    """
    Call client.messages.stream(**kwargs) and retry on 429 RateLimitError.

    Only invoked when retry is enabled (GUIDE_RETRY_MAX > 0). Behaviour is
    controlled by three env vars (set in .env, never committed):
        GUIDE_RETRY_MAX         β€” max attempts after the first (0 = no retry)
        GUIDE_RETRY_BASE_DELAY  β€” initial backoff in seconds (default 10.0)
        GUIDE_RETRY_MAX_DELAY   β€” cap on backoff (default 60.0)

    On each retry the delay doubles (exponential backoff, capped at _RETRY_MAX_DELAY).
    If the Retry-After header is present its value is used instead.
    Raises the original RateLimitError once all retries are exhausted.
    """
    delay = _RETRY_BASE_DELAY
    for attempt in range(_RETRY_MAX + 1):
        try:
            with client.messages.stream(**kwargs) as stream:
                for _chunk in stream.text_stream:
                    pass
                return stream.get_final_message()
        except anthropic.RateLimitError as exc:
            if attempt >= _RETRY_MAX:
                raise
            retry_after = None
            if hasattr(exc, "response") and exc.response is not None:
                retry_after = exc.response.headers.get("retry-after")
            wait = float(retry_after) if retry_after else min(delay, _RETRY_MAX_DELAY)
            logger.warning(
                "429 RateLimitError (attempt %d/%d) β€” waiting %.1fs before retry. %s",
                attempt + 1, _RETRY_MAX + 1, wait, exc,
            )
            time.sleep(wait)
            delay = min(delay * 2, _RETRY_MAX_DELAY)


class GUIDEAgent:
    """Stateful Claude Managed Agent for a single user session."""

    def __init__(self, session_id: str) -> None:
        self._session_id = session_id
        self._memory = SessionMemory()
        # Anthropic message history β€” list of {"role": ..., "content": ...} dicts.
        # Content may be a string (user text) or a list of content blocks
        # (assistant turns with text + tool_use blocks, tool_result turns).
        self._history: list[dict] = []
        # Document paths queued by add_document(); prepended as context on the
        # next send_message() call then cleared.
        self._pending_documents: list[str] = []
        # One Anthropic client per agent instance so sessions are independent.
        # If LITELLM_PROXY_URL is set, route through LiteLLM gateway; otherwise
        # use the Anthropic API directly with ANTHROPIC_API_KEY.
        # max_retries=0: the Anthropic SDK retries 429s TWICE by default, and
        # each retry RE-SENDS the full request β€” re-charging the per-minute token
        # bucket immediately, with no wait for it to refill. On a tight TPM cap
        # that triples token consumption per user action and guarantees the storm.
        # Disable SDK retries so ONLY our _stream_with_retry backoff fires β€” it
        # waits ~60s for the bucket to refill before re-sending.
        litellm_url = os.getenv("LITELLM_PROXY_URL")
        if litellm_url:
            from litellm import get_litellm_gateway_api_key
            self._client = anthropic.Anthropic(
                base_url=litellm_url,
                api_key=get_litellm_gateway_api_key(),
                max_retries=0,
            )
        else:
            _key = os.environ.get("ANTHROPIC_API_KEY", "")
            logger.info("GUIDEAgent init key ends=%s len=%d", _key[-3:] if _key else "EMPTY", len(_key))
            self._client = anthropic.Anthropic(
                api_key=_key,
                max_retries=0,
            )

    # ------------------------------------------------------------------
    # Public interface
    # ------------------------------------------------------------------

    def send_message(self, user_text: str) -> str:
        """
        Process *user_text* through the CMA loop and return the assistant reply.

        *user_text* must already be PII-redacted (Presidio runs at the API layer).

        If documents were queued via add_document(), their paths are prepended to
        the user's message so the agent sees the "[Document uploaded: <path>]"
        prefix mandated by Rule 3 of the system prompt.
        """
        # Prepend queued document notifications
        if self._pending_documents:
            prefixes = "\n".join(
                f"[Document uploaded: {p}]" for p in self._pending_documents
            )
            user_text = f"{prefixes}\n\n{user_text}"
            self._pending_documents.clear()

        self._history.append({"role": "user", "content": user_text})
        return self._run_agent_loop()

    def confirm_entities(self, verified_entities: dict) -> str:
        """
        Inject a [USER CONFIRMED] message with the user-verified entity values,
        then run the agent loop to trigger draft_complaint (Rule 5).

        *verified_entities* is the dict submitted from the HITL Verify Entities
        panel, e.g. {"ORG": "HDFC Bank", "AMOUNT": "β‚Ή5000"}.
        """
        confirmation = (
            f"[USER CONFIRMED]: {json.dumps(verified_entities, ensure_ascii=False)}"
        )
        self._history.append({"role": "user", "content": confirmation})
        return self._run_agent_loop()

    def generate_escalation(self) -> str:
        """
        Second, SEPARATE request that produces the escalation guide.

        The draft letter (confirm_entities) and the escalation guide are split
        into two distinct agent turns β€” two separate Anthropic requests β€” so each
        stays within the per-minute token budget. Calling this after the draft
        also lets the token bucket refill between the two requests, avoiding the
        429 storm that occurred when both were generated in one continuous turn.

        Relies on Rule 7 of the system prompt: only when this follow-up arrives
        does the model call recommend_action() and emit the escalation guide.
        """
        request = (
            "Now generate the escalation guide for this complaint. "
            "Call recommend_action() with the confirmed domain, entities, and "
            "prior_contact, then present the numbered escalation path per Rule 7. "
            "Output ONLY the escalation guide β€” do not repeat the complaint letter."
        )
        self._history.append({"role": "user", "content": request})
        return self._run_agent_loop()

    def add_document(self, file_path: str) -> None:
        """Queue a document path so it appears in the next send_message() turn."""
        self._pending_documents.append(file_path)

    def get_history(self) -> list[dict]:
        """Return the current conversation history (shallow copy)."""
        return list(self._history)

    # ------------------------------------------------------------------
    # Internal agent loop
    # ------------------------------------------------------------------

    def _run_agent_loop(self) -> str:
        """
        Stream CMA responses and execute tool calls until stop_reason != "tool_use".

        Uses the Anthropic streaming API so tokens are pushed to the network
        buffer in real time (enabling future Gradio streaming via a generator).
        The full response object is captured via stream.get_final_message() so
        tool_use blocks can be inspected and dispatched.

        Returns the last text response produced by the model (may be "" if the
        final turn was purely tool calls followed by no text, which should not
        happen in practice given the system prompt).
        """
        all_text_parts: list[str] = []

        for round_num in range(_MAX_TOOL_ROUNDS):
            logger.debug(
                "Session %s: agent round %d β€” history length %d",
                self._session_id, round_num + 1, len(self._history),
            )

            # Stream one response. Use the backoff-retry path only when retry is
            # enabled via env (GUIDE_RETRY_MAX > 0); otherwise plain stream.
            stream_fn = _stream_with_retry if _RETRY_ENABLED else _stream_once
            response = stream_fn(
                self._client,
                model=_MODEL,
                system=[{"type": "text", "text": SYSTEM_PROMPT, "cache_control": {"type": "ephemeral"}}],
                messages=self._history[-_MAX_HISTORY_TURNS:],
                tools=TOOL_DEFINITIONS,
                max_tokens=_MAX_TOKENS,
            )

            # Token-usage diagnostics. Logs input vs cache hits so we can see
            # whether prompt caching is crediting us against the per-minute token
            # cap (on Bedrock, cache reads often still count toward TPM). The
            # uncached input is what actually drains the 10k/min bucket.
            usage = getattr(response, "usage", None)
            if usage is not None:
                inp = getattr(usage, "input_tokens", 0) or 0
                out = getattr(usage, "output_tokens", 0) or 0
                cread = getattr(usage, "cache_read_input_tokens", 0) or 0
                ccreate = getattr(usage, "cache_creation_input_tokens", 0) or 0
                logger.info(
                    "Session %s: round %d usage β€” input=%d (uncached), "
                    "cache_read=%d, cache_create=%d, output=%d, total_in=%d",
                    self._session_id, round_num + 1,
                    inp, cread, ccreate, out, inp + cread + ccreate,
                )

            # Record the full assistant turn (may include tool_use blocks).
            # Manually build plain dicts with only the fields the Anthropic API
            # accepts β€” model_dump() includes LangSmith-injected extras like
            # `parsed_output` that cause 400 errors on subsequent rounds.
            self._history.append(
                {"role": "assistant", "content": [_serialize_block(b) for b in response.content]}
            )

            # Accumulate text across all rounds so the draft letter (emitted in
            # one round) is not overwritten by the escalation guide (emitted in
            # a later round after recommend_action completes).
            current_text = "".join(
                block.text
                for block in response.content
                if hasattr(block, "text") and block.text
            )
            if current_text:
                all_text_parts.append(current_text)

            if response.stop_reason == "max_tokens":
                # The model ran out of output budget mid-turn β€” text is truncated
                # (e.g. a complaint letter cut off mid-signature). Surface this
                # loudly; it is otherwise indistinguishable from a clean finish.
                logger.warning(
                    "Session %s: round %d hit max_tokens (=%d) β€” response TRUNCATED. "
                    "Raise _MAX_TOKENS or split drafting from escalation.",
                    self._session_id, round_num + 1, _MAX_TOKENS,
                )

            if response.stop_reason != "tool_use":
                logger.info(
                    "Session %s: agent loop complete (round %d, stop=%s)",
                    self._session_id, round_num + 1, response.stop_reason,
                )
                return "\n\n".join(all_text_parts)

            # Dispatch all tool calls from this response
            tool_result_blocks = self._execute_tool_calls(response.content)
            self._history.append(
                {"role": "user", "content": tool_result_blocks}
            )

        # Exceeded max rounds β€” return whatever text we have
        logger.warning(
            "Session %s: agent loop hit max rounds (%d). Returning partial reply.",
            self._session_id, _MAX_TOOL_ROUNDS,
        )
        return "\n\n".join(all_text_parts) or _FALLBACK_REPLY

    def _execute_tool_calls(self, content_blocks) -> list[dict]:
        """
        Find all tool_use blocks in *content_blocks*, execute each via
        execute_tool(), and return a list of tool_result dicts ready to be
        appended to the conversation history as a "user" turn.
        """
        results = []
        for block in content_blocks:
            # Support both SDK objects and plain dicts (after model_dump serialization)
            block_type = block.type if hasattr(block, "type") else block.get("type")
            if block_type != "tool_use":
                continue

            name = block.name if hasattr(block, "name") else block["name"]
            bid = block.id if hasattr(block, "id") else block["id"]
            inp = block.input if hasattr(block, "input") else block["input"]

            logger.info(
                "Session %s: tool %r (id=%s) input=%s",
                self._session_id,
                name,
                bid,
                json.dumps(inp, ensure_ascii=False, default=str)[:200],
            )

            result = execute_tool(name, inp, self._memory)

            logger.debug(
                "Session %s: tool %r result=%s",
                self._session_id,
                name,
                json.dumps(result, ensure_ascii=False, default=str)[:200],
            )

            results.append(
                {
                    "type": "tool_result",
                    "tool_use_id": bid,
                    "content": json.dumps(
                        result, ensure_ascii=False, default=str
                    ),
                }
            )

        return results