File size: 4,845 Bytes
cb70a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""OpenEnv-compatible Varaha wildfire drone environment."""

from __future__ import annotations

import sys
import os
import uuid
from typing import Any, Callable, Optional

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import EnvironmentMetadata

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from varaha_env import VarahaConfig, VarahaEnv, build_random_world
from openenv_wrapper.models import VarahaAction, VarahaObservation, VarahaState


class VarahaEnvironment(Environment[VarahaAction, VarahaObservation, VarahaState]):
    """Wildfire logistics drone environment wrapped for OpenEnv.

    Each episode the drone must deliver supplies to responder zones near
    wildfire hazards, then return to base.  Supports domain-randomised
    worlds when ``world_fn`` is provided.
    """

    def __init__(
        self,
        config: Optional[VarahaConfig] = None,
        world_fn: Optional[Callable[..., None]] = None,
    ) -> None:
        super().__init__()
        self._config = config or VarahaConfig()
        self._world_fn = world_fn
        self._env = VarahaEnv(config=self._config, world_fn=self._world_fn)
        self._episode_id = str(uuid.uuid4())
        self._last_info: dict[str, Any] = {}

    # ------------------------------------------------------------------
    # OpenEnv abstract interface
    # ------------------------------------------------------------------

    def reset(
        self,
        seed: Optional[int] = None,
        episode_id: Optional[str] = None,
        **kwargs: Any,
    ) -> VarahaObservation:
        self._episode_id = episode_id or str(uuid.uuid4())
        obs_dict = self._env.reset(seed=seed)
        self._last_info = {}
        return self._build_observation(obs_dict, reward=0.0, done=False)

    def step(
        self,
        action: VarahaAction,
        timeout_s: Optional[float] = None,
        **kwargs: Any,
    ) -> VarahaObservation:
        action_dict = {
            "ax": action.ax,
            "ay": action.ay,
            "az": action.az,
            "deliver": action.deliver,
            "recharge": action.recharge,
            "tool_call": action.tool_call,
        }
        obs_dict, reward, done, info = self._env.step(action_dict)
        self._last_info = info
        return self._build_observation(obs_dict, reward=reward, done=done, info=info)

    @property
    def state(self) -> VarahaState:
        delivered = sum(1 for t in self._env.targets if t.delivered)
        return VarahaState(
            episode_id=self._episode_id,
            step_count=self._env.step_count,
            cumulative_reward=round(self._env.cumulative_reward, 4),
            deliveries_completed=delivered,
            total_targets=len(self._env.targets),
            battery=round(self._env.drone.battery, 4),
            success=self._env._is_success(),
        )

    # ------------------------------------------------------------------
    # Optional overrides
    # ------------------------------------------------------------------

    def get_metadata(self) -> EnvironmentMetadata:
        return EnvironmentMetadata(
            name="Varaha Wildfire Logistics",
            description=(
                "A 3D drone delivery environment where an agent must navigate "
                "wildfire hazards and obstacles to deliver supplies to responder "
                "zones, then return to base."
            ),
            version="1.0.0",
            author="Varaha Team",
        )

    def close(self) -> None:
        pass

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------

    def _build_observation(
        self,
        obs_dict: dict[str, Any],
        *,
        reward: float,
        done: bool,
        info: dict[str, Any] | None = None,
    ) -> VarahaObservation:
        info = info or {}
        trace = self._env.get_trace() if done else None
        return VarahaObservation(
            done=done,
            reward=round(reward, 4),
            metadata={"info": info},
            drone_position=obs_dict["drone_position"],
            drone_velocity=obs_dict["drone_velocity"],
            battery=obs_dict["battery"],
            carrying_payload=obs_dict["carrying_payload"],
            alive=obs_dict["alive"],
            targets=obs_dict["targets"],
            hazards=obs_dict.get("hazards", []),
            mission=obs_dict.get("mission", {}),
            last_tool_result=obs_dict.get("last_tool_result", {}),
            step_num=obs_dict["step"],
            max_steps=obs_dict["max_steps"],
            reward_breakdown=info.get("reward_breakdown", {}),
            success=self._env._is_success(),
            trace=trace,
        )