DungeonMaster-AI / src /agents /rules_arbiter.py
bhupesh-sf's picture
first commit
f8ba6bf verified
"""
DungeonMaster AI - Rules Arbiter Agent
Specialized agent for rules lookup with LRU caching.
Uses only rules-related MCP tools for focused, accurate responses.
"""
from __future__ import annotations
import logging
import time
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING
from llama_index.core.agent.workflow import FunctionAgent
from llama_index.core.tools import FunctionTool
from src.config.settings import get_settings
from .exceptions import RulesAgentError
from .models import RulesResponse, ToolCallInfo
if TYPE_CHECKING:
from llama_index.core.llms import LLM
logger = logging.getLogger(__name__)
# =============================================================================
# Rules System Prompt
# =============================================================================
def load_rules_prompt() -> str:
"""
Load the rules arbiter system prompt.
Returns:
System prompt string.
"""
settings = get_settings()
prompts_dir = settings.prompts_dir
rules_path = Path(prompts_dir) / "rules_system.txt"
if rules_path.exists():
return rules_path.read_text()
# Fallback prompt
return """You are a D&D 5e rules expert and arbiter.
Your responsibilities:
1. Look up rules accurately using the available tools
2. Cite sources when providing rule information
3. Explain rules clearly and concisely
4. Help adjudicate edge cases fairly
CRITICAL: Always use tools to verify rules - never guess or assume.
When answering:
- Use search_rules for general mechanics questions
- Use get_monster for creature stats
- Use get_spell for spell mechanics
- Use get_condition for status effect rules
Provide the rule, then a brief explanation of how it applies."""
# =============================================================================
# Cache Wrappers
# =============================================================================
class RulesCache:
"""
LRU cache wrapper for rules lookups.
Caches monster stats, spell info, and rule queries
to avoid repeated MCP calls.
"""
def __init__(self, maxsize: int = 100) -> None:
"""
Initialize the rules cache.
Args:
maxsize: Maximum cache size per category.
"""
self._maxsize = maxsize
self._monster_cache: dict[str, dict[str, object]] = {}
self._spell_cache: dict[str, dict[str, object]] = {}
self._condition_cache: dict[str, dict[str, object]] = {}
self._query_cache: dict[str, str] = {}
# Cache statistics
self._hits = 0
self._misses = 0
def get_monster(self, name: str) -> dict[str, object] | None:
"""Get cached monster stats."""
key = name.lower()
if key in self._monster_cache:
self._hits += 1
return self._monster_cache[key]
self._misses += 1
return None
def set_monster(self, name: str, data: dict[str, object]) -> None:
"""Cache monster stats."""
key = name.lower()
if len(self._monster_cache) >= self._maxsize:
# Remove oldest entry (simple FIFO, not true LRU)
oldest = next(iter(self._monster_cache))
del self._monster_cache[oldest]
self._monster_cache[key] = data
def get_spell(self, name: str) -> dict[str, object] | None:
"""Get cached spell info."""
key = name.lower()
if key in self._spell_cache:
self._hits += 1
return self._spell_cache[key]
self._misses += 1
return None
def set_spell(self, name: str, data: dict[str, object]) -> None:
"""Cache spell info."""
key = name.lower()
if len(self._spell_cache) >= self._maxsize:
oldest = next(iter(self._spell_cache))
del self._spell_cache[oldest]
self._spell_cache[key] = data
def get_condition(self, name: str) -> dict[str, object] | None:
"""Get cached condition info."""
key = name.lower()
if key in self._condition_cache:
self._hits += 1
return self._condition_cache[key]
self._misses += 1
return None
def set_condition(self, name: str, data: dict[str, object]) -> None:
"""Cache condition info."""
key = name.lower()
if len(self._condition_cache) >= self._maxsize:
oldest = next(iter(self._condition_cache))
del self._condition_cache[oldest]
self._condition_cache[key] = data
def get_query(self, query: str) -> str | None:
"""Get cached query result."""
key = query.lower().strip()
if key in self._query_cache:
self._hits += 1
return self._query_cache[key]
self._misses += 1
return None
def set_query(self, query: str, result: str) -> None:
"""Cache query result."""
key = query.lower().strip()
if len(self._query_cache) >= self._maxsize:
oldest = next(iter(self._query_cache))
del self._query_cache[oldest]
self._query_cache[key] = result
@property
def hit_rate(self) -> float:
"""Calculate cache hit rate."""
total = self._hits + self._misses
return self._hits / total if total > 0 else 0.0
def clear(self) -> None:
"""Clear all caches."""
self._monster_cache.clear()
self._spell_cache.clear()
self._condition_cache.clear()
self._query_cache.clear()
self._hits = 0
self._misses = 0
# =============================================================================
# RulesArbiterAgent
# =============================================================================
class RulesArbiterAgent:
"""
Specialized rules lookup agent with LRU caching.
Only uses rules-related MCP tools:
- search_rules: Search rules by topic
- get_monster: Get monster stat block
- get_spell: Get spell description
- get_class_info: Get class features
- get_race_info: Get race abilities
- get_condition: Get condition effects
"""
# Rules tool names to filter
RULES_TOOLS = frozenset([
"search_rules",
"get_monster",
"search_monsters",
"get_spell",
"search_spells",
"get_class_info",
"get_race_info",
"get_item",
"get_condition",
# MCP prefixed versions
"mcp_search_rules",
"mcp_get_monster",
"mcp_search_monsters",
"mcp_get_spell",
"mcp_search_spells",
"mcp_get_class_info",
"mcp_get_race_info",
"mcp_get_item",
"mcp_get_condition",
])
def __init__(
self,
llm: LLM,
tools: list[FunctionTool],
cache_size: int = 100,
) -> None:
"""
Initialize the Rules Arbiter agent.
Args:
llm: LlamaIndex LLM instance.
tools: List of ALL MCP tools (will be filtered to rules only).
cache_size: Maximum cache size per category.
"""
self._llm = llm
# Filter to rules tools only
self._tools = [
tool for tool in tools
if tool.metadata.name in self.RULES_TOOLS
]
logger.info(
f"RulesArbiterAgent initialized with {len(self._tools)} rules tools"
)
# Initialize cache
self._cache = RulesCache(maxsize=cache_size)
# Load system prompt
self._system_prompt = load_rules_prompt()
# Create focused FunctionAgent
self._agent = FunctionAgent(
llm=llm,
tools=self._tools,
system_prompt=self._system_prompt,
)
@property
def cache_hit_rate(self) -> float:
"""Get cache hit rate."""
return self._cache.hit_rate
async def lookup(self, query: str) -> RulesResponse:
"""
Look up rules for a query.
Args:
query: The rules question or topic.
Returns:
RulesResponse with answer and sources.
"""
start_time = time.time()
# Check cache first
cached_result = self._cache.get_query(query)
if cached_result:
logger.debug(f"Cache hit for query: {query[:50]}...")
return RulesResponse(
answer=cached_result,
sources=["cache"],
confidence=1.0,
from_cache=True,
)
try:
# Run the agent
handler = self._agent.run(user_msg=query)
response_text = ""
tool_calls: list[ToolCallInfo] = []
sources: list[str] = []
async for event in handler.stream_events():
event_type = type(event).__name__
if event_type == "AgentOutput":
response_text = str(event.response) if hasattr(event, "response") else ""
elif event_type == "ToolCall":
if hasattr(event, "tool_name"):
tool_info = ToolCallInfo(
tool_name=event.tool_name,
arguments=getattr(event, "arguments", {}),
)
tool_calls.append(tool_info)
sources.append(event.tool_name)
elif event_type == "ToolCallResult":
if hasattr(event, "result") and tool_calls:
tool_calls[-1].result = event.result
tool_calls[-1].success = True
# Cache specific lookups
self._cache_tool_result(tool_calls[-1])
# Cache the full query result
if response_text:
self._cache.set_query(query, response_text)
return RulesResponse(
answer=response_text,
sources=list(set(sources)),
confidence=1.0 if tool_calls else 0.5,
tool_calls=tool_calls,
from_cache=False,
)
except Exception as e:
logger.error(f"Rules lookup failed: {e}")
raise RulesAgentError(str(e)) from e
def _cache_tool_result(self, tool_call: ToolCallInfo) -> None:
"""Cache individual tool results."""
if not tool_call.success or not tool_call.result:
return
result = tool_call.result
if not isinstance(result, dict):
return
tool_name = tool_call.tool_name.replace("mcp_", "")
if tool_name == "get_monster":
name = tool_call.arguments.get("name", "")
if name:
self._cache.set_monster(str(name), result)
elif tool_name == "get_spell":
name = tool_call.arguments.get("name", "")
if name:
self._cache.set_spell(str(name), result)
elif tool_name == "get_condition":
name = tool_call.arguments.get("name", "")
if name:
self._cache.set_condition(str(name), result)
async def get_monster_stats(self, monster_name: str) -> dict[str, object] | None:
"""
Get monster stats with caching.
Args:
monster_name: Name of the monster.
Returns:
Monster stat block or None if not found.
"""
# Check cache
cached = self._cache.get_monster(monster_name)
if cached:
return cached
# Look up via agent
response = await self.lookup(f"Get full stat block for {monster_name}")
# Try to extract monster data from response
for tool_call in response.tool_calls:
if "monster" in tool_call.tool_name.lower():
if tool_call.result and isinstance(tool_call.result, dict):
return tool_call.result
return None
async def get_spell_info(self, spell_name: str) -> dict[str, object] | None:
"""
Get spell info with caching.
Args:
spell_name: Name of the spell.
Returns:
Spell info or None if not found.
"""
# Check cache
cached = self._cache.get_spell(spell_name)
if cached:
return cached
# Look up via agent
response = await self.lookup(f"Get full description for the spell {spell_name}")
# Try to extract spell data from response
for tool_call in response.tool_calls:
if "spell" in tool_call.tool_name.lower():
if tool_call.result and isinstance(tool_call.result, dict):
return tool_call.result
return None
async def get_condition_info(self, condition_name: str) -> dict[str, object] | None:
"""
Get condition info with caching.
Args:
condition_name: Name of the condition.
Returns:
Condition info or None if not found.
"""
# Check cache
cached = self._cache.get_condition(condition_name)
if cached:
return cached
# Look up via agent
response = await self.lookup(f"What are the effects of the {condition_name} condition?")
# Try to extract condition data from response
for tool_call in response.tool_calls:
if "condition" in tool_call.tool_name.lower():
if tool_call.result and isinstance(tool_call.result, dict):
return tool_call.result
return None
def clear_cache(self) -> None:
"""Clear all caches."""
self._cache.clear()
logger.info("Rules cache cleared")