autofarm / tests /sim_support.py
isabelku's picture
AutoFarm Space deploy
826dd96
from __future__ import annotations
from pathlib import Path
from autofarm.contracts import (
ChallengeState,
EpisodeOutcome,
EpisodeOutcomeStatus,
FieldZoneState,
PriorityBand,
RiddleCategory,
SimulatorMode,
)
from autofarm.sim.engine import (
DEFAULT_HOME_ZONE_ID,
DEFAULT_INTERACTIVE_BASE_STATION_ZONE_ID,
DEFAULT_INTERACTIVE_MAX_ENERGY,
HiddenZoneState,
SimulatorEnvironment,
clone_hidden_zone_state,
default_hidden_for_zone,
)
from autofarm.sim.scenarios import ChallengeTemplate, InspectFixture, RIDDLE_SAFETY_MAX_STEPS, RiddleSpec, ZoneScenarioTruth
from autofarm.sim.scenarios import available_challenge_templates, fallback_goal_policy
class EmptyCatalog:
def sample_by_identity(self, *, dataset_name: str, source_dataset: str, sample_id: str): # noqa: ARG002
return None
def samples_for_tag(self, dataset_name: str, tag: str, positive: bool = True): # noqa: ARG002
return []
def samples_for_dataset(self, dataset_name: str): # noqa: ARG002
return []
def make_fixture(*expected_anomalies: str, simulated_backend_error: str | None = None) -> InspectFixture:
return InspectFixture(
sample_ref=None,
expected_anomalies=tuple(expected_anomalies),
simulated_backend_error=simulated_backend_error,
)
def make_environment(
*,
name: str,
zone_id: str = "zone_r02_c02",
priority_band: PriorityBand = PriorityBand.HIGH,
priority_hint: float = 0.88,
state_confidence: float = 0.4,
starting_uncertainty: float = 0.6,
fixtures: tuple[InspectFixture, ...] = (),
travel_budget: float = 10.0,
) -> SimulatorEnvironment:
template_policy = available_challenge_templates().get(name)
goal_policy = (
template_policy.goal_policy
if template_policy is not None
else fallback_goal_policy().model_copy(
update={
"allowed_terminal_intents": [],
"required_positive_confirmation_count": 999,
}
)
)
zone_states = build_test_zone_grid(focus_zone_id=zone_id)
zone_state = next(zone for zone in zone_states if zone.zone_id == zone_id)
zone_state.static_features["scenario_priority_hint"] = priority_hint
zone_state.state_confidence = state_confidence
truth = ZoneScenarioTruth(
true_priority_band=priority_band,
inspect_fixtures=fixtures,
priority_hint=priority_hint,
confidence_override=state_confidence,
starting_uncertainty=starting_uncertainty,
)
template = ChallengeTemplate(
template_id=name,
display_name=name,
description=name,
resolver_id=name,
truth=truth,
goal_policy=goal_policy,
interactive_enabled=True,
)
riddle = RiddleSpec(
name=name,
description=name,
category=RiddleCategory.MISSION,
travel_budget=travel_budget,
challenge_template_id=name,
target_zone_id=zone_id,
evaluator_id=name,
objective=name,
challenge_summary=name,
why_it_matters=name,
)
hidden = HiddenZoneState(
zone_id=zone_id,
true_priority_band=priority_band,
inspect_fixtures=fixtures,
current_uncertainty=starting_uncertainty,
)
default_hidden_zone_states = {state.zone_id: default_hidden_for_zone(state) for state in zone_states}
hidden_zone_states = {
state.zone_id: default_hidden_for_zone(state)
for state in zone_states
}
hidden_zone_states[zone_id] = hidden
baseline_zone_states = {state.zone_id: state.model_copy(deep=True) for state in zone_states}
env = SimulatorEnvironment(
mode=SimulatorMode.RIDDLE,
name=name,
description=name,
challenge_templates={name: template},
zone_states=zone_states,
hidden_zone_states=hidden_zone_states,
baseline_zone_states=baseline_zone_states,
default_hidden_zone_states=default_hidden_zone_states,
catalog=EmptyCatalog(),
current_zone_id=zone_id,
travel_budget_remaining=travel_budget,
max_energy=DEFAULT_INTERACTIVE_MAX_ENERGY,
energy_remaining=DEFAULT_INTERACTIVE_MAX_ENERGY,
riddle_spec=riddle,
outcome=EpisodeOutcome(
status=EpisodeOutcomeStatus.IN_PROGRESS,
done_reason="running",
message="Riddle in progress.",
safety_step_cap=RIDDLE_SAFETY_MAX_STEPS,
),
active_challenges={
zone_id: ChallengeState(
instance_id=f"{name}:{zone_id}:test",
template_id=name,
name=name,
description=name,
zone_id=zone_id,
evaluator_id=name,
goal_policy=goal_policy,
)
},
last_visited_step_by_zone={state.zone_id: (0 if state.zone_id == zone_id else -1) for state in zone_states},
)
env.initialize_indexes()
return env
def make_interactive_environment(
*,
zone_id: str = DEFAULT_INTERACTIVE_BASE_STATION_ZONE_ID,
max_energy: float = DEFAULT_INTERACTIVE_MAX_ENERGY,
travel_budget: float | None = None,
) -> SimulatorEnvironment:
del max_energy
zone_states = build_test_zone_grid(focus_zone_id=zone_id)
default_hidden_zone_states = {state.zone_id: default_hidden_for_zone(state) for state in zone_states}
env = SimulatorEnvironment(
mode=SimulatorMode.INTERACTIVE,
name="interactive_world",
description="interactive_world",
challenge_templates={},
zone_states=zone_states,
hidden_zone_states={zone_id: clone_hidden_zone_state(hidden) for zone_id, hidden in default_hidden_zone_states.items()},
baseline_zone_states={state.zone_id: state.model_copy(deep=True) for state in zone_states},
default_hidden_zone_states=default_hidden_zone_states,
catalog=EmptyCatalog(),
current_zone_id=DEFAULT_HOME_ZONE_ID,
travel_budget_remaining=travel_budget,
home_zone_id=DEFAULT_HOME_ZONE_ID,
home_neighbor_zone_id=zone_id,
max_energy=DEFAULT_INTERACTIVE_MAX_ENERGY,
energy_remaining=DEFAULT_INTERACTIVE_MAX_ENERGY,
outcome=EpisodeOutcome(
status=EpisodeOutcomeStatus.IN_PROGRESS,
done_reason="running",
message="Interactive world is running.",
),
last_visited_step_by_zone={state.zone_id: (0 if state.zone_id == zone_id else -1) for state in zone_states},
)
env.initialize_indexes()
return env
def parse_zone_coords(zone_id: str) -> tuple[int, int]:
row_token = zone_id.split("_")[1]
col_token = zone_id.split("_")[2]
return int(row_token.replace("r", "")), int(col_token.replace("c", ""))
def build_test_zone_grid(*, focus_zone_id: str) -> list[FieldZoneState]:
zone_states: list[FieldZoneState] = []
for row in range(1, 4):
for col in range(1, 4):
zone_id = f"zone_r{row:02d}_c{col:02d}"
zone_states.append(
FieldZoneState(
zone_id=zone_id,
timestamp="2026-04-01",
static_features={
"scenario_priority_hint": 0.25 if zone_id != focus_zone_id else 0.4,
"recommended_action": "none",
},
weather_proxies={},
soil_source_status="usda_sda_exact_point",
state_confidence=0.75,
recent_findings=[],
row=row,
col=col,
)
)
by_coord = {(zone.row, zone.col): zone.zone_id for zone in zone_states}
enriched: list[FieldZoneState] = []
for zone in zone_states:
neighbors: list[str] = []
for row_delta, col_delta in ((-1, 0), (0, 1), (1, 0), (0, -1)):
neighbor_zone_id = by_coord.get((zone.row + row_delta, zone.col + col_delta))
if neighbor_zone_id is not None:
neighbors.append(neighbor_zone_id)
enriched.append(zone.model_copy(update={"neighbor_zone_ids": neighbors}))
return enriched