import uuid from typing import List, Dict, Any from api.schemas.state import NexusState class EpisodeState: def __init__(self, scenario_id: str, task: str, difficulty: str, max_rounds: int, scenario_data: dict = None): self.episode_id = str(uuid.uuid4()) self.scenario_id = scenario_id self.task = task self.difficulty = difficulty self.current_round = 1 self.max_rounds = max_rounds from config import settings self.messages_by_agent: Dict[str, List[str]] = {a["id"]: [] for a in settings.AGENTS} self.all_messages: List[str] = [] self.tool_calls_made: List[Dict] = [] self.clues_found: List[str] = [] self.last_partner_message: str = "" self.previous_tool_calls: List[str] = [] self.root_cause_found = False self.fix_proposed = False self.fix_correct = False self.fix_verified = False self.cumulative_reward = 0.0 self.reward_history: List[float] = [] self.done = False self.investigation_stage = "investigating" self.steps_taken = 0 import copy self.system_state = copy.deepcopy(scenario_data.get("initial_state", {})) if scenario_data else {} def add_message(self, agent_id: str, message: str): self.steps_taken += 1 self.all_messages.append(message) if agent_id not in self.messages_by_agent: self.messages_by_agent[agent_id] = [] self.messages_by_agent[agent_id].append(message) from config import settings # A full round is defined as all agents having spoken at least once in the current sequence # We can approximate this by incrementing round when the last agent in the list speaks if settings.AGENTS and agent_id == settings.AGENTS[-1]["id"]: self.current_round += 1 self.last_partner_message = message def add_tool_call(self, tool_name: str, params: dict): call_signature = f"{tool_name}:{str(params)}" self.tool_calls_made.append({"tool_name": tool_name, "params": params}) self.previous_tool_calls.append(call_signature) def add_clue(self, clue: str): if clue not in self.clues_found: self.clues_found.append(clue) def to_pydantic(self) -> NexusState: return NexusState( episode_id=self.episode_id, scenario_id=self.scenario_id, task=self.task, difficulty=self.difficulty, current_round=self.current_round, max_rounds=self.max_rounds, messages_by_agent=self.messages_by_agent, tool_calls_made=self.tool_calls_made, clues_found=self.clues_found, root_cause_found=self.root_cause_found, fix_proposed=self.fix_proposed, fix_verified=self.fix_verified, cumulative_reward=self.cumulative_reward, reward_history=self.reward_history, done=self.done, investigation_stage=self.investigation_stage )