Spaces:
Sleeping
Sleeping
File size: 3,170 Bytes
5fe93dd 9fdf940 5fe93dd 9fdf940 5fe93dd 9fdf940 5fe93dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | import uuid
from typing import List, Dict, Any
from api.schemas.state import NexusState
class EpisodeState:
def __init__(self, scenario_id: str, task: str, difficulty: str, max_rounds: int, scenario_data: dict = None):
self.episode_id = str(uuid.uuid4())
self.scenario_id = scenario_id
self.task = task
self.difficulty = difficulty
self.current_round = 1
self.max_rounds = max_rounds
from config import settings
self.messages_by_agent: Dict[str, List[str]] = {a["id"]: [] for a in settings.AGENTS}
self.all_messages: List[str] = []
self.tool_calls_made: List[Dict] = []
self.clues_found: List[str] = []
self.last_partner_message: str = ""
self.previous_tool_calls: List[str] = []
self.root_cause_found = False
self.fix_proposed = False
self.fix_correct = False
self.fix_verified = False
self.cumulative_reward = 0.0
self.reward_history: List[float] = []
self.done = False
self.investigation_stage = "investigating"
self.steps_taken = 0
import copy
self.system_state = copy.deepcopy(scenario_data.get("initial_state", {})) if scenario_data else {}
def add_message(self, agent_id: str, message: str):
self.steps_taken += 1
self.all_messages.append(message)
if agent_id not in self.messages_by_agent:
self.messages_by_agent[agent_id] = []
self.messages_by_agent[agent_id].append(message)
from config import settings
# A full round is defined as all agents having spoken at least once in the current sequence
# We can approximate this by incrementing round when the last agent in the list speaks
if settings.AGENTS and agent_id == settings.AGENTS[-1]["id"]:
self.current_round += 1
self.last_partner_message = message
def add_tool_call(self, tool_name: str, params: dict):
call_signature = f"{tool_name}:{str(params)}"
self.tool_calls_made.append({"tool_name": tool_name, "params": params})
self.previous_tool_calls.append(call_signature)
def add_clue(self, clue: str):
if clue not in self.clues_found:
self.clues_found.append(clue)
def to_pydantic(self) -> NexusState:
return NexusState(
episode_id=self.episode_id,
scenario_id=self.scenario_id,
task=self.task,
difficulty=self.difficulty,
current_round=self.current_round,
max_rounds=self.max_rounds,
messages_by_agent=self.messages_by_agent,
tool_calls_made=self.tool_calls_made,
clues_found=self.clues_found,
root_cause_found=self.root_cause_found,
fix_proposed=self.fix_proposed,
fix_verified=self.fix_verified,
cumulative_reward=self.cumulative_reward,
reward_history=self.reward_history,
done=self.done,
investigation_stage=self.investigation_stage
)
|