""" DungeonMaster AI - Rules Arbiter Agent Specialized agent for rules lookup with LRU caching. Uses only rules-related MCP tools for focused, accurate responses. """ from __future__ import annotations import logging import time from functools import lru_cache from pathlib import Path from typing import TYPE_CHECKING from llama_index.core.agent.workflow import FunctionAgent from llama_index.core.tools import FunctionTool from src.config.settings import get_settings from .exceptions import RulesAgentError from .models import RulesResponse, ToolCallInfo if TYPE_CHECKING: from llama_index.core.llms import LLM logger = logging.getLogger(__name__) # ============================================================================= # Rules System Prompt # ============================================================================= def load_rules_prompt() -> str: """ Load the rules arbiter system prompt. Returns: System prompt string. """ settings = get_settings() prompts_dir = settings.prompts_dir rules_path = Path(prompts_dir) / "rules_system.txt" if rules_path.exists(): return rules_path.read_text() # Fallback prompt return """You are a D&D 5e rules expert and arbiter. Your responsibilities: 1. Look up rules accurately using the available tools 2. Cite sources when providing rule information 3. Explain rules clearly and concisely 4. Help adjudicate edge cases fairly CRITICAL: Always use tools to verify rules - never guess or assume. When answering: - Use search_rules for general mechanics questions - Use get_monster for creature stats - Use get_spell for spell mechanics - Use get_condition for status effect rules Provide the rule, then a brief explanation of how it applies.""" # ============================================================================= # Cache Wrappers # ============================================================================= class RulesCache: """ LRU cache wrapper for rules lookups. Caches monster stats, spell info, and rule queries to avoid repeated MCP calls. """ def __init__(self, maxsize: int = 100) -> None: """ Initialize the rules cache. Args: maxsize: Maximum cache size per category. """ self._maxsize = maxsize self._monster_cache: dict[str, dict[str, object]] = {} self._spell_cache: dict[str, dict[str, object]] = {} self._condition_cache: dict[str, dict[str, object]] = {} self._query_cache: dict[str, str] = {} # Cache statistics self._hits = 0 self._misses = 0 def get_monster(self, name: str) -> dict[str, object] | None: """Get cached monster stats.""" key = name.lower() if key in self._monster_cache: self._hits += 1 return self._monster_cache[key] self._misses += 1 return None def set_monster(self, name: str, data: dict[str, object]) -> None: """Cache monster stats.""" key = name.lower() if len(self._monster_cache) >= self._maxsize: # Remove oldest entry (simple FIFO, not true LRU) oldest = next(iter(self._monster_cache)) del self._monster_cache[oldest] self._monster_cache[key] = data def get_spell(self, name: str) -> dict[str, object] | None: """Get cached spell info.""" key = name.lower() if key in self._spell_cache: self._hits += 1 return self._spell_cache[key] self._misses += 1 return None def set_spell(self, name: str, data: dict[str, object]) -> None: """Cache spell info.""" key = name.lower() if len(self._spell_cache) >= self._maxsize: oldest = next(iter(self._spell_cache)) del self._spell_cache[oldest] self._spell_cache[key] = data def get_condition(self, name: str) -> dict[str, object] | None: """Get cached condition info.""" key = name.lower() if key in self._condition_cache: self._hits += 1 return self._condition_cache[key] self._misses += 1 return None def set_condition(self, name: str, data: dict[str, object]) -> None: """Cache condition info.""" key = name.lower() if len(self._condition_cache) >= self._maxsize: oldest = next(iter(self._condition_cache)) del self._condition_cache[oldest] self._condition_cache[key] = data def get_query(self, query: str) -> str | None: """Get cached query result.""" key = query.lower().strip() if key in self._query_cache: self._hits += 1 return self._query_cache[key] self._misses += 1 return None def set_query(self, query: str, result: str) -> None: """Cache query result.""" key = query.lower().strip() if len(self._query_cache) >= self._maxsize: oldest = next(iter(self._query_cache)) del self._query_cache[oldest] self._query_cache[key] = result @property def hit_rate(self) -> float: """Calculate cache hit rate.""" total = self._hits + self._misses return self._hits / total if total > 0 else 0.0 def clear(self) -> None: """Clear all caches.""" self._monster_cache.clear() self._spell_cache.clear() self._condition_cache.clear() self._query_cache.clear() self._hits = 0 self._misses = 0 # ============================================================================= # RulesArbiterAgent # ============================================================================= class RulesArbiterAgent: """ Specialized rules lookup agent with LRU caching. Only uses rules-related MCP tools: - search_rules: Search rules by topic - get_monster: Get monster stat block - get_spell: Get spell description - get_class_info: Get class features - get_race_info: Get race abilities - get_condition: Get condition effects """ # Rules tool names to filter RULES_TOOLS = frozenset([ "search_rules", "get_monster", "search_monsters", "get_spell", "search_spells", "get_class_info", "get_race_info", "get_item", "get_condition", # MCP prefixed versions "mcp_search_rules", "mcp_get_monster", "mcp_search_monsters", "mcp_get_spell", "mcp_search_spells", "mcp_get_class_info", "mcp_get_race_info", "mcp_get_item", "mcp_get_condition", ]) def __init__( self, llm: LLM, tools: list[FunctionTool], cache_size: int = 100, ) -> None: """ Initialize the Rules Arbiter agent. Args: llm: LlamaIndex LLM instance. tools: List of ALL MCP tools (will be filtered to rules only). cache_size: Maximum cache size per category. """ self._llm = llm # Filter to rules tools only self._tools = [ tool for tool in tools if tool.metadata.name in self.RULES_TOOLS ] logger.info( f"RulesArbiterAgent initialized with {len(self._tools)} rules tools" ) # Initialize cache self._cache = RulesCache(maxsize=cache_size) # Load system prompt self._system_prompt = load_rules_prompt() # Create focused FunctionAgent self._agent = FunctionAgent( llm=llm, tools=self._tools, system_prompt=self._system_prompt, ) @property def cache_hit_rate(self) -> float: """Get cache hit rate.""" return self._cache.hit_rate async def lookup(self, query: str) -> RulesResponse: """ Look up rules for a query. Args: query: The rules question or topic. Returns: RulesResponse with answer and sources. """ start_time = time.time() # Check cache first cached_result = self._cache.get_query(query) if cached_result: logger.debug(f"Cache hit for query: {query[:50]}...") return RulesResponse( answer=cached_result, sources=["cache"], confidence=1.0, from_cache=True, ) try: # Run the agent handler = self._agent.run(user_msg=query) response_text = "" tool_calls: list[ToolCallInfo] = [] sources: list[str] = [] async for event in handler.stream_events(): event_type = type(event).__name__ if event_type == "AgentOutput": response_text = str(event.response) if hasattr(event, "response") else "" elif event_type == "ToolCall": if hasattr(event, "tool_name"): tool_info = ToolCallInfo( tool_name=event.tool_name, arguments=getattr(event, "arguments", {}), ) tool_calls.append(tool_info) sources.append(event.tool_name) elif event_type == "ToolCallResult": if hasattr(event, "result") and tool_calls: tool_calls[-1].result = event.result tool_calls[-1].success = True # Cache specific lookups self._cache_tool_result(tool_calls[-1]) # Cache the full query result if response_text: self._cache.set_query(query, response_text) return RulesResponse( answer=response_text, sources=list(set(sources)), confidence=1.0 if tool_calls else 0.5, tool_calls=tool_calls, from_cache=False, ) except Exception as e: logger.error(f"Rules lookup failed: {e}") raise RulesAgentError(str(e)) from e def _cache_tool_result(self, tool_call: ToolCallInfo) -> None: """Cache individual tool results.""" if not tool_call.success or not tool_call.result: return result = tool_call.result if not isinstance(result, dict): return tool_name = tool_call.tool_name.replace("mcp_", "") if tool_name == "get_monster": name = tool_call.arguments.get("name", "") if name: self._cache.set_monster(str(name), result) elif tool_name == "get_spell": name = tool_call.arguments.get("name", "") if name: self._cache.set_spell(str(name), result) elif tool_name == "get_condition": name = tool_call.arguments.get("name", "") if name: self._cache.set_condition(str(name), result) async def get_monster_stats(self, monster_name: str) -> dict[str, object] | None: """ Get monster stats with caching. Args: monster_name: Name of the monster. Returns: Monster stat block or None if not found. """ # Check cache cached = self._cache.get_monster(monster_name) if cached: return cached # Look up via agent response = await self.lookup(f"Get full stat block for {monster_name}") # Try to extract monster data from response for tool_call in response.tool_calls: if "monster" in tool_call.tool_name.lower(): if tool_call.result and isinstance(tool_call.result, dict): return tool_call.result return None async def get_spell_info(self, spell_name: str) -> dict[str, object] | None: """ Get spell info with caching. Args: spell_name: Name of the spell. Returns: Spell info or None if not found. """ # Check cache cached = self._cache.get_spell(spell_name) if cached: return cached # Look up via agent response = await self.lookup(f"Get full description for the spell {spell_name}") # Try to extract spell data from response for tool_call in response.tool_calls: if "spell" in tool_call.tool_name.lower(): if tool_call.result and isinstance(tool_call.result, dict): return tool_call.result return None async def get_condition_info(self, condition_name: str) -> dict[str, object] | None: """ Get condition info with caching. Args: condition_name: Name of the condition. Returns: Condition info or None if not found. """ # Check cache cached = self._cache.get_condition(condition_name) if cached: return cached # Look up via agent response = await self.lookup(f"What are the effects of the {condition_name} condition?") # Try to extract condition data from response for tool_call in response.tool_calls: if "condition" in tool_call.tool_name.lower(): if tool_call.result and isinstance(tool_call.result, dict): return tool_call.result return None def clear_cache(self) -> None: """Clear all caches.""" self._cache.clear() logger.info("Rules cache cleared")