meet4150's picture
download
raw
13.5 kB
from __future__ import annotations
import json
import re
from collections import defaultdict
from threading import Lock
from typing import Any
from agent.tools import TOOL_SCHEMAS, calculate_premium, check_claim_status, search_policy
from model_server.llm_client import get_llm
TOOLS = {
"search_policy": search_policy,
"calculate_premium": calculate_premium,
"check_claim_status": check_claim_status,
}
AGENT_SYSTEM_PROMPT = """You are InsureCo's agentic broker assistant.
You can decide to call tools or answer directly.
Return ONLY JSON in one of these two formats:
1) Tool call:
{"action":"tool","tool_name":"search_policy","arguments":{"query":"..."}}
2) Final response:
{"action":"final","answer":"..."}
Rules:
- Never output markdown.
- If you call tools, use only listed tool names and valid argument keys.
- Use tool results to ground answers.
- If info is missing, say so clearly.
- For claim status requests, call check_claim_status before final.
- For premium computation requests, call calculate_premium before final.
- For policy coverage/deductible/benefit/exclusion questions, call search_policy before final.
"""
class AgentEngine:
def __init__(self, max_history_pairs: int = 12, default_top_k: int = 4):
self.max_history_pairs = max_history_pairs
self.default_top_k = default_top_k
self._sessions: dict[str, list[dict[str, str]]] = defaultdict(list)
self._lock = Lock()
def list_sessions(self) -> list[str]:
with self._lock:
return list(self._sessions.keys())
def clear_session(self, session_id: str) -> None:
with self._lock:
self._sessions.pop(session_id, None)
def _history(self, session_id: str) -> list[dict[str, str]]:
with self._lock:
history = self._sessions.get(session_id, [])
return history[-(self.max_history_pairs * 2) :]
def _save_turn(self, session_id: str, user_message: str, assistant_message: str) -> None:
with self._lock:
history = self._sessions.get(session_id, [])
history.extend(
[
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_message},
]
)
self._sessions[session_id] = history[-(self.max_history_pairs * 2) :]
def _tools_text(self) -> str:
lines = []
for name, schema in TOOL_SCHEMAS.items():
lines.append(
f"- {name}: {schema['description']} | args={json.dumps(schema['arguments'])}"
)
return "\n".join(lines)
def _build_messages(
self,
session_id: str,
user_message: str,
scratchpad: list[dict[str, Any]],
turn_index: int,
max_turns: int,
) -> list[dict[str, str]]:
history = self._history(session_id)
history_text = "\n".join(
[f"{m['role'].upper()}: {m['content']}" for m in history]
) or "(no prior conversation)"
scratch_text = json.dumps(scratchpad, ensure_ascii=False, indent=2) if scratchpad else "[]"
user_payload = (
f"Current user message:\n{user_message}\n\n"
f"Conversation history:\n{history_text}\n\n"
f"Available tools:\n{self._tools_text()}\n\n"
f"Tool execution log this request:\n{scratch_text}\n\n"
f"Turn {turn_index}/{max_turns}: choose next action."
)
return [
{"role": "system", "content": AGENT_SYSTEM_PROMPT},
{"role": "user", "content": user_payload},
]
def _extract_json_object(self, text: str) -> tuple[dict[str, Any] | None, str | None]:
cleaned = text.strip()
try:
parsed = json.loads(cleaned)
if isinstance(parsed, dict):
return parsed, None
return None, "Response JSON is not an object."
except json.JSONDecodeError:
pass
decoder = json.JSONDecoder()
for idx, ch in enumerate(cleaned):
if ch != "{":
continue
try:
parsed, _ = decoder.raw_decode(cleaned[idx:])
if isinstance(parsed, dict):
return parsed, None
except json.JSONDecodeError:
continue
return None, "Could not parse valid JSON object from model response."
def _normalize_action(self, obj: dict[str, Any]) -> tuple[str, dict[str, Any] | str | None, str | None]:
action = str(obj.get("action", "")).strip().lower()
if action in {"final", "answer"}:
answer = obj.get("answer")
if not isinstance(answer, str) or not answer.strip():
return "invalid", None, "Final action missing non-empty 'answer'."
return "final", answer.strip(), None
if action in {"tool", "call_tool"}:
tool_name = str(obj.get("tool_name", "")).strip()
if tool_name not in TOOLS:
return "invalid", None, f"Unknown tool_name: {tool_name!r}"
arguments = obj.get("arguments", {})
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
return "invalid", None, "Tool arguments string is not valid JSON."
if not isinstance(arguments, dict):
return "invalid", None, "Tool arguments must be an object."
return "tool", {"tool_name": tool_name, "arguments": arguments}, None
return "invalid", None, "Action must be 'tool' or 'final'."
def _required_tool_for_message(self, user_message: str) -> str | None:
text = user_message.lower()
if "claim status" in text or re.search(r"\bclm-\d+\b", text):
return "check_claim_status"
if "premium" in text and ("risk" in text or "coverage" in text):
return "calculate_premium"
if any(k in text for k in ["covered", "coverage", "deductible", "exclusion", "policy"]):
return "search_policy"
return None
def _run_tool(
self,
tool_name: str,
arguments: dict[str, Any],
use_reranker: bool,
) -> dict[str, Any]:
fn = TOOLS[tool_name]
if tool_name == "search_policy":
query = str(arguments.get("query", "")).strip()
if not query:
raise ValueError("search_policy requires 'query'.")
return fn(query=query, use_reranker=use_reranker, top_k=self.default_top_k)
if tool_name == "calculate_premium":
if "coverage" not in arguments or "risk_score" not in arguments:
raise ValueError("calculate_premium requires coverage and risk_score.")
return fn(
coverage=float(arguments["coverage"]),
risk_score=float(arguments["risk_score"]),
)
if tool_name == "check_claim_status":
claim_id = str(arguments.get("claim_id", "")).strip()
if not claim_id:
raise ValueError("check_claim_status requires 'claim_id'.")
return fn(claim_id=claim_id)
raise ValueError(f"Unsupported tool: {tool_name}")
def _fallback_arguments(self, tool_name: str, user_message: str) -> dict[str, Any]:
text = user_message.strip()
if tool_name == "check_claim_status":
match = re.search(r"\bclm-\d+\b", text, flags=re.IGNORECASE)
return {"claim_id": match.group(0).upper() if match else "CLM-1001"}
if tool_name == "calculate_premium":
nums = [float(x) for x in re.findall(r"\d+(?:\.\d+)?", text)]
coverage = nums[0] if nums else 100000.0
risk = nums[1] if len(nums) > 1 else 1.0
return {"coverage": coverage, "risk_score": risk}
return {"query": text}
def chat(
self,
session_id: str,
user_message: str,
use_reranker: bool = True,
max_turns: int = 6,
) -> dict[str, Any]:
scratchpad: list[dict[str, Any]] = []
tool_calls: list[dict[str, Any]] = []
required_tool = self._required_tool_for_message(user_message)
for turn in range(1, max_turns + 1):
messages = self._build_messages(
session_id=session_id,
user_message=user_message,
scratchpad=scratchpad,
turn_index=turn,
max_turns=max_turns,
)
raw = get_llm().complete(messages)
parsed, parse_error = self._extract_json_object(raw)
if parse_error:
scratchpad.append(
{
"type": "parse_error",
"message": parse_error,
"raw_model_output": raw[:1200],
}
)
continue
action_type, payload, action_error = self._normalize_action(parsed)
if action_error:
scratchpad.append(
{
"type": "validation_error",
"message": action_error,
"parsed_model_output": parsed,
}
)
continue
if action_type == "final":
answer = str(payload)
if required_tool and not any(c["tool_name"] == required_tool for c in tool_calls):
scratchpad.append(
{
"type": "policy_enforcement",
"message": (
f"Before final answer, call required tool '{required_tool}' "
f"for this user intent."
),
}
)
continue
self._save_turn(session_id=session_id, user_message=user_message, assistant_message=answer)
return {
"response": answer,
"session_id": session_id,
"tool_calls_made": len(tool_calls),
"tool_trace": tool_calls,
"turns_used": turn,
}
# action_type == "tool"
assert isinstance(payload, dict)
tool_name = str(payload["tool_name"])
arguments = payload["arguments"]
try:
result = self._run_tool(
tool_name=tool_name,
arguments=arguments,
use_reranker=use_reranker,
)
except Exception as exc:
scratchpad.append(
{
"type": "tool_error",
"tool_name": tool_name,
"arguments": arguments,
"message": str(exc),
}
)
continue
call_record = {"tool_name": tool_name, "arguments": arguments, "result": result}
tool_calls.append(call_record)
scratchpad.append({"type": "tool_result", **call_record})
graceful = (
"I reached the maximum tool loop limit (6 turns) for this request. "
"Please refine the question or split it into smaller steps."
)
if required_tool and not tool_calls:
try:
fallback_args = self._fallback_arguments(required_tool, user_message)
fallback_result = self._run_tool(
tool_name=required_tool,
arguments=fallback_args,
use_reranker=use_reranker,
)
fallback_record = {
"tool_name": required_tool,
"arguments": fallback_args,
"result": fallback_result,
"fallback_used": True,
}
tool_calls.append(fallback_record)
if required_tool == "check_claim_status":
graceful = (
f"Claim {fallback_result.get('claim_id')} status: "
f"{fallback_result.get('status')} (updated {fallback_result.get('updated_at')})."
)
elif required_tool == "calculate_premium":
graceful = (
f"Estimated monthly premium: "
f"{fallback_result.get('monthly_premium')} using "
f"{fallback_result.get('formula')}."
)
elif required_tool == "search_policy":
rows = fallback_result.get("results", [])
if rows:
top = rows[0]
graceful = (
f"Top policy match: {top.get('source')} | {top.get('section')}. "
f"{top.get('content')}"
)
except Exception:
pass
self._save_turn(session_id=session_id, user_message=user_message, assistant_message=graceful)
return {
"response": graceful,
"session_id": session_id,
"tool_calls_made": len(tool_calls),
"tool_trace": tool_calls,
"turns_used": max_turns,
"max_turns_reached": True,
}

Xet Storage Details

Size:
13.5 kB
·
Xet hash:
60a6f4af708f1e454c1e08cacbe93d01d4023a620c66d52760b8e3b7b08b8744

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.