Buckets:
| from __future__ import annotations | |
| import json | |
| import re | |
| from collections import defaultdict | |
| from threading import Lock | |
| from typing import Any | |
| from agent.tools import TOOL_SCHEMAS, calculate_premium, check_claim_status, search_policy | |
| from model_server.llm_client import get_llm | |
| TOOLS = { | |
| "search_policy": search_policy, | |
| "calculate_premium": calculate_premium, | |
| "check_claim_status": check_claim_status, | |
| } | |
| AGENT_SYSTEM_PROMPT = """You are InsureCo's agentic broker assistant. | |
| You can decide to call tools or answer directly. | |
| Return ONLY JSON in one of these two formats: | |
| 1) Tool call: | |
| {"action":"tool","tool_name":"search_policy","arguments":{"query":"..."}} | |
| 2) Final response: | |
| {"action":"final","answer":"..."} | |
| Rules: | |
| - Never output markdown. | |
| - If you call tools, use only listed tool names and valid argument keys. | |
| - Use tool results to ground answers. | |
| - If info is missing, say so clearly. | |
| - For claim status requests, call check_claim_status before final. | |
| - For premium computation requests, call calculate_premium before final. | |
| - For policy coverage/deductible/benefit/exclusion questions, call search_policy before final. | |
| """ | |
| class AgentEngine: | |
| def __init__(self, max_history_pairs: int = 12, default_top_k: int = 4): | |
| self.max_history_pairs = max_history_pairs | |
| self.default_top_k = default_top_k | |
| self._sessions: dict[str, list[dict[str, str]]] = defaultdict(list) | |
| self._lock = Lock() | |
| def list_sessions(self) -> list[str]: | |
| with self._lock: | |
| return list(self._sessions.keys()) | |
| def clear_session(self, session_id: str) -> None: | |
| with self._lock: | |
| self._sessions.pop(session_id, None) | |
| def _history(self, session_id: str) -> list[dict[str, str]]: | |
| with self._lock: | |
| history = self._sessions.get(session_id, []) | |
| return history[-(self.max_history_pairs * 2) :] | |
| def _save_turn(self, session_id: str, user_message: str, assistant_message: str) -> None: | |
| with self._lock: | |
| history = self._sessions.get(session_id, []) | |
| history.extend( | |
| [ | |
| {"role": "user", "content": user_message}, | |
| {"role": "assistant", "content": assistant_message}, | |
| ] | |
| ) | |
| self._sessions[session_id] = history[-(self.max_history_pairs * 2) :] | |
| def _tools_text(self) -> str: | |
| lines = [] | |
| for name, schema in TOOL_SCHEMAS.items(): | |
| lines.append( | |
| f"- {name}: {schema['description']} | args={json.dumps(schema['arguments'])}" | |
| ) | |
| return "\n".join(lines) | |
| def _build_messages( | |
| self, | |
| session_id: str, | |
| user_message: str, | |
| scratchpad: list[dict[str, Any]], | |
| turn_index: int, | |
| max_turns: int, | |
| ) -> list[dict[str, str]]: | |
| history = self._history(session_id) | |
| history_text = "\n".join( | |
| [f"{m['role'].upper()}: {m['content']}" for m in history] | |
| ) or "(no prior conversation)" | |
| scratch_text = json.dumps(scratchpad, ensure_ascii=False, indent=2) if scratchpad else "[]" | |
| user_payload = ( | |
| f"Current user message:\n{user_message}\n\n" | |
| f"Conversation history:\n{history_text}\n\n" | |
| f"Available tools:\n{self._tools_text()}\n\n" | |
| f"Tool execution log this request:\n{scratch_text}\n\n" | |
| f"Turn {turn_index}/{max_turns}: choose next action." | |
| ) | |
| return [ | |
| {"role": "system", "content": AGENT_SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_payload}, | |
| ] | |
| def _extract_json_object(self, text: str) -> tuple[dict[str, Any] | None, str | None]: | |
| cleaned = text.strip() | |
| try: | |
| parsed = json.loads(cleaned) | |
| if isinstance(parsed, dict): | |
| return parsed, None | |
| return None, "Response JSON is not an object." | |
| except json.JSONDecodeError: | |
| pass | |
| decoder = json.JSONDecoder() | |
| for idx, ch in enumerate(cleaned): | |
| if ch != "{": | |
| continue | |
| try: | |
| parsed, _ = decoder.raw_decode(cleaned[idx:]) | |
| if isinstance(parsed, dict): | |
| return parsed, None | |
| except json.JSONDecodeError: | |
| continue | |
| return None, "Could not parse valid JSON object from model response." | |
| def _normalize_action(self, obj: dict[str, Any]) -> tuple[str, dict[str, Any] | str | None, str | None]: | |
| action = str(obj.get("action", "")).strip().lower() | |
| if action in {"final", "answer"}: | |
| answer = obj.get("answer") | |
| if not isinstance(answer, str) or not answer.strip(): | |
| return "invalid", None, "Final action missing non-empty 'answer'." | |
| return "final", answer.strip(), None | |
| if action in {"tool", "call_tool"}: | |
| tool_name = str(obj.get("tool_name", "")).strip() | |
| if tool_name not in TOOLS: | |
| return "invalid", None, f"Unknown tool_name: {tool_name!r}" | |
| arguments = obj.get("arguments", {}) | |
| if isinstance(arguments, str): | |
| try: | |
| arguments = json.loads(arguments) | |
| except json.JSONDecodeError: | |
| return "invalid", None, "Tool arguments string is not valid JSON." | |
| if not isinstance(arguments, dict): | |
| return "invalid", None, "Tool arguments must be an object." | |
| return "tool", {"tool_name": tool_name, "arguments": arguments}, None | |
| return "invalid", None, "Action must be 'tool' or 'final'." | |
| def _required_tool_for_message(self, user_message: str) -> str | None: | |
| text = user_message.lower() | |
| if "claim status" in text or re.search(r"\bclm-\d+\b", text): | |
| return "check_claim_status" | |
| if "premium" in text and ("risk" in text or "coverage" in text): | |
| return "calculate_premium" | |
| if any(k in text for k in ["covered", "coverage", "deductible", "exclusion", "policy"]): | |
| return "search_policy" | |
| return None | |
| def _run_tool( | |
| self, | |
| tool_name: str, | |
| arguments: dict[str, Any], | |
| use_reranker: bool, | |
| ) -> dict[str, Any]: | |
| fn = TOOLS[tool_name] | |
| if tool_name == "search_policy": | |
| query = str(arguments.get("query", "")).strip() | |
| if not query: | |
| raise ValueError("search_policy requires 'query'.") | |
| return fn(query=query, use_reranker=use_reranker, top_k=self.default_top_k) | |
| if tool_name == "calculate_premium": | |
| if "coverage" not in arguments or "risk_score" not in arguments: | |
| raise ValueError("calculate_premium requires coverage and risk_score.") | |
| return fn( | |
| coverage=float(arguments["coverage"]), | |
| risk_score=float(arguments["risk_score"]), | |
| ) | |
| if tool_name == "check_claim_status": | |
| claim_id = str(arguments.get("claim_id", "")).strip() | |
| if not claim_id: | |
| raise ValueError("check_claim_status requires 'claim_id'.") | |
| return fn(claim_id=claim_id) | |
| raise ValueError(f"Unsupported tool: {tool_name}") | |
| def _fallback_arguments(self, tool_name: str, user_message: str) -> dict[str, Any]: | |
| text = user_message.strip() | |
| if tool_name == "check_claim_status": | |
| match = re.search(r"\bclm-\d+\b", text, flags=re.IGNORECASE) | |
| return {"claim_id": match.group(0).upper() if match else "CLM-1001"} | |
| if tool_name == "calculate_premium": | |
| nums = [float(x) for x in re.findall(r"\d+(?:\.\d+)?", text)] | |
| coverage = nums[0] if nums else 100000.0 | |
| risk = nums[1] if len(nums) > 1 else 1.0 | |
| return {"coverage": coverage, "risk_score": risk} | |
| return {"query": text} | |
| def chat( | |
| self, | |
| session_id: str, | |
| user_message: str, | |
| use_reranker: bool = True, | |
| max_turns: int = 6, | |
| ) -> dict[str, Any]: | |
| scratchpad: list[dict[str, Any]] = [] | |
| tool_calls: list[dict[str, Any]] = [] | |
| required_tool = self._required_tool_for_message(user_message) | |
| for turn in range(1, max_turns + 1): | |
| messages = self._build_messages( | |
| session_id=session_id, | |
| user_message=user_message, | |
| scratchpad=scratchpad, | |
| turn_index=turn, | |
| max_turns=max_turns, | |
| ) | |
| raw = get_llm().complete(messages) | |
| parsed, parse_error = self._extract_json_object(raw) | |
| if parse_error: | |
| scratchpad.append( | |
| { | |
| "type": "parse_error", | |
| "message": parse_error, | |
| "raw_model_output": raw[:1200], | |
| } | |
| ) | |
| continue | |
| action_type, payload, action_error = self._normalize_action(parsed) | |
| if action_error: | |
| scratchpad.append( | |
| { | |
| "type": "validation_error", | |
| "message": action_error, | |
| "parsed_model_output": parsed, | |
| } | |
| ) | |
| continue | |
| if action_type == "final": | |
| answer = str(payload) | |
| if required_tool and not any(c["tool_name"] == required_tool for c in tool_calls): | |
| scratchpad.append( | |
| { | |
| "type": "policy_enforcement", | |
| "message": ( | |
| f"Before final answer, call required tool '{required_tool}' " | |
| f"for this user intent." | |
| ), | |
| } | |
| ) | |
| continue | |
| self._save_turn(session_id=session_id, user_message=user_message, assistant_message=answer) | |
| return { | |
| "response": answer, | |
| "session_id": session_id, | |
| "tool_calls_made": len(tool_calls), | |
| "tool_trace": tool_calls, | |
| "turns_used": turn, | |
| } | |
| # action_type == "tool" | |
| assert isinstance(payload, dict) | |
| tool_name = str(payload["tool_name"]) | |
| arguments = payload["arguments"] | |
| try: | |
| result = self._run_tool( | |
| tool_name=tool_name, | |
| arguments=arguments, | |
| use_reranker=use_reranker, | |
| ) | |
| except Exception as exc: | |
| scratchpad.append( | |
| { | |
| "type": "tool_error", | |
| "tool_name": tool_name, | |
| "arguments": arguments, | |
| "message": str(exc), | |
| } | |
| ) | |
| continue | |
| call_record = {"tool_name": tool_name, "arguments": arguments, "result": result} | |
| tool_calls.append(call_record) | |
| scratchpad.append({"type": "tool_result", **call_record}) | |
| graceful = ( | |
| "I reached the maximum tool loop limit (6 turns) for this request. " | |
| "Please refine the question or split it into smaller steps." | |
| ) | |
| if required_tool and not tool_calls: | |
| try: | |
| fallback_args = self._fallback_arguments(required_tool, user_message) | |
| fallback_result = self._run_tool( | |
| tool_name=required_tool, | |
| arguments=fallback_args, | |
| use_reranker=use_reranker, | |
| ) | |
| fallback_record = { | |
| "tool_name": required_tool, | |
| "arguments": fallback_args, | |
| "result": fallback_result, | |
| "fallback_used": True, | |
| } | |
| tool_calls.append(fallback_record) | |
| if required_tool == "check_claim_status": | |
| graceful = ( | |
| f"Claim {fallback_result.get('claim_id')} status: " | |
| f"{fallback_result.get('status')} (updated {fallback_result.get('updated_at')})." | |
| ) | |
| elif required_tool == "calculate_premium": | |
| graceful = ( | |
| f"Estimated monthly premium: " | |
| f"{fallback_result.get('monthly_premium')} using " | |
| f"{fallback_result.get('formula')}." | |
| ) | |
| elif required_tool == "search_policy": | |
| rows = fallback_result.get("results", []) | |
| if rows: | |
| top = rows[0] | |
| graceful = ( | |
| f"Top policy match: {top.get('source')} | {top.get('section')}. " | |
| f"{top.get('content')}" | |
| ) | |
| except Exception: | |
| pass | |
| self._save_turn(session_id=session_id, user_message=user_message, assistant_message=graceful) | |
| return { | |
| "response": graceful, | |
| "session_id": session_id, | |
| "tool_calls_made": len(tool_calls), | |
| "tool_trace": tool_calls, | |
| "turns_used": max_turns, | |
| "max_turns_reached": True, | |
| } | |
Xet Storage Details
- Size:
- 13.5 kB
- Xet hash:
- 60a6f4af708f1e454c1e08cacbe93d01d4023a620c66d52760b8e3b7b08b8744
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.