File size: 2,017 Bytes
4afc4db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""Client-side state and observation types for the MedChain Env environment."""

from openenv.core.env_server import State
from openenv.core.env_server.types import Observation
from pydantic import Field
from typing import List, Optional

AVAILABLE_TOOLS = [
    "read_inbox",
    "query_erp",
    "query_supplier",
    "query_forecast",
    "submit_po",
    "transfer",
    "quarantine_lot",
    "file_justification",
    "end_shift",
]


class MedchainState(State):
    """Runtime state exposed by the environment server."""

    task: str = Field(
        default="",
        description="Task name (single_ward_stable / multi_ward_seasonal / hospital_network_crisis)",
    )
    day: int = Field(default=0, description="Current simulation day (1-indexed)")
    max_days: int = Field(default=0, description="Total simulation days for this task")
    actions_remaining: int = Field(default=0, description="Actions left this shift")
    budget_used: float = Field(default=0.0, description="Outstanding committed PO budget ($)")
    budget_limit: float = Field(default=0.0, description="Budget ceiling for outstanding orders ($)")
    unread_messages: int = Field(default=0, description="Unread inbox messages")
    orders_in_transit: int = Field(default=0, description="POs currently in transit")

class MedObservation(Observation):
    """Initial observation returned by reset(). Contains the shift dashboard text."""

    dashboard: str = Field(default="", description="Dashboard state")
    available_tools: List[str] = Field(default_factory=list, description="Available tools")
    episode_id: str = Field(default="", description="Episode ID")


class MedchainToolObservation(Observation):
    """Observation returned for every tool-call step."""
    tool_name: str = Field(default="", description="Name of the tool that was called")
    tool_result: str = Field(default="", description="Text result from the tool")
    error_msg: Optional[str] = Field(default=None, description="Error message if call failed")