guide / src /agent /agent.py
anmol-iisc's picture
UI enhancements, letter text redundant text removed
d230384
Raw
History Blame Contribute Delete
17.4 kB
"""
Claude Managed Agent (CMA) β€” orchestration layer.
Maintains per-session state:
- conversation history (list of Anthropic message dicts)
- SessionMemory (extracted entities, domain, draft versions, etc.)
- uploaded document paths
Model: env-driven via GUIDE_MODEL (default claude-sonnet-4-6)
Decision flow per user turn:
1. Presidio redaction happens at the API layer before this module is called.
2. Is domain known? No β†’ call classify_domain.
3. Are minimum fields complete? No β†’ ask one follow-up question.
4. Was a document uploaded? Yes β†’ call process_document, merge entities.
5. HITL gate: present summary β†’ wait for [USER CONFIRMED].
6. [USER CONFIRMED] received β†’ call draft_complaint.
7. User asks next steps β†’ call recommend_action.
The agent loop runs until the model returns a non-tool_use stop reason.
"""
from __future__ import annotations
import json
import logging
import os
import time
import anthropic
from src.agent.memory import SessionMemory
from src.agent.prompts import SYSTEM_PROMPT
from src.agent.tools import TOOL_DEFINITIONS, execute_tool
logger = logging.getLogger(__name__)
# Model name is env-driven so the same code runs against two backends:
# β€’ Local via the athena LiteLLM gateway, which only exposes Bedrock model
# IDs (e.g. bedrock-claude-3-5-haiku-20241022-v1:0) β€” set GUIDE_MODEL in .env.
# β€’ HF Spaces / public Anthropic API β€” leave GUIDE_MODEL unset to use the
# default claude-sonnet-4-6.
_MODEL = os.getenv("GUIDE_MODEL", "claude-sonnet-4-6")
_MAX_TOKENS = 4096
# Hard cap on tool-use rounds per turn to prevent runaway loops.
_MAX_TOOL_ROUNDS = 12
# Send only the last N messages to keep input tokens low on rate-limited keys.
# SessionMemory preserves all extracted entities so context is not lost.
_MAX_HISTORY_TURNS = 10
# Retry on 429 (RPM/TPM) β€” opt-in via .env so local devs can tune without
# committing. Retry is enabled ONLY when GUIDE_RETRY_MAX is set to a value > 0;
# otherwise the agent uses a plain stream (no backoff) on top of the
# Anthropic SDK's own built-in retries.
_RETRY_MAX = int(os.getenv("GUIDE_RETRY_MAX", "0"))
_RETRY_BASE_DELAY = float(os.getenv("GUIDE_RETRY_BASE_DELAY", "10.0")) # seconds
_RETRY_MAX_DELAY = float(os.getenv("GUIDE_RETRY_MAX_DELAY", "60.0")) # seconds
_RETRY_ENABLED = _RETRY_MAX > 0
_FALLBACK_REPLY = (
"I'm sorry, I encountered an issue processing your request. "
"Please try again or rephrase your message."
)
def _serialize_block(block) -> dict:
"""Convert an Anthropic SDK content block to a plain API-safe dict.
model_dump() includes LangSmith-injected fields (e.g. parsed_output) that
the Anthropic API rejects with 400. Only emit the fields each block type
actually accepts.
"""
t = block.type if hasattr(block, "type") else block.get("type")
if t == "text":
text = block.text if hasattr(block, "text") else block["text"]
return {"type": "text", "text": text}
if t == "tool_use":
bid = block.id if hasattr(block, "id") else block["id"]
name = block.name if hasattr(block, "name") else block["name"]
inp = block.input if hasattr(block, "input") else block["input"]
return {"type": "tool_use", "id": bid, "name": name, "input": inp}
# Fallback: strip to primitive types only via JSON round-trip
import json as _json
return _json.loads(_json.dumps(block if isinstance(block, dict) else block.model_dump(), default=str))
def _stream_once(client, **kwargs):
"""Plain stream β€” drain the token stream and return the final message.
Used when retry is disabled (GUIDE_RETRY_MAX unset/0). The Anthropic SDK
still applies its own built-in retries underneath; this adds no extra
backoff logic.
"""
with client.messages.stream(**kwargs) as stream:
for _chunk in stream.text_stream:
pass
return stream.get_final_message()
def _stream_with_retry(client, **kwargs):
"""
Call client.messages.stream(**kwargs) and retry on 429 RateLimitError.
Only invoked when retry is enabled (GUIDE_RETRY_MAX > 0). Behaviour is
controlled by three env vars (set in .env, never committed):
GUIDE_RETRY_MAX β€” max attempts after the first (0 = no retry)
GUIDE_RETRY_BASE_DELAY β€” initial backoff in seconds (default 10.0)
GUIDE_RETRY_MAX_DELAY β€” cap on backoff (default 60.0)
On each retry the delay doubles (exponential backoff, capped at _RETRY_MAX_DELAY).
If the Retry-After header is present its value is used instead.
Raises the original RateLimitError once all retries are exhausted.
"""
delay = _RETRY_BASE_DELAY
for attempt in range(_RETRY_MAX + 1):
try:
with client.messages.stream(**kwargs) as stream:
for _chunk in stream.text_stream:
pass
return stream.get_final_message()
except anthropic.RateLimitError as exc:
if attempt >= _RETRY_MAX:
raise
retry_after = None
if hasattr(exc, "response") and exc.response is not None:
retry_after = exc.response.headers.get("retry-after")
wait = float(retry_after) if retry_after else min(delay, _RETRY_MAX_DELAY)
logger.warning(
"429 RateLimitError (attempt %d/%d) β€” waiting %.1fs before retry. %s",
attempt + 1, _RETRY_MAX + 1, wait, exc,
)
time.sleep(wait)
delay = min(delay * 2, _RETRY_MAX_DELAY)
class GUIDEAgent:
"""Stateful Claude Managed Agent for a single user session."""
def __init__(self, session_id: str) -> None:
self._session_id = session_id
self._memory = SessionMemory()
# Anthropic message history β€” list of {"role": ..., "content": ...} dicts.
# Content may be a string (user text) or a list of content blocks
# (assistant turns with text + tool_use blocks, tool_result turns).
self._history: list[dict] = []
# Document paths queued by add_document(); prepended as context on the
# next send_message() call then cleared.
self._pending_documents: list[str] = []
# One Anthropic client per agent instance so sessions are independent.
# If LITELLM_PROXY_URL is set, route through LiteLLM gateway; otherwise
# use the Anthropic API directly with ANTHROPIC_API_KEY.
# max_retries=0: the Anthropic SDK retries 429s TWICE by default, and
# each retry RE-SENDS the full request β€” re-charging the per-minute token
# bucket immediately, with no wait for it to refill. On a tight TPM cap
# that triples token consumption per user action and guarantees the storm.
# Disable SDK retries so ONLY our _stream_with_retry backoff fires β€” it
# waits ~60s for the bucket to refill before re-sending.
litellm_url = os.getenv("LITELLM_PROXY_URL")
if litellm_url:
from litellm import get_litellm_gateway_api_key
self._client = anthropic.Anthropic(
base_url=litellm_url,
api_key=get_litellm_gateway_api_key(),
max_retries=0,
)
else:
_key = os.environ.get("ANTHROPIC_API_KEY", "")
logger.info("GUIDEAgent init key ends=%s len=%d", _key[-3:] if _key else "EMPTY", len(_key))
self._client = anthropic.Anthropic(
api_key=_key,
max_retries=0,
)
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
def send_message(self, user_text: str) -> str:
"""
Process *user_text* through the CMA loop and return the assistant reply.
*user_text* must already be PII-redacted (Presidio runs at the API layer).
If documents were queued via add_document(), their paths are prepended to
the user's message so the agent sees the "[Document uploaded: <path>]"
prefix mandated by Rule 3 of the system prompt.
"""
# Prepend queued document notifications
if self._pending_documents:
prefixes = "\n".join(
f"[Document uploaded: {p}]" for p in self._pending_documents
)
user_text = f"{prefixes}\n\n{user_text}"
self._pending_documents.clear()
self._history.append({"role": "user", "content": user_text})
return self._run_agent_loop()
def confirm_entities(self, verified_entities: dict) -> str:
"""
Inject a [USER CONFIRMED] message with the user-verified entity values,
then run the agent loop to trigger draft_complaint (Rule 5).
*verified_entities* is the dict submitted from the HITL Verify Entities
panel, e.g. {"ORG": "HDFC Bank", "AMOUNT": "β‚Ή5000"}.
"""
confirmation = (
f"[USER CONFIRMED]: {json.dumps(verified_entities, ensure_ascii=False)}"
)
self._history.append({"role": "user", "content": confirmation})
return self._run_agent_loop()
def generate_escalation(self) -> str:
"""
Second, SEPARATE request that produces the escalation guide.
The draft letter (confirm_entities) and the escalation guide are split
into two distinct agent turns β€” two separate Anthropic requests β€” so each
stays within the per-minute token budget. Calling this after the draft
also lets the token bucket refill between the two requests, avoiding the
429 storm that occurred when both were generated in one continuous turn.
Relies on Rule 7 of the system prompt: only when this follow-up arrives
does the model call recommend_action() and emit the escalation guide.
"""
request = (
"Now generate the escalation guide for this complaint. "
"Call recommend_action() with the confirmed domain, entities, and "
"prior_contact, then present the numbered escalation path per Rule 7. "
"Output ONLY the escalation guide β€” do not repeat the complaint letter."
)
self._history.append({"role": "user", "content": request})
return self._run_agent_loop()
def add_document(self, file_path: str) -> None:
"""Queue a document path so it appears in the next send_message() turn."""
self._pending_documents.append(file_path)
def get_history(self) -> list[dict]:
"""Return the current conversation history (shallow copy)."""
return list(self._history)
# ------------------------------------------------------------------
# Internal agent loop
# ------------------------------------------------------------------
def _run_agent_loop(self) -> str:
"""
Stream CMA responses and execute tool calls until stop_reason != "tool_use".
Uses the Anthropic streaming API so tokens are pushed to the network
buffer in real time (enabling future Gradio streaming via a generator).
The full response object is captured via stream.get_final_message() so
tool_use blocks can be inspected and dispatched.
Returns the last text response produced by the model (may be "" if the
final turn was purely tool calls followed by no text, which should not
happen in practice given the system prompt).
"""
all_text_parts: list[str] = []
for round_num in range(_MAX_TOOL_ROUNDS):
logger.debug(
"Session %s: agent round %d β€” history length %d",
self._session_id, round_num + 1, len(self._history),
)
# Stream one response. Use the backoff-retry path only when retry is
# enabled via env (GUIDE_RETRY_MAX > 0); otherwise plain stream.
stream_fn = _stream_with_retry if _RETRY_ENABLED else _stream_once
response = stream_fn(
self._client,
model=_MODEL,
system=[{"type": "text", "text": SYSTEM_PROMPT, "cache_control": {"type": "ephemeral"}}],
messages=self._history[-_MAX_HISTORY_TURNS:],
tools=TOOL_DEFINITIONS,
max_tokens=_MAX_TOKENS,
)
# Token-usage diagnostics. Logs input vs cache hits so we can see
# whether prompt caching is crediting us against the per-minute token
# cap (on Bedrock, cache reads often still count toward TPM). The
# uncached input is what actually drains the 10k/min bucket.
usage = getattr(response, "usage", None)
if usage is not None:
inp = getattr(usage, "input_tokens", 0) or 0
out = getattr(usage, "output_tokens", 0) or 0
cread = getattr(usage, "cache_read_input_tokens", 0) or 0
ccreate = getattr(usage, "cache_creation_input_tokens", 0) or 0
logger.info(
"Session %s: round %d usage β€” input=%d (uncached), "
"cache_read=%d, cache_create=%d, output=%d, total_in=%d",
self._session_id, round_num + 1,
inp, cread, ccreate, out, inp + cread + ccreate,
)
# Record the full assistant turn (may include tool_use blocks).
# Manually build plain dicts with only the fields the Anthropic API
# accepts β€” model_dump() includes LangSmith-injected extras like
# `parsed_output` that cause 400 errors on subsequent rounds.
self._history.append(
{"role": "assistant", "content": [_serialize_block(b) for b in response.content]}
)
# Accumulate text across all rounds so the draft letter (emitted in
# one round) is not overwritten by the escalation guide (emitted in
# a later round after recommend_action completes).
current_text = "".join(
block.text
for block in response.content
if hasattr(block, "text") and block.text
)
if current_text:
all_text_parts.append(current_text)
if response.stop_reason == "max_tokens":
# The model ran out of output budget mid-turn β€” text is truncated
# (e.g. a complaint letter cut off mid-signature). Surface this
# loudly; it is otherwise indistinguishable from a clean finish.
logger.warning(
"Session %s: round %d hit max_tokens (=%d) β€” response TRUNCATED. "
"Raise _MAX_TOKENS or split drafting from escalation.",
self._session_id, round_num + 1, _MAX_TOKENS,
)
if response.stop_reason != "tool_use":
logger.info(
"Session %s: agent loop complete (round %d, stop=%s)",
self._session_id, round_num + 1, response.stop_reason,
)
return "\n\n".join(all_text_parts)
# Dispatch all tool calls from this response
tool_result_blocks = self._execute_tool_calls(response.content)
self._history.append(
{"role": "user", "content": tool_result_blocks}
)
# Exceeded max rounds β€” return whatever text we have
logger.warning(
"Session %s: agent loop hit max rounds (%d). Returning partial reply.",
self._session_id, _MAX_TOOL_ROUNDS,
)
return "\n\n".join(all_text_parts) or _FALLBACK_REPLY
def _execute_tool_calls(self, content_blocks) -> list[dict]:
"""
Find all tool_use blocks in *content_blocks*, execute each via
execute_tool(), and return a list of tool_result dicts ready to be
appended to the conversation history as a "user" turn.
"""
results = []
for block in content_blocks:
# Support both SDK objects and plain dicts (after model_dump serialization)
block_type = block.type if hasattr(block, "type") else block.get("type")
if block_type != "tool_use":
continue
name = block.name if hasattr(block, "name") else block["name"]
bid = block.id if hasattr(block, "id") else block["id"]
inp = block.input if hasattr(block, "input") else block["input"]
logger.info(
"Session %s: tool %r (id=%s) input=%s",
self._session_id,
name,
bid,
json.dumps(inp, ensure_ascii=False, default=str)[:200],
)
result = execute_tool(name, inp, self._memory)
logger.debug(
"Session %s: tool %r result=%s",
self._session_id,
name,
json.dumps(result, ensure_ascii=False, default=str)[:200],
)
results.append(
{
"type": "tool_result",
"tool_use_id": bid,
"content": json.dumps(
result, ensure_ascii=False, default=str
),
}
)
return results