Spaces:
Sleeping
Sleeping
| """ | |
| Claude Managed Agent (CMA) β orchestration layer. | |
| Maintains per-session state: | |
| - conversation history (list of Anthropic message dicts) | |
| - SessionMemory (extracted entities, domain, draft versions, etc.) | |
| - uploaded document paths | |
| Model: env-driven via GUIDE_MODEL (default claude-sonnet-4-6) | |
| Decision flow per user turn: | |
| 1. Presidio redaction happens at the API layer before this module is called. | |
| 2. Is domain known? No β call classify_domain. | |
| 3. Are minimum fields complete? No β ask one follow-up question. | |
| 4. Was a document uploaded? Yes β call process_document, merge entities. | |
| 5. HITL gate: present summary β wait for [USER CONFIRMED]. | |
| 6. [USER CONFIRMED] received β call draft_complaint. | |
| 7. User asks next steps β call recommend_action. | |
| The agent loop runs until the model returns a non-tool_use stop reason. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import time | |
| import anthropic | |
| from src.agent.memory import SessionMemory | |
| from src.agent.prompts import SYSTEM_PROMPT | |
| from src.agent.tools import TOOL_DEFINITIONS, execute_tool | |
| logger = logging.getLogger(__name__) | |
| # Model name is env-driven so the same code runs against two backends: | |
| # β’ Local via the athena LiteLLM gateway, which only exposes Bedrock model | |
| # IDs (e.g. bedrock-claude-3-5-haiku-20241022-v1:0) β set GUIDE_MODEL in .env. | |
| # β’ HF Spaces / public Anthropic API β leave GUIDE_MODEL unset to use the | |
| # default claude-sonnet-4-6. | |
| _MODEL = os.getenv("GUIDE_MODEL", "claude-sonnet-4-6") | |
| _MAX_TOKENS = 4096 | |
| # Hard cap on tool-use rounds per turn to prevent runaway loops. | |
| _MAX_TOOL_ROUNDS = 12 | |
| # Send only the last N messages to keep input tokens low on rate-limited keys. | |
| # SessionMemory preserves all extracted entities so context is not lost. | |
| _MAX_HISTORY_TURNS = 10 | |
| # Retry on 429 (RPM/TPM) β opt-in via .env so local devs can tune without | |
| # committing. Retry is enabled ONLY when GUIDE_RETRY_MAX is set to a value > 0; | |
| # otherwise the agent uses a plain stream (no backoff) on top of the | |
| # Anthropic SDK's own built-in retries. | |
| _RETRY_MAX = int(os.getenv("GUIDE_RETRY_MAX", "0")) | |
| _RETRY_BASE_DELAY = float(os.getenv("GUIDE_RETRY_BASE_DELAY", "10.0")) # seconds | |
| _RETRY_MAX_DELAY = float(os.getenv("GUIDE_RETRY_MAX_DELAY", "60.0")) # seconds | |
| _RETRY_ENABLED = _RETRY_MAX > 0 | |
| _FALLBACK_REPLY = ( | |
| "I'm sorry, I encountered an issue processing your request. " | |
| "Please try again or rephrase your message." | |
| ) | |
| def _serialize_block(block) -> dict: | |
| """Convert an Anthropic SDK content block to a plain API-safe dict. | |
| model_dump() includes LangSmith-injected fields (e.g. parsed_output) that | |
| the Anthropic API rejects with 400. Only emit the fields each block type | |
| actually accepts. | |
| """ | |
| t = block.type if hasattr(block, "type") else block.get("type") | |
| if t == "text": | |
| text = block.text if hasattr(block, "text") else block["text"] | |
| return {"type": "text", "text": text} | |
| if t == "tool_use": | |
| bid = block.id if hasattr(block, "id") else block["id"] | |
| name = block.name if hasattr(block, "name") else block["name"] | |
| inp = block.input if hasattr(block, "input") else block["input"] | |
| return {"type": "tool_use", "id": bid, "name": name, "input": inp} | |
| # Fallback: strip to primitive types only via JSON round-trip | |
| import json as _json | |
| return _json.loads(_json.dumps(block if isinstance(block, dict) else block.model_dump(), default=str)) | |
| def _stream_once(client, **kwargs): | |
| """Plain stream β drain the token stream and return the final message. | |
| Used when retry is disabled (GUIDE_RETRY_MAX unset/0). The Anthropic SDK | |
| still applies its own built-in retries underneath; this adds no extra | |
| backoff logic. | |
| """ | |
| with client.messages.stream(**kwargs) as stream: | |
| for _chunk in stream.text_stream: | |
| pass | |
| return stream.get_final_message() | |
| def _stream_with_retry(client, **kwargs): | |
| """ | |
| Call client.messages.stream(**kwargs) and retry on 429 RateLimitError. | |
| Only invoked when retry is enabled (GUIDE_RETRY_MAX > 0). Behaviour is | |
| controlled by three env vars (set in .env, never committed): | |
| GUIDE_RETRY_MAX β max attempts after the first (0 = no retry) | |
| GUIDE_RETRY_BASE_DELAY β initial backoff in seconds (default 10.0) | |
| GUIDE_RETRY_MAX_DELAY β cap on backoff (default 60.0) | |
| On each retry the delay doubles (exponential backoff, capped at _RETRY_MAX_DELAY). | |
| If the Retry-After header is present its value is used instead. | |
| Raises the original RateLimitError once all retries are exhausted. | |
| """ | |
| delay = _RETRY_BASE_DELAY | |
| for attempt in range(_RETRY_MAX + 1): | |
| try: | |
| with client.messages.stream(**kwargs) as stream: | |
| for _chunk in stream.text_stream: | |
| pass | |
| return stream.get_final_message() | |
| except anthropic.RateLimitError as exc: | |
| if attempt >= _RETRY_MAX: | |
| raise | |
| retry_after = None | |
| if hasattr(exc, "response") and exc.response is not None: | |
| retry_after = exc.response.headers.get("retry-after") | |
| wait = float(retry_after) if retry_after else min(delay, _RETRY_MAX_DELAY) | |
| logger.warning( | |
| "429 RateLimitError (attempt %d/%d) β waiting %.1fs before retry. %s", | |
| attempt + 1, _RETRY_MAX + 1, wait, exc, | |
| ) | |
| time.sleep(wait) | |
| delay = min(delay * 2, _RETRY_MAX_DELAY) | |
| class GUIDEAgent: | |
| """Stateful Claude Managed Agent for a single user session.""" | |
| def __init__(self, session_id: str) -> None: | |
| self._session_id = session_id | |
| self._memory = SessionMemory() | |
| # Anthropic message history β list of {"role": ..., "content": ...} dicts. | |
| # Content may be a string (user text) or a list of content blocks | |
| # (assistant turns with text + tool_use blocks, tool_result turns). | |
| self._history: list[dict] = [] | |
| # Document paths queued by add_document(); prepended as context on the | |
| # next send_message() call then cleared. | |
| self._pending_documents: list[str] = [] | |
| # One Anthropic client per agent instance so sessions are independent. | |
| # If LITELLM_PROXY_URL is set, route through LiteLLM gateway; otherwise | |
| # use the Anthropic API directly with ANTHROPIC_API_KEY. | |
| # max_retries=0: the Anthropic SDK retries 429s TWICE by default, and | |
| # each retry RE-SENDS the full request β re-charging the per-minute token | |
| # bucket immediately, with no wait for it to refill. On a tight TPM cap | |
| # that triples token consumption per user action and guarantees the storm. | |
| # Disable SDK retries so ONLY our _stream_with_retry backoff fires β it | |
| # waits ~60s for the bucket to refill before re-sending. | |
| litellm_url = os.getenv("LITELLM_PROXY_URL") | |
| if litellm_url: | |
| from litellm import get_litellm_gateway_api_key | |
| self._client = anthropic.Anthropic( | |
| base_url=litellm_url, | |
| api_key=get_litellm_gateway_api_key(), | |
| max_retries=0, | |
| ) | |
| else: | |
| _key = os.environ.get("ANTHROPIC_API_KEY", "") | |
| logger.info("GUIDEAgent init key ends=%s len=%d", _key[-3:] if _key else "EMPTY", len(_key)) | |
| self._client = anthropic.Anthropic( | |
| api_key=_key, | |
| max_retries=0, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Public interface | |
| # ------------------------------------------------------------------ | |
| def send_message(self, user_text: str) -> str: | |
| """ | |
| Process *user_text* through the CMA loop and return the assistant reply. | |
| *user_text* must already be PII-redacted (Presidio runs at the API layer). | |
| If documents were queued via add_document(), their paths are prepended to | |
| the user's message so the agent sees the "[Document uploaded: <path>]" | |
| prefix mandated by Rule 3 of the system prompt. | |
| """ | |
| # Prepend queued document notifications | |
| if self._pending_documents: | |
| prefixes = "\n".join( | |
| f"[Document uploaded: {p}]" for p in self._pending_documents | |
| ) | |
| user_text = f"{prefixes}\n\n{user_text}" | |
| self._pending_documents.clear() | |
| self._history.append({"role": "user", "content": user_text}) | |
| return self._run_agent_loop() | |
| def confirm_entities(self, verified_entities: dict) -> str: | |
| """ | |
| Inject a [USER CONFIRMED] message with the user-verified entity values, | |
| then run the agent loop to trigger draft_complaint (Rule 5). | |
| *verified_entities* is the dict submitted from the HITL Verify Entities | |
| panel, e.g. {"ORG": "HDFC Bank", "AMOUNT": "βΉ5000"}. | |
| """ | |
| confirmation = ( | |
| f"[USER CONFIRMED]: {json.dumps(verified_entities, ensure_ascii=False)}" | |
| ) | |
| self._history.append({"role": "user", "content": confirmation}) | |
| return self._run_agent_loop() | |
| def generate_escalation(self) -> str: | |
| """ | |
| Second, SEPARATE request that produces the escalation guide. | |
| The draft letter (confirm_entities) and the escalation guide are split | |
| into two distinct agent turns β two separate Anthropic requests β so each | |
| stays within the per-minute token budget. Calling this after the draft | |
| also lets the token bucket refill between the two requests, avoiding the | |
| 429 storm that occurred when both were generated in one continuous turn. | |
| Relies on Rule 7 of the system prompt: only when this follow-up arrives | |
| does the model call recommend_action() and emit the escalation guide. | |
| """ | |
| request = ( | |
| "Now generate the escalation guide for this complaint. " | |
| "Call recommend_action() with the confirmed domain, entities, and " | |
| "prior_contact, then present the numbered escalation path per Rule 7. " | |
| "Output ONLY the escalation guide β do not repeat the complaint letter." | |
| ) | |
| self._history.append({"role": "user", "content": request}) | |
| return self._run_agent_loop() | |
| def add_document(self, file_path: str) -> None: | |
| """Queue a document path so it appears in the next send_message() turn.""" | |
| self._pending_documents.append(file_path) | |
| def get_history(self) -> list[dict]: | |
| """Return the current conversation history (shallow copy).""" | |
| return list(self._history) | |
| # ------------------------------------------------------------------ | |
| # Internal agent loop | |
| # ------------------------------------------------------------------ | |
| def _run_agent_loop(self) -> str: | |
| """ | |
| Stream CMA responses and execute tool calls until stop_reason != "tool_use". | |
| Uses the Anthropic streaming API so tokens are pushed to the network | |
| buffer in real time (enabling future Gradio streaming via a generator). | |
| The full response object is captured via stream.get_final_message() so | |
| tool_use blocks can be inspected and dispatched. | |
| Returns the last text response produced by the model (may be "" if the | |
| final turn was purely tool calls followed by no text, which should not | |
| happen in practice given the system prompt). | |
| """ | |
| all_text_parts: list[str] = [] | |
| for round_num in range(_MAX_TOOL_ROUNDS): | |
| logger.debug( | |
| "Session %s: agent round %d β history length %d", | |
| self._session_id, round_num + 1, len(self._history), | |
| ) | |
| # Stream one response. Use the backoff-retry path only when retry is | |
| # enabled via env (GUIDE_RETRY_MAX > 0); otherwise plain stream. | |
| stream_fn = _stream_with_retry if _RETRY_ENABLED else _stream_once | |
| response = stream_fn( | |
| self._client, | |
| model=_MODEL, | |
| system=[{"type": "text", "text": SYSTEM_PROMPT, "cache_control": {"type": "ephemeral"}}], | |
| messages=self._history[-_MAX_HISTORY_TURNS:], | |
| tools=TOOL_DEFINITIONS, | |
| max_tokens=_MAX_TOKENS, | |
| ) | |
| # Token-usage diagnostics. Logs input vs cache hits so we can see | |
| # whether prompt caching is crediting us against the per-minute token | |
| # cap (on Bedrock, cache reads often still count toward TPM). The | |
| # uncached input is what actually drains the 10k/min bucket. | |
| usage = getattr(response, "usage", None) | |
| if usage is not None: | |
| inp = getattr(usage, "input_tokens", 0) or 0 | |
| out = getattr(usage, "output_tokens", 0) or 0 | |
| cread = getattr(usage, "cache_read_input_tokens", 0) or 0 | |
| ccreate = getattr(usage, "cache_creation_input_tokens", 0) or 0 | |
| logger.info( | |
| "Session %s: round %d usage β input=%d (uncached), " | |
| "cache_read=%d, cache_create=%d, output=%d, total_in=%d", | |
| self._session_id, round_num + 1, | |
| inp, cread, ccreate, out, inp + cread + ccreate, | |
| ) | |
| # Record the full assistant turn (may include tool_use blocks). | |
| # Manually build plain dicts with only the fields the Anthropic API | |
| # accepts β model_dump() includes LangSmith-injected extras like | |
| # `parsed_output` that cause 400 errors on subsequent rounds. | |
| self._history.append( | |
| {"role": "assistant", "content": [_serialize_block(b) for b in response.content]} | |
| ) | |
| # Accumulate text across all rounds so the draft letter (emitted in | |
| # one round) is not overwritten by the escalation guide (emitted in | |
| # a later round after recommend_action completes). | |
| current_text = "".join( | |
| block.text | |
| for block in response.content | |
| if hasattr(block, "text") and block.text | |
| ) | |
| if current_text: | |
| all_text_parts.append(current_text) | |
| if response.stop_reason == "max_tokens": | |
| # The model ran out of output budget mid-turn β text is truncated | |
| # (e.g. a complaint letter cut off mid-signature). Surface this | |
| # loudly; it is otherwise indistinguishable from a clean finish. | |
| logger.warning( | |
| "Session %s: round %d hit max_tokens (=%d) β response TRUNCATED. " | |
| "Raise _MAX_TOKENS or split drafting from escalation.", | |
| self._session_id, round_num + 1, _MAX_TOKENS, | |
| ) | |
| if response.stop_reason != "tool_use": | |
| logger.info( | |
| "Session %s: agent loop complete (round %d, stop=%s)", | |
| self._session_id, round_num + 1, response.stop_reason, | |
| ) | |
| return "\n\n".join(all_text_parts) | |
| # Dispatch all tool calls from this response | |
| tool_result_blocks = self._execute_tool_calls(response.content) | |
| self._history.append( | |
| {"role": "user", "content": tool_result_blocks} | |
| ) | |
| # Exceeded max rounds β return whatever text we have | |
| logger.warning( | |
| "Session %s: agent loop hit max rounds (%d). Returning partial reply.", | |
| self._session_id, _MAX_TOOL_ROUNDS, | |
| ) | |
| return "\n\n".join(all_text_parts) or _FALLBACK_REPLY | |
| def _execute_tool_calls(self, content_blocks) -> list[dict]: | |
| """ | |
| Find all tool_use blocks in *content_blocks*, execute each via | |
| execute_tool(), and return a list of tool_result dicts ready to be | |
| appended to the conversation history as a "user" turn. | |
| """ | |
| results = [] | |
| for block in content_blocks: | |
| # Support both SDK objects and plain dicts (after model_dump serialization) | |
| block_type = block.type if hasattr(block, "type") else block.get("type") | |
| if block_type != "tool_use": | |
| continue | |
| name = block.name if hasattr(block, "name") else block["name"] | |
| bid = block.id if hasattr(block, "id") else block["id"] | |
| inp = block.input if hasattr(block, "input") else block["input"] | |
| logger.info( | |
| "Session %s: tool %r (id=%s) input=%s", | |
| self._session_id, | |
| name, | |
| bid, | |
| json.dumps(inp, ensure_ascii=False, default=str)[:200], | |
| ) | |
| result = execute_tool(name, inp, self._memory) | |
| logger.debug( | |
| "Session %s: tool %r result=%s", | |
| self._session_id, | |
| name, | |
| json.dumps(result, ensure_ascii=False, default=str)[:200], | |
| ) | |
| results.append( | |
| { | |
| "type": "tool_result", | |
| "tool_use_id": bid, | |
| "content": json.dumps( | |
| result, ensure_ascii=False, default=str | |
| ), | |
| } | |
| ) | |
| return results | |