# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Wire-protocol schemas for the CrisisWorld environment. Implements design doc §11.1 — CrisisWorld-facing OpenEnv interface: - Outer action variants (6 MVP + 1 V2-declared-rejected) as a discriminated union on the ``kind`` field. - CrisisworldcortexAction: OpenEnv wire wrapper carrying the payload. - CrisisworldcortexObservation: full observation shape (regions, resources, policy, meta, action log). - Supporting atoms (RegionTelemetry, Restriction, LegalConstraint, ResourceInventory, ExecutedAction) and literal type aliases. Latent world state is deliberately NOT declared here — it lives in ``server/simulator/seir_model.py`` (session 5+) so the wire/internal boundary is enforced structurally (nothing that imports ``models`` can see latent state). """ from typing import Annotated, List, Literal, Union from openenv.core.env_server.types import Action, Observation from pydantic import BaseModel, Field # ============================================================================ # Shared vocabulary # ============================================================================ ResourceType = Literal["test_kits", "hospital_beds", "mobile_units", "vaccine_doses"] RegionId = str # e.g. "R1", "R2", ... Severity = Literal["none", "light", "moderate", "strict"] # ============================================================================ # Outer-action variants — 6 MVP + 1 V2 (design §6.3) # ============================================================================ class DeployResource(BaseModel): """Deploy a quantity of a resource type to a region.""" kind: Literal["deploy_resource"] = "deploy_resource" region: RegionId resource_type: ResourceType quantity: int = Field(ge=0) class RequestData(BaseModel): """Reduce telemetry noise for a region for a few ticks (costs cognition budget).""" kind: Literal["request_data"] = "request_data" region: RegionId data_type: Literal["case_survey", "hospital_audit", "compliance_check"] class RestrictMovement(BaseModel): """Apply a restriction severity to a region. Strict severity may be blocked by a legal constraint until ``escalate(national)`` is invoked (design §6.5). """ kind: Literal["restrict_movement"] = "restrict_movement" region: RegionId severity: Severity class Escalate(BaseModel): """Escalate to a higher authority. Unlocks additional action classes.""" kind: Literal["escalate"] = "escalate" to_authority: Literal["regional", "national"] class ReallocateBudget(BaseModel): """Transfer resource units between resource types (small efficiency loss).""" kind: Literal["reallocate_budget"] = "reallocate_budget" from_resource: ResourceType to_resource: ResourceType amount: int = Field(ge=0) class NoOp(BaseModel): """Advance the tick without intervention.""" kind: Literal["no_op"] = "no_op" class PublicCommunication(BaseModel): """[V2] Declared for forward-compatibility. Rejected at runtime in MVP per design §6.3 / §19: env marks ``accepted=False`` in the action log; the training-reward grader applies the -0.1 well-formed-illegal penalty. """ kind: Literal["public_communication"] = "public_communication" audience: Literal["general", "workers", "leaders"] message_class: Literal["informational", "reassurance", "directive"] honesty: float = Field(ge=0.0, le=1.0) OuterActionPayload = Annotated[ Union[ DeployResource, RequestData, RestrictMovement, Escalate, ReallocateBudget, NoOp, PublicCommunication, ], Field(discriminator="kind"), ] # ============================================================================ # Wire-protocol Action (frozen class name) — wraps the discriminated union # ============================================================================ class CrisisworldcortexAction(Action): """OpenEnv wire wrapper carrying an ``OuterActionPayload``. JSON shape:: {"action": {"kind": "deploy_resource", "region": "R1", ...}, "metadata": {}} The class name is frozen by the OpenEnv template and must not be renamed. """ action: OuterActionPayload # ============================================================================ # Observation atoms # ============================================================================ class RegionTelemetry(BaseModel): """Per-region observed telemetry (delayed + noised derivative of latent state).""" region: RegionId reported_cases_d_ago: int = Field(ge=0) hospital_load: float = Field(ge=0.0, le=1.0) compliance_proxy: float = Field(ge=0.0, le=1.0) class Restriction(BaseModel): region: RegionId severity: Severity ticks_remaining: int = Field(ge=0) class LegalConstraint(BaseModel): rule_id: str blocked_action: str # e.g. "restrict_movement.strict" unlock_via: Literal["escalate"] = "escalate" class ResourceInventory(BaseModel): test_kits: int = Field(default=0, ge=0) hospital_beds_free: int = Field(default=0, ge=0) mobile_units: int = Field(default=0, ge=0) vaccine_doses: int = Field(default=0, ge=0) class ExecutedAction(BaseModel): """Entry in the recent-action log. ``accepted=False`` signals an illegal action (e.g. PublicCommunication in MVP) that the env rejected — grader reads this to apply the -0.1 well-formed-illegal penalty. """ tick: int = Field(ge=0) action: OuterActionPayload accepted: bool # ============================================================================ # Wire-protocol Observation (frozen class name; expanded fields) # ============================================================================ class CrisisworldcortexObservation(Observation): """Full CrisisWorld observation per design §11.1. The class name is frozen by the OpenEnv template. Inherits ``done`` and ``reward`` from ``openenv.core.env_server.types.Observation``. """ regions: List[RegionTelemetry] = Field(default_factory=list) resources: ResourceInventory = Field(default_factory=ResourceInventory) active_restrictions: List[Restriction] = Field(default_factory=list) legal_constraints: List[LegalConstraint] = Field(default_factory=list) tick: int = Field(default=0, ge=0) ticks_remaining: int = Field(default=0, ge=0) cognition_budget_remaining: int = Field(default=0, ge=0) recent_action_log: List[ExecutedAction] = Field(default_factory=list)