Spaces:
Runtime error
Runtime error
initial push: overflow_env with Gradio RL demo UI
Browse files- DESIGN.md +791 -0
- Dockerfile +22 -0
- README.md +64 -6
- __init__.py +27 -0
- app.py +424 -0
- client.py +92 -0
- models.py +134 -0
- openenv.yaml +6 -0
- policies/__init__.py +5 -0
- policies/__pycache__/__init__.cpython-314.pyc +0 -0
- policies/__pycache__/base_policy.cpython-314.pyc +0 -0
- policies/__pycache__/flat_mlp_policy.cpython-314.pyc +0 -0
- policies/__pycache__/policy_spec.cpython-314.pyc +0 -0
- policies/__pycache__/ticket_attention_policy.cpython-314.pyc +0 -0
- policies/base_policy.py +66 -0
- policies/flat_mlp_policy.py +50 -0
- policies/policy_spec.py +409 -0
- policies/ticket_attention_policy.py +227 -0
- pyproject.toml +33 -0
- requirements.txt +8 -0
- server/__init__.py +0 -0
- server/__pycache__/__init__.cpython-314.pyc +0 -0
- server/__pycache__/overflow_environment.cpython-314.pyc +0 -0
- server/__pycache__/policy_adapter.cpython-314.pyc +0 -0
- server/app.py +46 -0
- server/overflow_environment.py +497 -0
- server/policy_adapter.py +80 -0
- server/requirements.txt +8 -0
- training/__init__.py +0 -0
- training/__pycache__/__init__.cpython-314.pyc +0 -0
- training/__pycache__/curriculum.cpython-314.pyc +0 -0
- training/__pycache__/overflow_gym_env.cpython-314.pyc +0 -0
- training/__pycache__/ppo_trainer.cpython-314.pyc +0 -0
- training/__pycache__/reward.cpython-314.pyc +0 -0
- training/curriculum.py +99 -0
- training/overflow_gym_env.py +170 -0
- training/ppo_trainer.py +329 -0
- training/reward.py +94 -0
DESIGN.md
ADDED
|
@@ -0,0 +1,791 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Overflow Environment β Low-Level Design Document
|
| 2 |
+
|
| 3 |
+
## Table of Contents
|
| 4 |
+
|
| 5 |
+
1. [Architecture Overview](#1-architecture-overview)
|
| 6 |
+
2. [File-by-File Breakdown](#2-file-by-file-breakdown)
|
| 7 |
+
3. [Data Models (Wire Format)](#3-data-models-wire-format)
|
| 8 |
+
4. [Simulation Internals](#4-simulation-internals)
|
| 9 |
+
5. [Step-by-Step Execution Pipeline](#5-step-by-step-execution-pipeline)
|
| 10 |
+
6. [Distance and Collision Model](#6-distance-and-collision-model)
|
| 11 |
+
7. [Reward Function β Complete Breakdown](#7-reward-function--complete-breakdown)
|
| 12 |
+
8. [Scripted Car AI](#8-scripted-car-ai)
|
| 13 |
+
9. [Action Parsing β How LLM Output Becomes a Decision](#9-action-parsing--how-llm-output-becomes-a-decision)
|
| 14 |
+
10. [Observation Text Format](#10-observation-text-format)
|
| 15 |
+
11. [Server Protocol β What Training Scripts Must Send](#11-server-protocol--what-training-scripts-must-send)
|
| 16 |
+
12. [Training Integration β GRPO / TRL](#12-training-integration--grpo--trl)
|
| 17 |
+
13. [Episode Dynamics and RL Characteristics](#13-episode-dynamics-and-rl-characteristics)
|
| 18 |
+
14. [Configuration Constants](#14-configuration-constants)
|
| 19 |
+
15. [Docker and Deployment](#15-docker-and-deployment)
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## 1. Architecture Overview
|
| 24 |
+
|
| 25 |
+
```
|
| 26 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
β Training Script (GRPO) β
|
| 28 |
+
β calls reset(), reads observation, calls step(action) β
|
| 29 |
+
ββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββ
|
| 30 |
+
β WebSocket (persistent session)
|
| 31 |
+
β JSON messages over ws://host:8000/ws
|
| 32 |
+
βΌ
|
| 33 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
+
β FastAPI Server (app.py) β
|
| 35 |
+
β create_app(OverflowEnvironment, OverflowAction, β
|
| 36 |
+
β OverflowObservation) β
|
| 37 |
+
β β
|
| 38 |
+
β Endpoints: β
|
| 39 |
+
β WS /ws β primary (stateful session) β
|
| 40 |
+
β POST /reset β HTTP fallback β
|
| 41 |
+
β POST /step β HTTP fallback β
|
| 42 |
+
β GET /state β HTTP fallback β
|
| 43 |
+
β GET /health β health check β
|
| 44 |
+
β GET /schema β JSON schemas for action/obs/state β
|
| 45 |
+
ββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββ
|
| 46 |
+
β
|
| 47 |
+
βΌ
|
| 48 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
+
β OverflowEnvironment (pure Python) β
|
| 50 |
+
β β
|
| 51 |
+
β Internal state: β
|
| 52 |
+
β _cars: List[Car] (5 cars, car 0 = agent) β
|
| 53 |
+
β _state: OverflowState (episode tracking) β
|
| 54 |
+
β _rng: random.Random (seeded per episode) β
|
| 55 |
+
β _done: bool β
|
| 56 |
+
β β
|
| 57 |
+
β Methods: β
|
| 58 |
+
β reset(seed, episode_id) β OverflowObservation β
|
| 59 |
+
β step(OverflowAction) β OverflowObservation β
|
| 60 |
+
β state (property) β OverflowState β
|
| 61 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
**Key invariant**: The training loop calls `reset()`. The LLM agent only calls `step()` via the training harness. Agents can never reset β if they could undo consequences, training breaks.
|
| 65 |
+
|
| 66 |
+
**Session model**: Each WebSocket connection gets its own `OverflowEnvironment` instance. The `create_app` function receives the class (factory), not an instance. When a WebSocket connects, the server instantiates a fresh environment for that session.
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## 2. File-by-File Breakdown
|
| 71 |
+
|
| 72 |
+
### `models.py` β Pydantic data models
|
| 73 |
+
|
| 74 |
+
Defines three classes inheriting from OpenEnv core types:
|
| 75 |
+
|
| 76 |
+
| Class | Parent | Purpose |
|
| 77 |
+
|-------|--------|---------|
|
| 78 |
+
| `OverflowAction(Action)` | `openenv.core.env_server.types.Action` | What the LLM sends each step |
|
| 79 |
+
| `OverflowObservation(Observation)` | `openenv.core.env_server.types.Observation` | What the environment returns |
|
| 80 |
+
| `OverflowState(State)` | `openenv.core.env_server.types.State` | Internal state exposed via `/state` |
|
| 81 |
+
|
| 82 |
+
All three are Pydantic `BaseModel` subclasses. The parent classes provide `metadata: Dict[str, Any]` (on Action and Observation) and `episode_id: str`, `step_count: int` (on State). The parent `Observation` provides `done: bool` and `reward: float | None`.
|
| 83 |
+
|
| 84 |
+
### `server/overflow_environment.py` β All game logic
|
| 85 |
+
|
| 86 |
+
Contains:
|
| 87 |
+
- `Car` dataclass β per-car state (id, lane, position, speed, goal, is_agent, reached_goal)
|
| 88 |
+
- `_parse_decision()` β tolerant action parser
|
| 89 |
+
- `_compute_reasoning_bonus()` β reasoning quality scorer
|
| 90 |
+
- `_scripted_car_action()` β NPC car AI
|
| 91 |
+
- `_apply_action()` β mutates a car's speed/lane
|
| 92 |
+
- `_generate_scene_description()` β builds the text observation
|
| 93 |
+
- `OverflowEnvironment(Environment)` β the main class with `reset()`, `step()`, `state`
|
| 94 |
+
|
| 95 |
+
### `server/app.py` β FastAPI wiring
|
| 96 |
+
|
| 97 |
+
Introspects `create_app` to determine if it expects a factory (class) or an instance. Passes `OverflowEnvironment`, `OverflowAction`, `OverflowObservation` to `create_app`. The resulting `app` object is what uvicorn serves.
|
| 98 |
+
|
| 99 |
+
### `client.py` β WebSocket client
|
| 100 |
+
|
| 101 |
+
`OverflowEnv(EnvClient[OverflowAction, OverflowObservation, OverflowState])` with three required methods:
|
| 102 |
+
- `_step_payload(action)` β serializes `OverflowAction` to `{"decision": ..., "reasoning": ...}`
|
| 103 |
+
- `_parse_result(payload)` β deserializes server JSON into `StepResult[OverflowObservation]`
|
| 104 |
+
- `_parse_state(payload)` β deserializes server JSON into `OverflowState`
|
| 105 |
+
|
| 106 |
+
### `__init__.py` β Public API
|
| 107 |
+
|
| 108 |
+
Exports: `OverflowAction`, `OverflowObservation`, `OverflowState`, `OverflowEnv`.
|
| 109 |
+
|
| 110 |
+
---
|
| 111 |
+
|
| 112 |
+
## 3. Data Models (Wire Format)
|
| 113 |
+
|
| 114 |
+
### OverflowAction β What the training script sends to `/step`
|
| 115 |
+
|
| 116 |
+
```json
|
| 117 |
+
{
|
| 118 |
+
"action": {
|
| 119 |
+
"decision": "brake",
|
| 120 |
+
"reasoning": "Car 3 is ahead in my lane, 15 units away, going slower. I should brake."
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
| Field | Type | Required | Default | Description |
|
| 126 |
+
|-------|------|----------|---------|-------------|
|
| 127 |
+
| `decision` | `str` | No | `"maintain"` | One of: `accelerate`, `brake`, `lane_change_left`, `lane_change_right`, `maintain` |
|
| 128 |
+
| `reasoning` | `str` | No | `""` | Free-text chain-of-thought. Affects reward via reasoning bonus (0.0β2.0). |
|
| 129 |
+
|
| 130 |
+
The `decision` field is parsed tolerantly β see Section 9.
|
| 131 |
+
|
| 132 |
+
### OverflowObservation β What the server returns
|
| 133 |
+
|
| 134 |
+
Each observation carries **both** text (for the LLM) and structured data (for the frontend/viz).
|
| 135 |
+
|
| 136 |
+
```json
|
| 137 |
+
{
|
| 138 |
+
"observation": {
|
| 139 |
+
"scene_description": "You are Car 0 in lane 2, position 45, speed 60.\n...",
|
| 140 |
+
"incident_report": "Observer: No incidents this step.",
|
| 141 |
+
"done": false,
|
| 142 |
+
"reward": 1.45,
|
| 143 |
+
"cars": [
|
| 144 |
+
{"carId": 0, "lane": 2, "position": {"x": 45.0, "y": 7.4}, "speed": 60.0, "acceleration": 5.0},
|
| 145 |
+
{"carId": 1, "lane": 1, "position": {"x": 43.0, "y": 3.7}, "speed": 55.0, "acceleration": 0.0}
|
| 146 |
+
],
|
| 147 |
+
"proximities": [
|
| 148 |
+
{"carA": 0, "carB": 1, "distance": 10.5}
|
| 149 |
+
],
|
| 150 |
+
"lane_occupancies": [
|
| 151 |
+
{"lane": 1, "carIds": [1]},
|
| 152 |
+
{"lane": 2, "carIds": [0]}
|
| 153 |
+
],
|
| 154 |
+
"metadata": {}
|
| 155 |
+
},
|
| 156 |
+
"reward": 1.45,
|
| 157 |
+
"done": false
|
| 158 |
+
}
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
#### Text fields (for the LLM)
|
| 162 |
+
|
| 163 |
+
| Field | Type | Description |
|
| 164 |
+
|-------|------|-------------|
|
| 165 |
+
| `scene_description` | `str` | Multi-line text describing all cars. This is what the LLM reads. |
|
| 166 |
+
| `incident_report` | `str` | Observer output. Either `"Observer: No incidents this step."` or a list of CRASH/NEAR MISS events. |
|
| 167 |
+
|
| 168 |
+
#### Structured fields (for the frontend β compatible with Overflow frontend types)
|
| 169 |
+
|
| 170 |
+
| Field | Type | Frontend equivalent |
|
| 171 |
+
|-------|------|---------------------|
|
| 172 |
+
| `cars` | `CarStateData[]` | `CarState[]` β `{carId, lane, position: {x, y}, speed, acceleration}` |
|
| 173 |
+
| `proximities` | `ProximityData[]` | `{carA, carB, distance}[]` β pairwise distances for close cars |
|
| 174 |
+
| `lane_occupancies` | `LaneOccupancyData[]` | `{lane, carIds}[]` β which cars are in each lane |
|
| 175 |
+
|
| 176 |
+
Position `y` is computed as `lane * 3.7` (lane width in metres), matching the frontend's `makeCar` convention.
|
| 177 |
+
|
| 178 |
+
#### Common fields
|
| 179 |
+
|
| 180 |
+
| Field | Type | Description |
|
| 181 |
+
|-------|------|-------------|
|
| 182 |
+
| `done` | `bool` | `true` if episode ended (crash, goal reached, or max steps). |
|
| 183 |
+
| `reward` | `float` | Scalar reward for this step. Sum of all reward components. |
|
| 184 |
+
|
| 185 |
+
The `reward` and `done` appear both inside `observation` and at the top level of the response (OpenEnv convention).
|
| 186 |
+
|
| 187 |
+
### OverflowState β What `/state` returns
|
| 188 |
+
|
| 189 |
+
```json
|
| 190 |
+
{
|
| 191 |
+
"episode_id": "a1b2c3d4-...",
|
| 192 |
+
"step_count": 17,
|
| 193 |
+
"crash_count": 0,
|
| 194 |
+
"near_miss_count": 23,
|
| 195 |
+
"cars_reached_goal": 1,
|
| 196 |
+
"total_cars": 5
|
| 197 |
+
}
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
| Field | Type | Description |
|
| 201 |
+
|-------|------|-------------|
|
| 202 |
+
| `episode_id` | `str` | UUID for this episode. Set on `reset()`. |
|
| 203 |
+
| `step_count` | `int` | How many `step()` calls have been made. |
|
| 204 |
+
| `crash_count` | `int` | Cumulative crash events (each pair counts as 1). |
|
| 205 |
+
| `near_miss_count` | `int` | Cumulative near-miss events (each pair counts as 1). |
|
| 206 |
+
| `cars_reached_goal` | `int` | How many cars (including scripted) reached their goal. |
|
| 207 |
+
| `total_cars` | `int` | Always 5. |
|
| 208 |
+
|
| 209 |
+
---
|
| 210 |
+
|
| 211 |
+
## 4. Simulation Internals
|
| 212 |
+
|
| 213 |
+
### The Road
|
| 214 |
+
|
| 215 |
+
- 3 lanes, numbered 1, 2, 3 (1 = leftmost, 3 = rightmost)
|
| 216 |
+
- Road length: ~200 position units
|
| 217 |
+
- No wrapping β cars move forward from low positions toward high positions
|
| 218 |
+
- Lanes are conceptually 10 units apart for distance calculations
|
| 219 |
+
|
| 220 |
+
### Car State
|
| 221 |
+
|
| 222 |
+
Each car is a `Car` dataclass:
|
| 223 |
+
|
| 224 |
+
```python
|
| 225 |
+
@dataclass
|
| 226 |
+
class Car:
|
| 227 |
+
car_id: int # 0 = agent, 1β4 = scripted
|
| 228 |
+
lane: int # 1, 2, or 3
|
| 229 |
+
position: float # 0.0 to ~200.0 (along the road)
|
| 230 |
+
speed: float # 20.0 to 90.0
|
| 231 |
+
goal_position: float # 160.0 to 195.0
|
| 232 |
+
is_agent: bool # True only for car 0
|
| 233 |
+
reached_goal: bool # True once position >= goal_position
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
+
### Initialization (reset)
|
| 237 |
+
|
| 238 |
+
On `reset(seed=N)`:
|
| 239 |
+
1. A `random.Random(seed)` RNG is created (deterministic replays if same seed).
|
| 240 |
+
2. 5 cars are spawned:
|
| 241 |
+
- **Lane**: random 1β3
|
| 242 |
+
- **Position**: random 10β80 (spread across the first half of the road)
|
| 243 |
+
- **Speed**: random 40β70
|
| 244 |
+
- **Goal**: random 160β195
|
| 245 |
+
3. No two cars occupy the same 10-unit segment in the same lane at spawn (deconflicted via `(lane, position // 10)` hash).
|
| 246 |
+
4. Car 0 is the agent. Cars 1β4 are scripted.
|
| 247 |
+
|
| 248 |
+
### Movement
|
| 249 |
+
|
| 250 |
+
Each step, every active (non-goal-reached) car moves forward:
|
| 251 |
+
|
| 252 |
+
```
|
| 253 |
+
car.position += car.speed * 0.1
|
| 254 |
+
```
|
| 255 |
+
|
| 256 |
+
This means a car at speed 60 moves 6.0 units per step. At that rate, traversing the ~120-unit gap from starting zone (10β80) to goal zone (160β195) takes roughly 20 steps. Faster cars (speed 90) move 9.0 units/step and reach goals sooner.
|
| 257 |
+
|
| 258 |
+
---
|
| 259 |
+
|
| 260 |
+
## 5. Step-by-Step Execution Pipeline
|
| 261 |
+
|
| 262 |
+
When `step(action)` is called, the following happens **in this exact order**:
|
| 263 |
+
|
| 264 |
+
```
|
| 265 |
+
1. GUARD: if episode is already done β return stale observation with reward=0.0
|
| 266 |
+
2. INCREMENT step_count
|
| 267 |
+
3. PARSE the agent's action β one of {accelerate, brake, lane_change_left, lane_change_right, maintain}
|
| 268 |
+
4. APPLY action to Car 0 (mutate speed or lane)
|
| 269 |
+
5. COMPUTE scripted actions for Cars 1β4 and APPLY them
|
| 270 |
+
6. MOVE all active cars forward: position += speed * 0.1
|
| 271 |
+
7. COLLISION DETECTION (pairwise over all active cars):
|
| 272 |
+
- distance < 5.0 β CRASH (reward -5.0, episode ends)
|
| 273 |
+
- distance < 15.0 β NEAR MISS (reward -1.0 per pair)
|
| 274 |
+
8. If no crash:
|
| 275 |
+
a. Check if Car 0 reached its goal β reward +3.0, episode ends
|
| 276 |
+
b. Check if scripted cars reached their goals (state tracking only)
|
| 277 |
+
c. If episode not ending β SAFE STEP bonus: reward +0.5
|
| 278 |
+
9. REASONING BONUS: score the reasoning text β reward +0.0 to +2.0
|
| 279 |
+
10. MAX STEPS CHECK: if step_count >= 100 β episode ends
|
| 280 |
+
11. BUILD observation text and incident report
|
| 281 |
+
12. RETURN OverflowObservation(scene_description, incident_report, done, reward)
|
| 282 |
+
```
|
| 283 |
+
|
| 284 |
+
**Important ordering detail**: Actions are applied (step 4β5) **before** movement (step 6). This means the agent's speed/lane change takes effect for this step's movement. Collision detection (step 7) happens **after** movement, on the new positions.
|
| 285 |
+
|
| 286 |
+
**Reward accumulation within a step**: A single step's reward is the **sum** of all applicable components. For example, if there are 2 near-miss pairs and the agent is still alive with good reasoning, the reward could be: `(-1.0 * 2) + 0.5 + 1.5 = -1.0`.
|
| 287 |
+
|
| 288 |
+
---
|
| 289 |
+
|
| 290 |
+
## 6. Distance and Collision Model
|
| 291 |
+
|
| 292 |
+
Distance between two cars uses a weighted Euclidean formula:
|
| 293 |
+
|
| 294 |
+
```python
|
| 295 |
+
def distance_to(self, other):
|
| 296 |
+
lane_diff = abs(self.lane - other.lane) * 10.0
|
| 297 |
+
pos_diff = abs(self.position - other.position)
|
| 298 |
+
return sqrt(lane_diff**2 + pos_diff**2)
|
| 299 |
+
```
|
| 300 |
+
|
| 301 |
+
**Implications**:
|
| 302 |
+
- Two cars in the **same lane** at positions 45 and 50: distance = 5.0 (exactly at crash threshold)
|
| 303 |
+
- Two cars in **adjacent lanes** (e.g., lane 1 and lane 2) at the same position: distance = 10.0 (near miss, not crash)
|
| 304 |
+
- Two cars **two lanes apart** at the same position: distance = 20.0 (safe, no incident)
|
| 305 |
+
- Two cars in adjacent lanes, 10 units apart longitudinally: distance = sqrt(100 + 100) β 14.1 (near miss)
|
| 306 |
+
|
| 307 |
+
**Key insight for the agent**: Lane changes provide safety via the 10-unit lane multiplier. Staying in the same lane as another car is the primary crash risk. The agent should use lane changes proactively to maintain distance from cars in its lane.
|
| 308 |
+
|
| 309 |
+
### Collision detection scope
|
| 310 |
+
|
| 311 |
+
Detection is **pairwise over ALL active cars**, not just agent-involving pairs. If Car 2 and Car 3 crash, the episode still ends with -5.0 reward. This means the agent is implicitly responsible for the overall traffic flow β it should avoid creating situations where its actions cause chain reactions among scripted cars.
|
| 312 |
+
|
| 313 |
+
---
|
| 314 |
+
|
| 315 |
+
## 7. Reward Function β Complete Breakdown
|
| 316 |
+
|
| 317 |
+
### Per-step reward components
|
| 318 |
+
|
| 319 |
+
| Component | Value | Condition | Stacks? |
|
| 320 |
+
|-----------|-------|-----------|---------|
|
| 321 |
+
| **Crash** | -5.0 | Any pair distance < 5.0 | Once (episode ends) |
|
| 322 |
+
| **Near miss** | -1.0 | Per pair with distance < 15.0 | Yes, per pair (can be -2.0, -3.0, etc.) |
|
| 323 |
+
| **Safe step** | +0.5 | No crash and episode not ending this step | Once per step |
|
| 324 |
+
| **Goal reached** | +3.0 | Car 0's position >= goal_position | Once (episode ends) |
|
| 325 |
+
| **Reasoning bonus** | +0.0 to +2.0 | Based on reasoning text quality | Once per step |
|
| 326 |
+
|
| 327 |
+
### Reasoning bonus scoring
|
| 328 |
+
|
| 329 |
+
The bonus has three sub-components capped at 2.0 total:
|
| 330 |
+
|
| 331 |
+
**Length bonus** (up to 0.5):
|
| 332 |
+
- `len > 20` chars β +0.2
|
| 333 |
+
- `len > 50` chars β +0.15
|
| 334 |
+
- `len > 100` chars β +0.15
|
| 335 |
+
|
| 336 |
+
**Keyword awareness** (up to 1.0):
|
| 337 |
+
Each keyword found β +0.2, capped at 1.0. Keywords: `ahead`, `behind`, `lane`, `speed`, `distance`, `safe`, `danger`, `collision`, `brake`, `gap`, `close`, `slow`, `fast`, `goal`, `position`.
|
| 338 |
+
|
| 339 |
+
**Structure bonus** (up to 0.5):
|
| 340 |
+
- Contains `<think>` or `because` β +0.25
|
| 341 |
+
- Contains `therefore`, `so i should`, `best option`, or `i will` β +0.25
|
| 342 |
+
|
| 343 |
+
### Typical reward ranges per step
|
| 344 |
+
|
| 345 |
+
| Scenario | Typical reward |
|
| 346 |
+
|----------|---------------|
|
| 347 |
+
| Safe step, no reasoning | +0.5 |
|
| 348 |
+
| Safe step, decent reasoning | +1.0 to +2.0 |
|
| 349 |
+
| Safe step, excellent reasoning | +2.0 to +2.5 |
|
| 350 |
+
| 1 near miss, decent reasoning | -0.5 to +0.5 |
|
| 351 |
+
| 2 near misses, decent reasoning | -1.5 to -0.5 |
|
| 352 |
+
| Crash (any) | -5.0 + reasoning bonus |
|
| 353 |
+
| Goal reached, good reasoning | +3.0 + reasoning bonus |
|
| 354 |
+
|
| 355 |
+
### Episode return (total reward) characteristics
|
| 356 |
+
|
| 357 |
+
Based on testing with seed=42:
|
| 358 |
+
- A "maintain" strategy with decent reasoning gets ~1.1 per step Γ ~17 steps β 18.7 total, minus near-miss penalties
|
| 359 |
+
- Aggressive "accelerate" strategies reach the goal faster but accumulate more near misses
|
| 360 |
+
- Smart strategies that use lane changes and braking to avoid near misses can maximize total reward
|
| 361 |
+
|
| 362 |
+
---
|
| 363 |
+
|
| 364 |
+
## 8. Scripted Car AI
|
| 365 |
+
|
| 366 |
+
Cars 1β4 use `_scripted_car_action(car, all_cars, rng)`:
|
| 367 |
+
|
| 368 |
+
```
|
| 369 |
+
1. Find the nearest car AHEAD in the SAME LANE
|
| 370 |
+
2. If that car is < 20 units ahead β "brake"
|
| 371 |
+
3. Else if speed < 60 and 10% random chance β "accelerate"
|
| 372 |
+
4. Else if 5% random chance β lane change (random left/right, respecting boundaries)
|
| 373 |
+
5. Else β "maintain"
|
| 374 |
+
```
|
| 375 |
+
|
| 376 |
+
**Characteristics**:
|
| 377 |
+
- Scripted cars are mostly passive β they maintain speed
|
| 378 |
+
- They brake reactively when blocked (but only for same-lane, ahead)
|
| 379 |
+
- They rarely change lanes (5% per step), making their behavior somewhat predictable
|
| 380 |
+
- They never intentionally avoid the agent β only react to cars directly ahead
|
| 381 |
+
- They can accumulate near misses and crashes among themselves
|
| 382 |
+
|
| 383 |
+
This creates an environment where a smart agent can learn to navigate around largely predictable but occasionally erratic traffic.
|
| 384 |
+
|
| 385 |
+
---
|
| 386 |
+
|
| 387 |
+
## 9. Action Parsing β How LLM Output Becomes a Decision
|
| 388 |
+
|
| 389 |
+
The parser `_parse_decision(action)` is intentionally forgiving. It tries three strategies in order:
|
| 390 |
+
|
| 391 |
+
### Strategy 1: Direct field match
|
| 392 |
+
```python
|
| 393 |
+
decision = action.decision.strip().lower().replace(" ", "_")
|
| 394 |
+
# If it's one of {accelerate, brake, lane_change_left, lane_change_right, maintain} β use it
|
| 395 |
+
```
|
| 396 |
+
|
| 397 |
+
### Strategy 2: XML tag extraction
|
| 398 |
+
```python
|
| 399 |
+
text = f"{action.decision} {action.reasoning}".lower()
|
| 400 |
+
match = re.search(r"<action>\s*(\w+)\s*</action>", text)
|
| 401 |
+
# If found and valid β use it
|
| 402 |
+
```
|
| 403 |
+
|
| 404 |
+
This handles LLM outputs like:
|
| 405 |
+
```
|
| 406 |
+
decision: "think about it"
|
| 407 |
+
reasoning: "<think>Car ahead is close</think><action>brake</action>"
|
| 408 |
+
```
|
| 409 |
+
|
| 410 |
+
### Strategy 3: Keyword scan
|
| 411 |
+
```python
|
| 412 |
+
for v in {"accelerate", "brake", "lane_change_left", "lane_change_right", "maintain"}:
|
| 413 |
+
if v in text:
|
| 414 |
+
return v
|
| 415 |
+
```
|
| 416 |
+
|
| 417 |
+
This handles outputs like `decision: "I want to accelerate now"`.
|
| 418 |
+
|
| 419 |
+
### Fallback
|
| 420 |
+
If nothing matches β `"maintain"` (safe default).
|
| 421 |
+
|
| 422 |
+
**For training scripts**: The cleanest format is to put the exact decision string in the `decision` field. The tolerant parsing is there so that LLMs in early training (before they learn the format) still produce valid actions rather than crashing.
|
| 423 |
+
|
| 424 |
+
---
|
| 425 |
+
|
| 426 |
+
## 10. Observation Text Format
|
| 427 |
+
|
| 428 |
+
The `scene_description` field is a multi-line string that the LLM reads as its input. Example:
|
| 429 |
+
|
| 430 |
+
```
|
| 431 |
+
You are Car 0 in lane 2, position 45, speed 60.
|
| 432 |
+
Goal: reach position 180.
|
| 433 |
+
Nearby cars:
|
| 434 |
+
- Car 1: lane 1, position 43, speed 55
|
| 435 |
+
- Car 2: lane 3, position 48, speed 70
|
| 436 |
+
- Car 3: lane 2, position 65, speed 50 [AHEAD IN YOUR LANE - 20 units away]
|
| 437 |
+
- Car 4: lane 1, position 30, speed 65
|
| 438 |
+
```
|
| 439 |
+
|
| 440 |
+
**Annotations added**:
|
| 441 |
+
- `[AHEAD IN YOUR LANE - N units away]` β same lane, ahead of agent
|
| 442 |
+
- `[BEHIND IN YOUR LANE - N units away]` β same lane, behind agent
|
| 443 |
+
- `[REACHED GOAL]` β car has finished
|
| 444 |
+
|
| 445 |
+
The `incident_report` is separate:
|
| 446 |
+
- No incidents: `"Observer: No incidents this step."`
|
| 447 |
+
- With incidents: One line per event, e.g.:
|
| 448 |
+
```
|
| 449 |
+
NEAR MISS between Car 0 and Car 3 (distance: 12.5)
|
| 450 |
+
Car 0 reached its goal at position 180!
|
| 451 |
+
```
|
| 452 |
+
|
| 453 |
+
---
|
| 454 |
+
|
| 455 |
+
## 11. Server Protocol β What Training Scripts Must Send
|
| 456 |
+
|
| 457 |
+
### WebSocket Protocol (Primary β for training)
|
| 458 |
+
|
| 459 |
+
Connect to `ws://host:8000/ws`. All messages are JSON.
|
| 460 |
+
|
| 461 |
+
#### Reset
|
| 462 |
+
|
| 463 |
+
**Send:**
|
| 464 |
+
```json
|
| 465 |
+
{"type": "reset", "data": {"seed": 42}}
|
| 466 |
+
```
|
| 467 |
+
|
| 468 |
+
`data` can include `seed` (int) and/or `episode_id` (str). Both are optional.
|
| 469 |
+
|
| 470 |
+
**Receive:**
|
| 471 |
+
```json
|
| 472 |
+
{
|
| 473 |
+
"type": "observation",
|
| 474 |
+
"data": {
|
| 475 |
+
"observation": {
|
| 476 |
+
"scene_description": "You are Car 0 in lane 3, position 24, speed 40.\n...",
|
| 477 |
+
"incident_report": "",
|
| 478 |
+
"done": false,
|
| 479 |
+
"reward": 0.0,
|
| 480 |
+
"metadata": {}
|
| 481 |
+
},
|
| 482 |
+
"reward": 0.0,
|
| 483 |
+
"done": false
|
| 484 |
+
}
|
| 485 |
+
}
|
| 486 |
+
```
|
| 487 |
+
|
| 488 |
+
#### Step
|
| 489 |
+
|
| 490 |
+
**Send:**
|
| 491 |
+
```json
|
| 492 |
+
{
|
| 493 |
+
"type": "step",
|
| 494 |
+
"data": {
|
| 495 |
+
"decision": "brake",
|
| 496 |
+
"reasoning": "Car ahead is close, braking to maintain safe distance."
|
| 497 |
+
}
|
| 498 |
+
}
|
| 499 |
+
```
|
| 500 |
+
|
| 501 |
+
**Receive:**
|
| 502 |
+
```json
|
| 503 |
+
{
|
| 504 |
+
"type": "observation",
|
| 505 |
+
"data": {
|
| 506 |
+
"observation": {
|
| 507 |
+
"scene_description": "You are Car 0 in lane 3, position 27, speed 35.\n...",
|
| 508 |
+
"incident_report": "Observer: No incidents this step.",
|
| 509 |
+
"done": false,
|
| 510 |
+
"reward": 2.25,
|
| 511 |
+
"metadata": {}
|
| 512 |
+
},
|
| 513 |
+
"reward": 2.25,
|
| 514 |
+
"done": false
|
| 515 |
+
}
|
| 516 |
+
}
|
| 517 |
+
```
|
| 518 |
+
|
| 519 |
+
#### State
|
| 520 |
+
|
| 521 |
+
**Send:**
|
| 522 |
+
```json
|
| 523 |
+
{"type": "state"}
|
| 524 |
+
```
|
| 525 |
+
|
| 526 |
+
**Receive:**
|
| 527 |
+
```json
|
| 528 |
+
{
|
| 529 |
+
"type": "state",
|
| 530 |
+
"data": {
|
| 531 |
+
"episode_id": "a1b2c3d4-...",
|
| 532 |
+
"step_count": 7,
|
| 533 |
+
"crash_count": 0,
|
| 534 |
+
"near_miss_count": 3,
|
| 535 |
+
"cars_reached_goal": 0,
|
| 536 |
+
"total_cars": 5
|
| 537 |
+
}
|
| 538 |
+
}
|
| 539 |
+
```
|
| 540 |
+
|
| 541 |
+
#### Close
|
| 542 |
+
|
| 543 |
+
**Send:**
|
| 544 |
+
```json
|
| 545 |
+
{"type": "close"}
|
| 546 |
+
```
|
| 547 |
+
|
| 548 |
+
### HTTP Protocol (Fallback β for simple testing)
|
| 549 |
+
|
| 550 |
+
Note: The HTTP API creates a **new environment instance per endpoint** in factory mode. The `/reset` and `/step` calls hit separate instances. Use WebSocket for stateful multi-step episodes.
|
| 551 |
+
|
| 552 |
+
```
|
| 553 |
+
POST /reset Body: {"seed": 42} β {"observation": {...}, "reward": 0.0, "done": false}
|
| 554 |
+
POST /step Body: {"action": {"decision": "brake", "reasoning": "..."}} β {"observation": {...}, "reward": ..., "done": ...}
|
| 555 |
+
GET /state β {"episode_id": ..., "step_count": ..., ...}
|
| 556 |
+
GET /health β {"status": "healthy"}
|
| 557 |
+
GET /schema β {"action": {...}, "observation": {...}, "state": {...}}
|
| 558 |
+
```
|
| 559 |
+
|
| 560 |
+
### Using the Python Client
|
| 561 |
+
|
| 562 |
+
```python
|
| 563 |
+
from overflow_env import OverflowEnv, OverflowAction
|
| 564 |
+
|
| 565 |
+
with OverflowEnv(base_url="http://localhost:8000") as env:
|
| 566 |
+
result = env.reset(seed=42)
|
| 567 |
+
# result is StepResult[OverflowObservation]
|
| 568 |
+
# result.observation.scene_description β the text for the LLM
|
| 569 |
+
# result.observation.incident_report β observer output
|
| 570 |
+
# result.reward β float
|
| 571 |
+
# result.done β bool
|
| 572 |
+
|
| 573 |
+
while not result.done:
|
| 574 |
+
# Feed scene_description to LLM, get decision + reasoning back
|
| 575 |
+
llm_decision, llm_reasoning = call_llm(result.observation.scene_description)
|
| 576 |
+
|
| 577 |
+
action = OverflowAction(decision=llm_decision, reasoning=llm_reasoning)
|
| 578 |
+
result = env.step(action)
|
| 579 |
+
|
| 580 |
+
# Episode over
|
| 581 |
+
state = env.state()
|
| 582 |
+
print(f"Steps: {state.step_count}, Crashes: {state.crash_count}")
|
| 583 |
+
```
|
| 584 |
+
|
| 585 |
+
---
|
| 586 |
+
|
| 587 |
+
## 12. Training Integration β GRPO / TRL
|
| 588 |
+
|
| 589 |
+
### System prompt for the LLM
|
| 590 |
+
|
| 591 |
+
The training script should set a system prompt like:
|
| 592 |
+
|
| 593 |
+
```
|
| 594 |
+
You are an autonomous vehicle controller. Each turn you receive a traffic scene description.
|
| 595 |
+
You must output a driving decision and your reasoning.
|
| 596 |
+
|
| 597 |
+
Available decisions: accelerate, brake, lane_change_left, lane_change_right, maintain
|
| 598 |
+
|
| 599 |
+
Output format:
|
| 600 |
+
<think>Your reasoning about the traffic situation</think>
|
| 601 |
+
<action>your_decision</action>
|
| 602 |
+
```
|
| 603 |
+
|
| 604 |
+
### What the training loop does each episode
|
| 605 |
+
|
| 606 |
+
```python
|
| 607 |
+
# 1. Reset environment
|
| 608 |
+
result = env.reset(seed=episode_seed)
|
| 609 |
+
|
| 610 |
+
# 2. Build initial prompt
|
| 611 |
+
messages = [
|
| 612 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 613 |
+
{"role": "user", "content": result.observation.scene_description}
|
| 614 |
+
]
|
| 615 |
+
|
| 616 |
+
trajectory_rewards = []
|
| 617 |
+
|
| 618 |
+
# 3. Loop until done
|
| 619 |
+
while not result.done:
|
| 620 |
+
# 3a. Get LLM completion
|
| 621 |
+
completion = model.generate(messages) # the text the LLM produces
|
| 622 |
+
|
| 623 |
+
# 3b. Parse LLM output into action
|
| 624 |
+
# The environment's parser is tolerant, but for clean training
|
| 625 |
+
# you might also parse on the client side
|
| 626 |
+
action = OverflowAction(
|
| 627 |
+
decision=extract_decision(completion),
|
| 628 |
+
reasoning=completion # pass full text as reasoning
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
# 3c. Step
|
| 632 |
+
result = env.step(action)
|
| 633 |
+
trajectory_rewards.append(result.reward)
|
| 634 |
+
|
| 635 |
+
# 3d. Append to conversation for next turn
|
| 636 |
+
messages.append({"role": "assistant", "content": completion})
|
| 637 |
+
messages.append({"role": "user", "content": (
|
| 638 |
+
result.observation.scene_description + "\n" +
|
| 639 |
+
result.observation.incident_report
|
| 640 |
+
)})
|
| 641 |
+
|
| 642 |
+
# 4. Compute episode return for GRPO
|
| 643 |
+
episode_return = sum(trajectory_rewards)
|
| 644 |
+
```
|
| 645 |
+
|
| 646 |
+
### GRPO reward signal
|
| 647 |
+
|
| 648 |
+
For GRPO (Group Relative Policy Optimization), the reward signal is the **episode return** β the sum of all per-step rewards across the episode. The environment is designed so that:
|
| 649 |
+
|
| 650 |
+
- **Positive episode returns** (agent reached goal safely with good reasoning) indicate good behavior
|
| 651 |
+
- **Negative episode returns** (crashes, many near misses) indicate bad behavior
|
| 652 |
+
- The **reasoning bonus** provides per-step reward shaping that encourages the LLM to explain its thinking, which improves interpretability and can speed up learning
|
| 653 |
+
|
| 654 |
+
### Constructing the reward for TRL
|
| 655 |
+
|
| 656 |
+
If using TRL's `OnlineDPOTrainer` or `GRPOTrainer`:
|
| 657 |
+
|
| 658 |
+
```python
|
| 659 |
+
# Per-step reward is already in result.reward
|
| 660 |
+
# For token-level reward (assign to last token of each turn):
|
| 661 |
+
rewards_per_turn = trajectory_rewards # list of floats, one per step
|
| 662 |
+
|
| 663 |
+
# For episode-level reward (assign to last token of episode):
|
| 664 |
+
episode_reward = sum(trajectory_rewards)
|
| 665 |
+
```
|
| 666 |
+
|
| 667 |
+
---
|
| 668 |
+
|
| 669 |
+
## 13. Episode Dynamics and RL Characteristics
|
| 670 |
+
|
| 671 |
+
### Episode length distribution
|
| 672 |
+
|
| 673 |
+
| Scenario | Typical length |
|
| 674 |
+
|----------|---------------|
|
| 675 |
+
| Aggressive accelerate β goal | 12β20 steps |
|
| 676 |
+
| Moderate maintain β goal | 18β30 steps |
|
| 677 |
+
| Conservative braking | 30β50+ steps |
|
| 678 |
+
| Crash (bad luck or bad driving) | 5β15 steps |
|
| 679 |
+
| Max steps timeout | 100 steps |
|
| 680 |
+
|
| 681 |
+
### What makes this environment learnable
|
| 682 |
+
|
| 683 |
+
1. **Clear signal**: Crashes give -5.0, goals give +3.0. The agent quickly learns that crashing is bad and reaching the goal is good.
|
| 684 |
+
|
| 685 |
+
2. **Gradual improvement**: Near misses (-1.0 each) provide intermediate signal. An agent that learns to avoid near misses gets higher returns than one that just avoids crashes.
|
| 686 |
+
|
| 687 |
+
3. **Speed-accuracy tradeoff**: Accelerating reaches the goal faster (more +3.0 episodes) but increases crash/near-miss risk. The optimal policy is to accelerate when safe and brake/change lanes when needed.
|
| 688 |
+
|
| 689 |
+
4. **Reasoning is rewarded**: The reasoning bonus (up to +2.0/step) means that over a 20-step episode, reasoning alone can contribute up to +40.0. This incentivizes the LLM to produce structured, situation-aware reasoning.
|
| 690 |
+
|
| 691 |
+
5. **Stochasticity**: Scripted cars have random elements (10% accelerate, 5% lane change). This means the same seed produces the same episode, but different seeds produce different traffic patterns, forcing the agent to generalize.
|
| 692 |
+
|
| 693 |
+
6. **All-pairs collision**: The agent is rewarded/punished for the entire traffic system, not just its own car. This means the agent must be aware of the overall traffic flow.
|
| 694 |
+
|
| 695 |
+
### Typical learning progression
|
| 696 |
+
|
| 697 |
+
1. **Random policy**: Mostly "maintain", occasional random actions. Episode return: 0 to 15 (depending on luck).
|
| 698 |
+
2. **Basic safety**: Agent learns to brake when car ahead is close. Fewer crashes, more goals. Episode return: 10 to 25.
|
| 699 |
+
3. **Strategic driving**: Agent learns to change lanes proactively, accelerate when clear, brake early. Episode return: 20 to 40.
|
| 700 |
+
4. **Optimized reasoning**: Agent produces structured reasoning with relevant keywords, maximizing the reasoning bonus. Episode return: 30 to 60.
|
| 701 |
+
|
| 702 |
+
### Reproducibility
|
| 703 |
+
|
| 704 |
+
Passing `seed=N` to `reset()` produces deterministic initial conditions and scripted car behavior (since the `random.Random` instance is seeded). The same seed + same agent actions = same trajectory. This is critical for GRPO, which compares multiple rollouts of the same prompt.
|
| 705 |
+
|
| 706 |
+
---
|
| 707 |
+
|
| 708 |
+
## 14. Configuration Constants
|
| 709 |
+
|
| 710 |
+
All constants are defined at the top of `server/overflow_environment.py`:
|
| 711 |
+
|
| 712 |
+
```python
|
| 713 |
+
NUM_LANES = 3 # Number of road lanes
|
| 714 |
+
ROAD_LENGTH = 200 # Conceptual road length (units)
|
| 715 |
+
NUM_CARS = 5 # Total cars (1 agent + 4 scripted)
|
| 716 |
+
MAX_STEPS = 100 # Maximum steps before forced termination
|
| 717 |
+
CRASH_DISTANCE = 5.0 # Distance threshold for crash
|
| 718 |
+
NEAR_MISS_DISTANCE = 15.0 # Distance threshold for near miss
|
| 719 |
+
|
| 720 |
+
REWARD_CRASH = -5.0 # Reward for any crash
|
| 721 |
+
REWARD_NEAR_MISS = -1.0 # Reward per near-miss pair
|
| 722 |
+
REWARD_SAFE_STEP = 0.5 # Reward for surviving a step
|
| 723 |
+
REWARD_REACHED_GOAL = 3.0 # Reward for reaching goal
|
| 724 |
+
REWARD_REASONING_MAX = 2.0 # Maximum reasoning quality bonus
|
| 725 |
+
|
| 726 |
+
MIN_SPEED = 20 # Minimum car speed
|
| 727 |
+
MAX_SPEED = 90 # Maximum car speed
|
| 728 |
+
SPEED_DELTA = 5 # Speed change per accelerate/brake
|
| 729 |
+
```
|
| 730 |
+
|
| 731 |
+
To tune difficulty:
|
| 732 |
+
- **Easier**: Increase `CRASH_DISTANCE` and `NEAR_MISS_DISTANCE`, decrease `NUM_CARS`, widen starting positions
|
| 733 |
+
- **Harder**: Decrease distances, increase `NUM_CARS`, narrow starting positions, increase `MAX_SPEED`
|
| 734 |
+
- **Longer episodes**: Increase `ROAD_LENGTH` or decrease starting speeds
|
| 735 |
+
- **More reasoning incentive**: Increase `REWARD_REASONING_MAX`
|
| 736 |
+
|
| 737 |
+
---
|
| 738 |
+
|
| 739 |
+
## 15. Docker and Deployment
|
| 740 |
+
|
| 741 |
+
### Local development
|
| 742 |
+
|
| 743 |
+
```bash
|
| 744 |
+
uvicorn overflow_env.server.app:app --host 0.0.0.0 --port 8000 --reload
|
| 745 |
+
```
|
| 746 |
+
|
| 747 |
+
### Docker build
|
| 748 |
+
|
| 749 |
+
```bash
|
| 750 |
+
# From the overflow_env/ directory:
|
| 751 |
+
docker build -t overflow-env:latest -f server/Dockerfile .
|
| 752 |
+
docker run -p 8000:8000 overflow-env:latest
|
| 753 |
+
```
|
| 754 |
+
|
| 755 |
+
The Dockerfile uses a multi-stage build:
|
| 756 |
+
1. **Builder stage**: Installs dependencies with `uv sync` into a `.venv`
|
| 757 |
+
2. **Runtime stage**: Copies the `.venv` and source code, runs uvicorn
|
| 758 |
+
|
| 759 |
+
Base image: `ghcr.io/meta-pytorch/openenv-base:latest`
|
| 760 |
+
|
| 761 |
+
### Push to HuggingFace Spaces
|
| 762 |
+
|
| 763 |
+
```bash
|
| 764 |
+
openenv push --repo-id username/overflow-env
|
| 765 |
+
```
|
| 766 |
+
|
| 767 |
+
### Connect from training script
|
| 768 |
+
|
| 769 |
+
```python
|
| 770 |
+
# Local
|
| 771 |
+
env = OverflowEnv(base_url="http://localhost:8000")
|
| 772 |
+
|
| 773 |
+
# Docker
|
| 774 |
+
env = OverflowEnv.from_docker_image("overflow-env:latest")
|
| 775 |
+
|
| 776 |
+
# HuggingFace Space
|
| 777 |
+
env = OverflowEnv.from_env("username/overflow-env")
|
| 778 |
+
```
|
| 779 |
+
|
| 780 |
+
### openenv.yaml manifest
|
| 781 |
+
|
| 782 |
+
```yaml
|
| 783 |
+
spec_version: 1
|
| 784 |
+
name: overflow_env
|
| 785 |
+
type: space
|
| 786 |
+
runtime: fastapi
|
| 787 |
+
app: server.app:app
|
| 788 |
+
port: 8000
|
| 789 |
+
```
|
| 790 |
+
|
| 791 |
+
This tells OpenEnv tooling how to find and run the environment.
|
Dockerfile
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system deps
|
| 6 |
+
RUN apt-get update && \
|
| 7 |
+
apt-get install -y --no-install-recommends git curl && \
|
| 8 |
+
rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Copy environment code into a proper package directory
|
| 11 |
+
COPY . /app/overflow_env
|
| 12 |
+
|
| 13 |
+
# Install dependencies via pip using requirements.txt
|
| 14 |
+
RUN pip install --no-cache-dir -r /app/overflow_env/server/requirements.txt
|
| 15 |
+
|
| 16 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 17 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
| 18 |
+
|
| 19 |
+
EXPOSE 8000
|
| 20 |
+
|
| 21 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 22 |
+
CMD ["uvicorn", "overflow_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
README.md
CHANGED
|
@@ -1,12 +1,70 @@
|
|
| 1 |
---
|
| 2 |
-
title: Overflow
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Overflow OpenENV
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# Overflow Environment
|
| 15 |
+
|
| 16 |
+
An autonomous vehicle fleet oversight environment for [OpenEnv](https://github.com/meta-pytorch/OpenEnv).
|
| 17 |
+
|
| 18 |
+
## Overview
|
| 19 |
+
|
| 20 |
+
A 2D road grid with N cars. One car (Car 0) is controlled by an LLM agent, while other cars follow simple scripted driving rules. An observer detects crashes and near-misses each step and computes rewards based on safety.
|
| 21 |
+
|
| 22 |
+
## Quick Start
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
# Install dependencies
|
| 26 |
+
pip install -e .
|
| 27 |
+
|
| 28 |
+
# Run the server
|
| 29 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000 --reload
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
from overflow_env import OverflowEnv, OverflowAction
|
| 34 |
+
|
| 35 |
+
async with OverflowEnv(base_url="http://localhost:8000") as env:
|
| 36 |
+
result = await env.reset()
|
| 37 |
+
print(result.observation.scene_description)
|
| 38 |
+
|
| 39 |
+
action = OverflowAction(decision="maintain", reasoning="Road is clear ahead.")
|
| 40 |
+
result = await env.step(action)
|
| 41 |
+
print(result.observation.incident_report)
|
| 42 |
+
print(f"Reward: {result.reward}, Done: {result.done}")
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## Action Space
|
| 46 |
+
|
| 47 |
+
| Decision | Effect |
|
| 48 |
+
|----------|--------|
|
| 49 |
+
| `accelerate` | Increase speed by 5 |
|
| 50 |
+
| `brake` | Decrease speed by 5 |
|
| 51 |
+
| `lane_change_left` | Move to left lane |
|
| 52 |
+
| `lane_change_right` | Move to right lane |
|
| 53 |
+
| `maintain` | Keep current speed and lane |
|
| 54 |
+
|
| 55 |
+
## Reward Structure
|
| 56 |
+
|
| 57 |
+
| Event | Reward |
|
| 58 |
+
|-------|--------|
|
| 59 |
+
| Crash (distance < 5) | -5.0 |
|
| 60 |
+
| Near miss (distance < 15) | -1.0 |
|
| 61 |
+
| Safe step toward goal | +0.5 |
|
| 62 |
+
| Reached goal | +3.0 |
|
| 63 |
+
| Reasoning quality bonus | +0.0 to +0.3 |
|
| 64 |
+
|
| 65 |
+
## Environment Details
|
| 66 |
+
|
| 67 |
+
- **Road**: 3 lanes, ~200 units long
|
| 68 |
+
- **Cars**: 5 total (1 agent + 4 scripted)
|
| 69 |
+
- **Max steps**: 100 per episode
|
| 70 |
+
- **Speed range**: 20β90 units
|
__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Overflow Environment β Autonomous vehicle fleet oversight for OpenEnv."""
|
| 2 |
+
|
| 3 |
+
try:
|
| 4 |
+
from .client import OverflowEnv
|
| 5 |
+
except ImportError:
|
| 6 |
+
OverflowEnv = None # openenv-core not installed; training-only mode
|
| 7 |
+
|
| 8 |
+
from .models import (
|
| 9 |
+
CarStateData,
|
| 10 |
+
LaneOccupancyData,
|
| 11 |
+
OverflowAction,
|
| 12 |
+
OverflowObservation,
|
| 13 |
+
OverflowState,
|
| 14 |
+
Position,
|
| 15 |
+
ProximityData,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"OverflowAction",
|
| 20 |
+
"OverflowObservation",
|
| 21 |
+
"OverflowState",
|
| 22 |
+
"OverflowEnv",
|
| 23 |
+
"CarStateData",
|
| 24 |
+
"Position",
|
| 25 |
+
"ProximityData",
|
| 26 |
+
"LaneOccupancyData",
|
| 27 |
+
]
|
app.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenENV RL Demo β Gradio UI entrypoint for HuggingFace Spaces.
|
| 3 |
+
|
| 4 |
+
Runs inside the overflow_env package root. All imports use absolute paths
|
| 5 |
+
so they work both as a package (installed) and as a Space (flat root).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import sys, os
|
| 9 |
+
# When running as HF Space, make server/ importable with absolute paths
|
| 10 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 11 |
+
|
| 12 |
+
import math, time, threading
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.optim as optim
|
| 16 |
+
import matplotlib
|
| 17 |
+
matplotlib.use("Agg")
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
import matplotlib.patches as patches
|
| 20 |
+
import gradio as gr
|
| 21 |
+
|
| 22 |
+
from server.overflow_environment import OverflowEnvironment
|
| 23 |
+
from models import OverflowAction
|
| 24 |
+
from policies.flat_mlp_policy import FlatMLPPolicy
|
| 25 |
+
from policies.policy_spec import build_obs, build_ticket_vector, OBS_DIM
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
STEPS_PER_EPISODE = 20
|
| 29 |
+
NUM_LANES = 3
|
| 30 |
+
ROAD_LENGTH = 200
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ββ Observation adapter βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
+
|
| 35 |
+
def obs_to_vec(overflow_obs) -> np.ndarray:
|
| 36 |
+
cars = overflow_obs.cars
|
| 37 |
+
if not cars:
|
| 38 |
+
return np.zeros(OBS_DIM, dtype=np.float32)
|
| 39 |
+
ego = next((c for c in cars if c.carId == 0), cars[0])
|
| 40 |
+
ego_spd = ego.speed / 4.5
|
| 41 |
+
ego_x = ego.position.x
|
| 42 |
+
ego_y = (ego.lane - 2) * 3.7
|
| 43 |
+
tickets = []
|
| 44 |
+
for car in cars:
|
| 45 |
+
if car.carId == 0:
|
| 46 |
+
continue
|
| 47 |
+
rx = car.position.x - ego.position.x
|
| 48 |
+
ry = (car.lane - ego.lane) * 3.7
|
| 49 |
+
cs = car.speed / 4.5
|
| 50 |
+
d = math.sqrt(rx**2 + ry**2)
|
| 51 |
+
if d > 80:
|
| 52 |
+
continue
|
| 53 |
+
cl = max(ego_spd - cs * math.copysign(1, max(rx, 0.01)), 0.1)
|
| 54 |
+
tickets.append(build_ticket_vector(
|
| 55 |
+
severity_weight=1.0 if d < 8 else 0.75 if d < 15 else 0.5,
|
| 56 |
+
ttl=5.0, pos_x=rx, pos_y=ry, pos_z=0.0,
|
| 57 |
+
vel_x=cs, vel_y=0.0, vel_z=0.0, heading=0.0,
|
| 58 |
+
size_length=4.0, size_width=2.0, size_height=1.5,
|
| 59 |
+
distance=d, time_to_collision=min(d / cl, 30.0),
|
| 60 |
+
bearing=math.atan2(ry, max(rx, 0.01)),
|
| 61 |
+
ticket_type="collision_risk", entity_type="vehicle", confidence=1.0,
|
| 62 |
+
))
|
| 63 |
+
tv = np.array(tickets, dtype=np.float32) if tickets else None
|
| 64 |
+
return build_obs(ego_x=ego_x, ego_y=ego_y, ego_z=0.0,
|
| 65 |
+
ego_vx=ego_spd, ego_vy=0.0,
|
| 66 |
+
heading=0.0, speed=ego_spd,
|
| 67 |
+
steer=0.0, throttle=0.5, brake=0.0,
|
| 68 |
+
ticket_vectors=tv)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def action_to_decision(a: np.ndarray) -> str:
|
| 72 |
+
s, t, b = float(a[0]), float(a[1]), float(a[2])
|
| 73 |
+
if abs(s) > 0.35: return "lane_change_left" if s < 0 else "lane_change_right"
|
| 74 |
+
if b > 0.25: return "brake"
|
| 75 |
+
if t > 0.20: return "accelerate"
|
| 76 |
+
return "maintain"
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ββ Global training state βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 80 |
+
|
| 81 |
+
policy = FlatMLPPolicy(obs_dim=OBS_DIM)
|
| 82 |
+
optimizer = optim.Adam(policy.parameters(), lr=3e-4, eps=1e-5)
|
| 83 |
+
|
| 84 |
+
_buf_obs = []
|
| 85 |
+
_buf_acts = []
|
| 86 |
+
_buf_rews = []
|
| 87 |
+
_buf_logps = []
|
| 88 |
+
_buf_vals = []
|
| 89 |
+
_buf_dones = []
|
| 90 |
+
|
| 91 |
+
episode_history = []
|
| 92 |
+
step_log = []
|
| 93 |
+
_running = False
|
| 94 |
+
_lock = threading.Lock()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _ppo_mini_update():
|
| 98 |
+
if len(_buf_obs) < 2:
|
| 99 |
+
return
|
| 100 |
+
obs_t = torch.tensor(np.array(_buf_obs), dtype=torch.float32)
|
| 101 |
+
acts_t = torch.tensor(np.array(_buf_acts), dtype=torch.float32)
|
| 102 |
+
rews_t = torch.tensor(_buf_rews, dtype=torch.float32)
|
| 103 |
+
logp_t = torch.tensor(_buf_logps, dtype=torch.float32)
|
| 104 |
+
vals_t = torch.tensor(_buf_vals, dtype=torch.float32)
|
| 105 |
+
done_t = torch.tensor(_buf_dones, dtype=torch.float32)
|
| 106 |
+
|
| 107 |
+
gamma, lam = 0.99, 0.95
|
| 108 |
+
adv = torch.zeros_like(rews_t)
|
| 109 |
+
gae = 0.0
|
| 110 |
+
for t in reversed(range(len(rews_t))):
|
| 111 |
+
nv = 0.0 if t == len(rews_t) - 1 else float(vals_t[t + 1])
|
| 112 |
+
d = rews_t[t] + gamma * nv * (1 - done_t[t]) - vals_t[t]
|
| 113 |
+
gae = d + gamma * lam * (1 - done_t[t]) * gae
|
| 114 |
+
adv[t] = gae
|
| 115 |
+
ret = adv + vals_t
|
| 116 |
+
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
|
| 117 |
+
|
| 118 |
+
policy.train()
|
| 119 |
+
act_mean, val = policy(obs_t)
|
| 120 |
+
val = val.squeeze(-1)
|
| 121 |
+
dist = torch.distributions.Normal(act_mean, torch.ones_like(act_mean) * 0.3)
|
| 122 |
+
logp = dist.log_prob(acts_t).sum(dim=-1)
|
| 123 |
+
entropy = dist.entropy().sum(dim=-1).mean()
|
| 124 |
+
ratio = torch.exp(logp - logp_t)
|
| 125 |
+
pg = torch.max(-adv * ratio, -adv * ratio.clamp(0.8, 1.2)).mean()
|
| 126 |
+
vf = 0.5 * ((val - ret) ** 2).mean()
|
| 127 |
+
loss = pg + 0.5 * vf - 0.02 * entropy
|
| 128 |
+
optimizer.zero_grad()
|
| 129 |
+
loss.backward()
|
| 130 |
+
torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5)
|
| 131 |
+
optimizer.step()
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def run_episodes_loop():
|
| 135 |
+
global _running
|
| 136 |
+
ep_num = 0
|
| 137 |
+
env = OverflowEnvironment()
|
| 138 |
+
|
| 139 |
+
while _running:
|
| 140 |
+
ep_num += 1
|
| 141 |
+
obs = env.reset()
|
| 142 |
+
ep_rew = 0.0
|
| 143 |
+
outcome = "timeout"
|
| 144 |
+
|
| 145 |
+
_buf_obs.clear(); _buf_acts.clear(); _buf_rews.clear()
|
| 146 |
+
_buf_logps.clear(); _buf_vals.clear(); _buf_dones.clear()
|
| 147 |
+
|
| 148 |
+
for step in range(1, STEPS_PER_EPISODE + 1):
|
| 149 |
+
if not _running:
|
| 150 |
+
break
|
| 151 |
+
|
| 152 |
+
obs_vec = obs_to_vec(obs)
|
| 153 |
+
policy.eval()
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
obs_t = torch.tensor(obs_vec, dtype=torch.float32).unsqueeze(0)
|
| 156 |
+
act_mean, val = policy(obs_t)
|
| 157 |
+
dist = torch.distributions.Normal(act_mean.squeeze(0),
|
| 158 |
+
torch.ones(3) * 0.3)
|
| 159 |
+
action = dist.sample().clamp(-1, 1)
|
| 160 |
+
logp = dist.log_prob(action).sum()
|
| 161 |
+
|
| 162 |
+
decision = action_to_decision(action.numpy())
|
| 163 |
+
obs = env.step(OverflowAction(decision=decision, reasoning=""))
|
| 164 |
+
reward = float(obs.reward or 0.0)
|
| 165 |
+
done = obs.done
|
| 166 |
+
ep_rew += reward
|
| 167 |
+
|
| 168 |
+
_buf_obs.append(obs_vec)
|
| 169 |
+
_buf_acts.append(action.numpy())
|
| 170 |
+
_buf_rews.append(reward)
|
| 171 |
+
_buf_logps.append(float(logp))
|
| 172 |
+
_buf_vals.append(float(val.squeeze()))
|
| 173 |
+
_buf_dones.append(float(done))
|
| 174 |
+
|
| 175 |
+
with _lock:
|
| 176 |
+
step_log.append({
|
| 177 |
+
"ep": ep_num,
|
| 178 |
+
"step": step,
|
| 179 |
+
"decision": decision,
|
| 180 |
+
"reward": round(reward, 2),
|
| 181 |
+
"ep_reward": round(ep_rew, 2),
|
| 182 |
+
"incident": obs.incident_report or "",
|
| 183 |
+
"cars": [(c.carId, c.lane, c.position.x, c.speed)
|
| 184 |
+
for c in obs.cars],
|
| 185 |
+
})
|
| 186 |
+
|
| 187 |
+
if done:
|
| 188 |
+
outcome = "CRASH" if "CRASH" in (obs.incident_report or "") else "GOAL"
|
| 189 |
+
break
|
| 190 |
+
|
| 191 |
+
time.sleep(0.6)
|
| 192 |
+
|
| 193 |
+
_ppo_mini_update()
|
| 194 |
+
|
| 195 |
+
with _lock:
|
| 196 |
+
episode_history.append({
|
| 197 |
+
"ep": ep_num,
|
| 198 |
+
"steps": step,
|
| 199 |
+
"reward": round(ep_rew, 2),
|
| 200 |
+
"outcome": outcome,
|
| 201 |
+
})
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ββ Plot helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 205 |
+
|
| 206 |
+
DECISION_COLORS = {
|
| 207 |
+
"accelerate": "#22c55e",
|
| 208 |
+
"brake": "#ef4444",
|
| 209 |
+
"lane_change_left": "#f59e0b",
|
| 210 |
+
"lane_change_right": "#f59e0b",
|
| 211 |
+
"maintain": "#60a5fa",
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def render_road(cars_snapshot, last_decision, last_incident):
|
| 216 |
+
fig, ax = plt.subplots(figsize=(10, 2.8))
|
| 217 |
+
fig.patch.set_facecolor("#0f172a")
|
| 218 |
+
ax.set_facecolor("#1e293b")
|
| 219 |
+
|
| 220 |
+
ax.set_xlim(0, ROAD_LENGTH)
|
| 221 |
+
ax.set_ylim(0, NUM_LANES + 1)
|
| 222 |
+
ax.set_yticks([])
|
| 223 |
+
ax.set_xlabel("Position", color="#94a3b8", fontsize=9)
|
| 224 |
+
ax.tick_params(colors="#94a3b8")
|
| 225 |
+
for spine in ax.spines.values():
|
| 226 |
+
spine.set_edgecolor("#334155")
|
| 227 |
+
|
| 228 |
+
for lane in range(1, NUM_LANES):
|
| 229 |
+
ax.axhline(y=lane + 0.5, color="#334155", linewidth=1, linestyle="--", alpha=0.6)
|
| 230 |
+
|
| 231 |
+
for lane in range(1, NUM_LANES + 1):
|
| 232 |
+
ax.text(2, lane, f"L{lane}", color="#475569", fontsize=8, va="center")
|
| 233 |
+
|
| 234 |
+
ax.axvspan(160, ROAD_LENGTH, alpha=0.12, color="#22c55e")
|
| 235 |
+
ax.text(162, NUM_LANES + 0.6, "GOAL ZONE", color="#22c55e", fontsize=7, alpha=0.8)
|
| 236 |
+
|
| 237 |
+
car_w, car_h = 8, 0.55
|
| 238 |
+
for car_id, lane, pos_x, speed in cars_snapshot:
|
| 239 |
+
is_ego = car_id == 0
|
| 240 |
+
color = "#3b82f6" if is_ego else "#94a3b8"
|
| 241 |
+
outline = "#60a5fa" if is_ego else "#475569"
|
| 242 |
+
lw = 2.0 if is_ego else 1.0
|
| 243 |
+
rect = patches.FancyBboxPatch(
|
| 244 |
+
(pos_x - car_w / 2, lane - car_h / 2),
|
| 245 |
+
car_w, car_h,
|
| 246 |
+
boxstyle="round,pad=0.05",
|
| 247 |
+
facecolor=color, edgecolor=outline, linewidth=lw, alpha=0.92,
|
| 248 |
+
)
|
| 249 |
+
ax.add_patch(rect)
|
| 250 |
+
label = f"{'EGO' if is_ego else f'C{car_id}'}\n{speed:.0f}"
|
| 251 |
+
ax.text(pos_x, lane, label, ha="center", va="center",
|
| 252 |
+
fontsize=6.5, color="white", fontweight="bold" if is_ego else "normal")
|
| 253 |
+
|
| 254 |
+
dec_color = DECISION_COLORS.get(last_decision, "#60a5fa")
|
| 255 |
+
ax.text(ROAD_LENGTH - 2, NUM_LANES + 0.65,
|
| 256 |
+
f"Action: {last_decision.replace('_', ' ').upper()}",
|
| 257 |
+
color=dec_color, fontsize=8, fontweight="bold", ha="right")
|
| 258 |
+
|
| 259 |
+
if "CRASH" in last_incident:
|
| 260 |
+
ax.text(ROAD_LENGTH / 2, NUM_LANES + 0.65, "CRASH",
|
| 261 |
+
color="#ef4444", fontsize=10, fontweight="bold", ha="center")
|
| 262 |
+
elif "NEAR MISS" in last_incident:
|
| 263 |
+
ax.text(ROAD_LENGTH / 2, NUM_LANES + 0.65, "NEAR MISS",
|
| 264 |
+
color="#f59e0b", fontsize=9, fontweight="bold", ha="center")
|
| 265 |
+
elif "GOAL" in last_incident:
|
| 266 |
+
ax.text(ROAD_LENGTH / 2, NUM_LANES + 0.65, "GOAL REACHED",
|
| 267 |
+
color="#22c55e", fontsize=10, fontweight="bold", ha="center")
|
| 268 |
+
|
| 269 |
+
plt.tight_layout(pad=0.3)
|
| 270 |
+
return fig
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def render_reward_curve(eps):
|
| 274 |
+
fig, ax = plt.subplots(figsize=(10, 2.8))
|
| 275 |
+
fig.patch.set_facecolor("#0f172a")
|
| 276 |
+
ax.set_facecolor("#1e293b")
|
| 277 |
+
for spine in ax.spines.values():
|
| 278 |
+
spine.set_edgecolor("#334155")
|
| 279 |
+
ax.tick_params(colors="#94a3b8")
|
| 280 |
+
ax.set_xlabel("Episode", color="#94a3b8", fontsize=9)
|
| 281 |
+
ax.set_ylabel("Total Reward", color="#94a3b8", fontsize=9)
|
| 282 |
+
|
| 283 |
+
if not eps:
|
| 284 |
+
ax.text(0.5, 0.5, "Waiting for episodes...", transform=ax.transAxes,
|
| 285 |
+
ha="center", va="center", color="#475569", fontsize=11)
|
| 286 |
+
plt.tight_layout(pad=0.3)
|
| 287 |
+
return fig
|
| 288 |
+
|
| 289 |
+
xs = [e["ep"] for e in eps]
|
| 290 |
+
ys = [e["reward"] for e in eps]
|
| 291 |
+
outcome_colors = {"CRASH": "#ef4444", "GOAL": "#22c55e", "timeout": "#60a5fa"}
|
| 292 |
+
for x, y, e in zip(xs, ys, eps):
|
| 293 |
+
ax.bar(x, y, color=outcome_colors.get(e["outcome"], "#60a5fa"), alpha=0.6, width=0.7)
|
| 294 |
+
|
| 295 |
+
if len(ys) >= 3:
|
| 296 |
+
w = min(5, len(ys))
|
| 297 |
+
smoothed = np.convolve(ys, np.ones(w) / w, mode="valid")
|
| 298 |
+
ax.plot(xs[w - 1:], smoothed, color="#f8fafc", linewidth=2)
|
| 299 |
+
|
| 300 |
+
ax.axhline(0, color="#334155", linewidth=0.8)
|
| 301 |
+
|
| 302 |
+
from matplotlib.patches import Patch
|
| 303 |
+
legend_els = [Patch(facecolor="#ef4444", label="crash"),
|
| 304 |
+
Patch(facecolor="#22c55e", label="goal"),
|
| 305 |
+
Patch(facecolor="#60a5fa", label="timeout")]
|
| 306 |
+
ax.legend(handles=legend_els, facecolor="#1e293b", labelcolor="#94a3b8",
|
| 307 |
+
fontsize=8, framealpha=0.6, edgecolor="#334155", loc="upper left")
|
| 308 |
+
|
| 309 |
+
plt.tight_layout(pad=0.3)
|
| 310 |
+
return fig
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 314 |
+
|
| 315 |
+
def start_training():
|
| 316 |
+
global _running
|
| 317 |
+
if not _running:
|
| 318 |
+
_running = True
|
| 319 |
+
step_log.clear()
|
| 320 |
+
episode_history.clear()
|
| 321 |
+
threading.Thread(target=run_episodes_loop, daemon=True).start()
|
| 322 |
+
return gr.update(value="Running...", interactive=False), gr.update(interactive=True)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def stop_training():
|
| 326 |
+
global _running
|
| 327 |
+
_running = False
|
| 328 |
+
return gr.update(value="Start", interactive=True), gr.update(interactive=False)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def get_updates():
|
| 332 |
+
with _lock:
|
| 333 |
+
logs = list(step_log[-20:])
|
| 334 |
+
eps = list(episode_history[-50:])
|
| 335 |
+
last = step_log[-1] if step_log else None
|
| 336 |
+
|
| 337 |
+
road_fig = render_road(last["cars"], last["decision"], last["incident"]) if last \
|
| 338 |
+
else render_road([], "maintain", "")
|
| 339 |
+
reward_fig = render_reward_curve(eps)
|
| 340 |
+
|
| 341 |
+
lines = []
|
| 342 |
+
for e in reversed(logs):
|
| 343 |
+
flag = ""
|
| 344 |
+
if "CRASH" in e["incident"]: flag = " π₯"
|
| 345 |
+
elif "GOAL" in e["incident"]: flag = " β"
|
| 346 |
+
elif "NEAR MISS" in e["incident"]: flag = " β "
|
| 347 |
+
lines.append(
|
| 348 |
+
f"ep {e['ep']:>3d} | step {e['step']:>2d} | "
|
| 349 |
+
f"{e['decision']:<20} | r={e['reward']:>+6.2f} | "
|
| 350 |
+
f"ep_total={e['ep_reward']:>7.2f}{flag}"
|
| 351 |
+
)
|
| 352 |
+
step_text = "\n".join(lines) if lines else "Waiting for first episode..."
|
| 353 |
+
|
| 354 |
+
ep_lines = ["Episode | Steps | Total Reward | Outcome", "-" * 44]
|
| 355 |
+
for e in reversed(eps[-15:]):
|
| 356 |
+
ep_lines.append(
|
| 357 |
+
f" {e['ep']:>4d} | {e['steps']:>3d} | "
|
| 358 |
+
f" {e['reward']:>+8.2f} | {e['outcome']}"
|
| 359 |
+
)
|
| 360 |
+
ep_text = "\n".join(ep_lines) if eps else "No episodes completed yet."
|
| 361 |
+
|
| 362 |
+
if len(eps) >= 2:
|
| 363 |
+
rewards = [e["reward"] for e in eps]
|
| 364 |
+
n = len(rewards)
|
| 365 |
+
half = max(n // 2, 1)
|
| 366 |
+
early = sum(rewards[:half]) / half
|
| 367 |
+
late = sum(rewards[half:]) / max(n - half, 1)
|
| 368 |
+
arrow = "β improving" if late > early else "β declining"
|
| 369 |
+
trend_text = f"Early {half} eps: {early:+.2f} β Last {n-half} eps: {late:+.2f} {arrow}"
|
| 370 |
+
else:
|
| 371 |
+
trend_text = "Collecting data..."
|
| 372 |
+
|
| 373 |
+
status = "β RUNNING" if _running else "β STOPPED"
|
| 374 |
+
return road_fig, reward_fig, step_text, ep_text, trend_text, status
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
_EMPTY_ROAD = render_road([], "maintain", "")
|
| 378 |
+
_EMPTY_REWARD = render_reward_curve([])
|
| 379 |
+
|
| 380 |
+
with gr.Blocks(title="OpenENV RL Demo", theme=gr.themes.Base()) as demo:
|
| 381 |
+
gr.Markdown(
|
| 382 |
+
"# OpenENV RL β Live Policy Training\n"
|
| 383 |
+
"**FlatMLPPolicy** drives Car 0 on a 3-lane road for 20 steps per episode. "
|
| 384 |
+
"PPO mini-update after each episode β watch rewards trend upward over time."
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
with gr.Row():
|
| 388 |
+
start_btn = gr.Button("Start", variant="primary", scale=1)
|
| 389 |
+
stop_btn = gr.Button("Stop", variant="stop", interactive=False, scale=1)
|
| 390 |
+
status_box = gr.Textbox(value="β STOPPED", label="Status",
|
| 391 |
+
interactive=False, scale=0, min_width=130)
|
| 392 |
+
|
| 393 |
+
gr.Markdown("### Road View")
|
| 394 |
+
road_plot = gr.Plot(value=_EMPTY_ROAD, show_label=False)
|
| 395 |
+
|
| 396 |
+
gr.Markdown("### Episode Reward Curve")
|
| 397 |
+
reward_plot = gr.Plot(value=_EMPTY_REWARD, show_label=False)
|
| 398 |
+
|
| 399 |
+
gr.Markdown("### Live Step Feed (last 20 steps)")
|
| 400 |
+
step_display = gr.Textbox(
|
| 401 |
+
value="Press Start to begin...",
|
| 402 |
+
lines=14, max_lines=14, interactive=False,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
with gr.Row():
|
| 406 |
+
with gr.Column():
|
| 407 |
+
gr.Markdown("### Episode History")
|
| 408 |
+
ep_display = gr.Textbox(lines=10, interactive=False)
|
| 409 |
+
with gr.Column():
|
| 410 |
+
gr.Markdown("### Reward Trend")
|
| 411 |
+
trend_display = gr.Textbox(lines=3, interactive=False)
|
| 412 |
+
|
| 413 |
+
timer = gr.Timer(value=1.0)
|
| 414 |
+
timer.tick(
|
| 415 |
+
fn=get_updates,
|
| 416 |
+
outputs=[road_plot, reward_plot, step_display, ep_display, trend_display, status_box],
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
start_btn.click(fn=start_training, outputs=[start_btn, stop_btn])
|
| 420 |
+
stop_btn.click(fn=stop_training, outputs=[start_btn, stop_btn])
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
if __name__ == "__main__":
|
| 424 |
+
demo.launch()
|
client.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Overflow Environment Client.
|
| 3 |
+
|
| 4 |
+
Provides the client for connecting to an Overflow Environment server
|
| 5 |
+
via WebSocket for persistent sessions.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, Dict, List
|
| 9 |
+
|
| 10 |
+
from openenv.core.client_types import StepResult
|
| 11 |
+
from openenv.core.env_client import EnvClient
|
| 12 |
+
|
| 13 |
+
from .models import (
|
| 14 |
+
CarStateData,
|
| 15 |
+
LaneOccupancyData,
|
| 16 |
+
OverflowAction,
|
| 17 |
+
OverflowObservation,
|
| 18 |
+
OverflowState,
|
| 19 |
+
Position,
|
| 20 |
+
ProximityData,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class OverflowEnv(EnvClient[OverflowAction, OverflowObservation, OverflowState]):
|
| 25 |
+
"""
|
| 26 |
+
WebSocket client for the Overflow Environment.
|
| 27 |
+
|
| 28 |
+
Example:
|
| 29 |
+
>>> with OverflowEnv(base_url="http://localhost:8000") as env:
|
| 30 |
+
... result = env.reset()
|
| 31 |
+
... print(result.observation.scene_description)
|
| 32 |
+
... print(result.observation.cars) # structured car data
|
| 33 |
+
... action = OverflowAction(decision="maintain", reasoning="Safe for now")
|
| 34 |
+
... result = env.step(action)
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def _step_payload(self, action: OverflowAction) -> Dict[str, Any]:
|
| 38 |
+
"""Convert OverflowAction to JSON payload for step request."""
|
| 39 |
+
return {
|
| 40 |
+
"decision": action.decision,
|
| 41 |
+
"reasoning": action.reasoning,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[OverflowObservation]:
|
| 45 |
+
"""Parse server response into StepResult[OverflowObservation]."""
|
| 46 |
+
obs_data = payload.get("observation", {})
|
| 47 |
+
|
| 48 |
+
# Parse structured car data
|
| 49 |
+
cars = [
|
| 50 |
+
CarStateData(
|
| 51 |
+
carId=c["carId"],
|
| 52 |
+
lane=c["lane"],
|
| 53 |
+
position=Position(**c["position"]),
|
| 54 |
+
speed=c["speed"],
|
| 55 |
+
acceleration=c.get("acceleration", 0.0),
|
| 56 |
+
)
|
| 57 |
+
for c in obs_data.get("cars", [])
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
proximities = [
|
| 61 |
+
ProximityData(**p) for p in obs_data.get("proximities", [])
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
lane_occupancies = [
|
| 65 |
+
LaneOccupancyData(**lo) for lo in obs_data.get("lane_occupancies", [])
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
observation = OverflowObservation(
|
| 69 |
+
scene_description=obs_data.get("scene_description", ""),
|
| 70 |
+
incident_report=obs_data.get("incident_report", ""),
|
| 71 |
+
done=payload.get("done", False),
|
| 72 |
+
reward=payload.get("reward"),
|
| 73 |
+
cars=cars,
|
| 74 |
+
proximities=proximities,
|
| 75 |
+
lane_occupancies=lane_occupancies,
|
| 76 |
+
)
|
| 77 |
+
return StepResult(
|
| 78 |
+
observation=observation,
|
| 79 |
+
reward=payload.get("reward"),
|
| 80 |
+
done=payload.get("done", False),
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def _parse_state(self, payload: Dict[str, Any]) -> OverflowState:
|
| 84 |
+
"""Parse server response into OverflowState."""
|
| 85 |
+
return OverflowState(
|
| 86 |
+
episode_id=payload.get("episode_id"),
|
| 87 |
+
step_count=payload.get("step_count", 0),
|
| 88 |
+
crash_count=payload.get("crash_count", 0),
|
| 89 |
+
near_miss_count=payload.get("near_miss_count", 0),
|
| 90 |
+
cars_reached_goal=payload.get("cars_reached_goal", 0),
|
| 91 |
+
total_cars=payload.get("total_cars", 5),
|
| 92 |
+
)
|
models.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data models for the Overflow Environment.
|
| 3 |
+
|
| 4 |
+
An autonomous vehicle fleet oversight environment where an LLM agent
|
| 5 |
+
controls one car on a 2D road grid while other cars follow scripted rules.
|
| 6 |
+
|
| 7 |
+
Structured observation fields (cars, proximities, lane_occupancies) are
|
| 8 |
+
compatible with the Overflow frontend's CarState / AnomalyObservation types.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from typing import Any, Dict, List, Optional
|
| 12 |
+
|
| 13 |
+
from pydantic import BaseModel, Field
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 17 |
+
except ImportError:
|
| 18 |
+
class Action(BaseModel): pass
|
| 19 |
+
class Observation(BaseModel):
|
| 20 |
+
done: bool = False
|
| 21 |
+
reward: float = 0.0
|
| 22 |
+
class State(BaseModel):
|
| 23 |
+
episode_id: str = ""
|
| 24 |
+
step_count: int = 0
|
| 25 |
+
|
| 26 |
+
# ββ Structured sub-models (frontend-compatible) βββββββββββββββββββββββββ
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Position(BaseModel):
|
| 30 |
+
"""2D position on the road. x = longitudinal, y = lateral."""
|
| 31 |
+
|
| 32 |
+
x: float = 0.0
|
| 33 |
+
y: float = 0.0
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class CarStateData(BaseModel):
|
| 37 |
+
"""
|
| 38 |
+
Structured per-car snapshot β matches the frontend CarState interface.
|
| 39 |
+
|
| 40 |
+
Frontend type:
|
| 41 |
+
interface CarState {
|
| 42 |
+
carId: number; lane: number;
|
| 43 |
+
position: { x: number; y: number };
|
| 44 |
+
speed: number; acceleration: number;
|
| 45 |
+
}
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
carId: int
|
| 49 |
+
lane: int
|
| 50 |
+
position: Position
|
| 51 |
+
speed: float
|
| 52 |
+
acceleration: float = 0.0
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class ProximityData(BaseModel):
|
| 56 |
+
"""Pairwise distance between two cars."""
|
| 57 |
+
|
| 58 |
+
carA: int
|
| 59 |
+
carB: int
|
| 60 |
+
distance: float
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class LaneOccupancyData(BaseModel):
|
| 64 |
+
"""Which cars are in a given lane."""
|
| 65 |
+
|
| 66 |
+
lane: int
|
| 67 |
+
carIds: List[int]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ββ OpenEnv core models βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class OverflowAction(Action):
|
| 74 |
+
"""
|
| 75 |
+
Action for the Overflow environment.
|
| 76 |
+
|
| 77 |
+
The LLM agent outputs a driving decision and optional reasoning.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
decision: str = Field(
|
| 81 |
+
default="maintain",
|
| 82 |
+
description="Driving decision: accelerate, brake, lane_change_left, lane_change_right, maintain",
|
| 83 |
+
)
|
| 84 |
+
reasoning: str = Field(
|
| 85 |
+
default="",
|
| 86 |
+
description="The LLM's chain-of-thought reasoning for this decision",
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class OverflowObservation(Observation):
|
| 91 |
+
"""
|
| 92 |
+
Observation from the Overflow environment.
|
| 93 |
+
|
| 94 |
+
Contains both:
|
| 95 |
+
- Text fields (scene_description, incident_report) for the LLM to read.
|
| 96 |
+
- Structured fields (cars, proximities, lane_occupancies) for the frontend
|
| 97 |
+
to render, matching the Overflow frontend AnomalyObservation shape.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
# ββ Text (for the LLM) ββ
|
| 101 |
+
scene_description: str = Field(
|
| 102 |
+
default="", description="Text description of the traffic scene"
|
| 103 |
+
)
|
| 104 |
+
incident_report: str = Field(
|
| 105 |
+
default="", description="Observer's incident report, empty if no incident"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# ββ Structured (for the frontend / viz) ββ
|
| 109 |
+
cars: List[CarStateData] = Field(
|
| 110 |
+
default_factory=list, description="Structured state of every car"
|
| 111 |
+
)
|
| 112 |
+
proximities: List[ProximityData] = Field(
|
| 113 |
+
default_factory=list, description="Pairwise proximity measurements"
|
| 114 |
+
)
|
| 115 |
+
lane_occupancies: List[LaneOccupancyData] = Field(
|
| 116 |
+
default_factory=list, description="Per-lane vehicle occupancy"
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class OverflowState(State):
|
| 121 |
+
"""
|
| 122 |
+
Internal state for the Overflow environment.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
crash_count: int = Field(default=0, description="Number of crashes this episode")
|
| 126 |
+
near_miss_count: int = Field(
|
| 127 |
+
default=0, description="Number of near misses this episode"
|
| 128 |
+
)
|
| 129 |
+
cars_reached_goal: int = Field(
|
| 130 |
+
default=0, description="Number of cars that reached their goal"
|
| 131 |
+
)
|
| 132 |
+
total_cars: int = Field(
|
| 133 |
+
default=5, description="Total number of cars in the simulation"
|
| 134 |
+
)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: overflow_env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
policies/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_policy import BasePolicy
|
| 2 |
+
from .flat_mlp_policy import FlatMLPPolicy
|
| 3 |
+
from .ticket_attention_policy import TicketAttentionPolicy
|
| 4 |
+
|
| 5 |
+
__all__ = ["BasePolicy", "FlatMLPPolicy", "TicketAttentionPolicy"]
|
policies/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (381 Bytes). View file
|
|
|
policies/__pycache__/base_policy.cpython-314.pyc
ADDED
|
Binary file (4.05 kB). View file
|
|
|
policies/__pycache__/flat_mlp_policy.cpython-314.pyc
ADDED
|
Binary file (3.71 kB). View file
|
|
|
policies/__pycache__/policy_spec.cpython-314.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
policies/__pycache__/ticket_attention_policy.cpython-314.pyc
ADDED
|
Binary file (11.3 kB). View file
|
|
|
policies/base_policy.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BasePolicy β abstract interface all policies implement.
|
| 3 |
+
|
| 4 |
+
All policies expose the same predict() and train_step() API so the
|
| 5 |
+
curriculum trainer can swap them out transparently.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import abc
|
| 11 |
+
from typing import Any, Dict, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class BasePolicy(nn.Module, abc.ABC):
|
| 19 |
+
"""
|
| 20 |
+
Abstract base for all driving policies.
|
| 21 |
+
|
| 22 |
+
Subclasses implement:
|
| 23 |
+
forward(obs_tensor) β action_tensor, value_tensor
|
| 24 |
+
encode_obs(obs_np) β torch.Tensor
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, obs_dim: int, action_dim: int = 3):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.obs_dim = obs_dim
|
| 30 |
+
self.action_dim = action_dim
|
| 31 |
+
|
| 32 |
+
@abc.abstractmethod
|
| 33 |
+
def forward(
|
| 34 |
+
self, obs: torch.Tensor
|
| 35 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 36 |
+
"""
|
| 37 |
+
Returns:
|
| 38 |
+
action_mean β shape (B, action_dim)
|
| 39 |
+
value β shape (B, 1)
|
| 40 |
+
"""
|
| 41 |
+
...
|
| 42 |
+
|
| 43 |
+
def predict(
|
| 44 |
+
self,
|
| 45 |
+
obs: np.ndarray,
|
| 46 |
+
deterministic: bool = False,
|
| 47 |
+
) -> np.ndarray:
|
| 48 |
+
"""Numpy in, numpy out. Used by the env during rollout."""
|
| 49 |
+
self.eval()
|
| 50 |
+
with torch.no_grad():
|
| 51 |
+
t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0)
|
| 52 |
+
mean, _ = self.forward(t)
|
| 53 |
+
if deterministic:
|
| 54 |
+
action = mean
|
| 55 |
+
else:
|
| 56 |
+
action = mean + torch.randn_like(mean) * 0.1
|
| 57 |
+
return action.squeeze(0).numpy()
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def _mlp(dims: list[int], activation=nn.Tanh) -> nn.Sequential:
|
| 61 |
+
layers = []
|
| 62 |
+
for i in range(len(dims) - 1):
|
| 63 |
+
layers.append(nn.Linear(dims[i], dims[i + 1]))
|
| 64 |
+
if i < len(dims) - 2:
|
| 65 |
+
layers.append(activation())
|
| 66 |
+
return nn.Sequential(*layers)
|
policies/flat_mlp_policy.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FlatMLPPolicy β sanity-check baseline.
|
| 3 |
+
|
| 4 |
+
Concatenates the full observation (ego + all tickets flattened) and passes
|
| 5 |
+
it through a standard MLP. No attention, no structure.
|
| 6 |
+
|
| 7 |
+
Use this to:
|
| 8 |
+
1. Verify the reward signal and environment are working
|
| 9 |
+
2. Establish a performance floor
|
| 10 |
+
3. Confirm that TicketAttentionPolicy actually improves over this
|
| 11 |
+
|
| 12 |
+
If FlatMLPPolicy can't learn Stage 1 survival, the reward or env is broken.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from .base_policy import BasePolicy
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class FlatMLPPolicy(BasePolicy):
|
| 24 |
+
"""Standard 3-layer MLP over the full flat observation."""
|
| 25 |
+
|
| 26 |
+
def __init__(self, obs_dim: int, hidden: int = 256):
|
| 27 |
+
super().__init__(obs_dim)
|
| 28 |
+
|
| 29 |
+
self.actor = nn.Sequential(
|
| 30 |
+
nn.Linear(obs_dim, hidden), nn.LayerNorm(hidden), nn.Tanh(),
|
| 31 |
+
nn.Linear(hidden, hidden), nn.Tanh(),
|
| 32 |
+
nn.Linear(hidden, hidden // 2), nn.Tanh(),
|
| 33 |
+
nn.Linear(hidden // 2, 3), nn.Tanh(),
|
| 34 |
+
)
|
| 35 |
+
self.critic = nn.Sequential(
|
| 36 |
+
nn.Linear(obs_dim, hidden), nn.Tanh(),
|
| 37 |
+
nn.Linear(hidden, hidden // 2), nn.Tanh(),
|
| 38 |
+
nn.Linear(hidden // 2, 1),
|
| 39 |
+
)
|
| 40 |
+
self._init_weights()
|
| 41 |
+
|
| 42 |
+
def _init_weights(self):
|
| 43 |
+
for m in self.modules():
|
| 44 |
+
if isinstance(m, nn.Linear):
|
| 45 |
+
nn.init.orthogonal_(m.weight, gain=1.0)
|
| 46 |
+
nn.init.zeros_(m.bias)
|
| 47 |
+
nn.init.orthogonal_(self.actor[-2].weight, gain=0.01)
|
| 48 |
+
|
| 49 |
+
def forward(self, obs: torch.Tensor):
|
| 50 |
+
return self.actor(obs), self.critic(obs)
|
policies/policy_spec.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Policy data input specifications β formal contracts for observation, action, and ticket data.
|
| 3 |
+
|
| 4 |
+
This module defines the exact data shapes, normalization ranges, and semantic meaning
|
| 5 |
+
of every field consumed by OpenENV policies. Use this as the reference when:
|
| 6 |
+
|
| 7 |
+
1. Building a new environment that targets these policies
|
| 8 |
+
2. Writing a bridge/adapter from a different simulator
|
| 9 |
+
3. Implementing a new policy that must interoperate with the existing set
|
| 10 |
+
|
| 11 |
+
All policies share the same raw observation layout (EGO + ticket matrix).
|
| 12 |
+
Specialized policies (ThreatAvoidance, SystemFailure) select subsets internally.
|
| 13 |
+
|
| 14 |
+
Example usage:
|
| 15 |
+
from openenv.policies.policy_spec import ObsSpec, ActionSpec, validate_obs
|
| 16 |
+
|
| 17 |
+
spec = ObsSpec()
|
| 18 |
+
obs = my_env.get_observation()
|
| 19 |
+
validate_obs(obs, spec) # raises ValueError on shape/range mismatch
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
from dataclasses import dataclass, field
|
| 25 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ββ Ego state specification ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
|
| 32 |
+
EGO_STATE_DIM = 11
|
| 33 |
+
|
| 34 |
+
@dataclass(frozen=True)
|
| 35 |
+
class EgoField:
|
| 36 |
+
"""Description of a single ego state field."""
|
| 37 |
+
index: int
|
| 38 |
+
name: str
|
| 39 |
+
unit: str
|
| 40 |
+
raw_range: Tuple[float, float] # physical range before normalization
|
| 41 |
+
norm_divisor: float # obs_value = raw_value / norm_divisor
|
| 42 |
+
description: str
|
| 43 |
+
|
| 44 |
+
EGO_FIELDS: List[EgoField] = [
|
| 45 |
+
EgoField(0, "x", "m", (-5000, 5000), 1000.0, "Forward displacement from episode start"),
|
| 46 |
+
EgoField(1, "y", "m", (-6.0, 6.0), 3.7, "Lateral displacement (0 = lane center, + = left)"),
|
| 47 |
+
EgoField(2, "z", "m", (-10, 10), 10.0, "Vertical position (flat road = 0)"),
|
| 48 |
+
EgoField(3, "vx", "m/s", (-20, 20), 20.0, "Forward velocity in world frame"),
|
| 49 |
+
EgoField(4, "vy", "m/s", (-20, 20), 20.0, "Lateral velocity in world frame"),
|
| 50 |
+
EgoField(5, "vz", "m/s", (0, 0), 1.0, "Vertical velocity (always 0 on flat road)"),
|
| 51 |
+
EgoField(6, "heading_sin", "rad", (-1, 1), 1.0, "sin(heading angle), 0 = forward"),
|
| 52 |
+
EgoField(7, "heading_cos", "rad", (-1, 1), 1.0, "cos(heading angle), 1 = forward"),
|
| 53 |
+
EgoField(8, "speed", "m/s", (0, 20), 20.0, "Scalar speed = sqrt(vx^2 + vy^2)"),
|
| 54 |
+
EgoField(9, "steer", "norm", (-1, 1), 1.0, "Current steering command [-1=full left, 1=full right]"),
|
| 55 |
+
EgoField(10, "net_drive", "norm", (-1, 1), 1.0, "throttle - brake [-1=full brake, 1=full throttle]"),
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ββ Ticket vector specification ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 60 |
+
|
| 61 |
+
TICKET_VECTOR_DIM = 37 # 18 fixed + 14 type one-hot + 5 entity one-hot
|
| 62 |
+
MAX_TICKETS = 16
|
| 63 |
+
|
| 64 |
+
# Ticket types (14 total) β one-hot encoded starting at index 18
|
| 65 |
+
TICKET_TYPES = [
|
| 66 |
+
"collision_risk", "sudden_brake", "side_impact", "head_on",
|
| 67 |
+
"merge_cut", "rear_end_risk",
|
| 68 |
+
"pedestrian_crossing", "cyclist_lane",
|
| 69 |
+
"tire_blowout", "brake_fade", "steering_loss", "sensor_occlusion",
|
| 70 |
+
"road_hazard", "weather_visibility",
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
# Entity types (5 total) β one-hot encoded after ticket types
|
| 74 |
+
ENTITY_TYPES = ["vehicle", "pedestrian", "cyclist", "obstacle", "system"]
|
| 75 |
+
|
| 76 |
+
# Verify dimension
|
| 77 |
+
assert 18 + len(TICKET_TYPES) + len(ENTITY_TYPES) == TICKET_VECTOR_DIM, (
|
| 78 |
+
f"Ticket vector dim mismatch: 18 + {len(TICKET_TYPES)} + {len(ENTITY_TYPES)} "
|
| 79 |
+
f"!= {TICKET_VECTOR_DIM}"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
@dataclass(frozen=True)
|
| 83 |
+
class TicketField:
|
| 84 |
+
"""Description of a single ticket vector field."""
|
| 85 |
+
offset: int # index within the TICKET_VECTOR_DIM vector
|
| 86 |
+
length: int # number of floats
|
| 87 |
+
name: str
|
| 88 |
+
unit: str
|
| 89 |
+
raw_range: Tuple[float, float]
|
| 90 |
+
norm_divisor: float
|
| 91 |
+
description: str
|
| 92 |
+
|
| 93 |
+
TICKET_FIELDS: List[TicketField] = [
|
| 94 |
+
TicketField(0, 1, "severity_weight", "norm", (0, 1), 1.0, "Severity: 0.25=LOW, 0.5=MED, 0.75=HIGH, 1.0=CRITICAL"),
|
| 95 |
+
TicketField(1, 1, "ttl_norm", "s", (0, 10), 10.0, "Time-to-live remaining, clamped to [0,1]"),
|
| 96 |
+
TicketField(2, 1, "pos_x", "m", (-100, 100), 100.0, "Ego-relative X (forward positive)"),
|
| 97 |
+
TicketField(3, 1, "pos_y", "m", (-50, 50), 50.0, "Ego-relative Y (left positive)"),
|
| 98 |
+
TicketField(4, 1, "pos_z", "m", (-10, 10), 10.0, "Ego-relative Z (up positive)"),
|
| 99 |
+
TicketField(5, 1, "vel_x", "m/s", (-30, 30), 30.0, "Entity velocity X in world frame"),
|
| 100 |
+
TicketField(6, 1, "vel_y", "m/s", (-30, 30), 30.0, "Entity velocity Y in world frame"),
|
| 101 |
+
TicketField(7, 1, "vel_z", "m/s", (-10, 10), 10.0, "Entity velocity Z in world frame"),
|
| 102 |
+
TicketField(8, 1, "heading_sin", "rad", (-1, 1), 1.0, "sin(entity heading relative to ego)"),
|
| 103 |
+
TicketField(9, 1, "heading_cos", "rad", (-1, 1), 1.0, "cos(entity heading relative to ego)"),
|
| 104 |
+
TicketField(10, 1, "size_length", "m", (0, 10), 10.0, "Entity bounding box length"),
|
| 105 |
+
TicketField(11, 1, "size_width", "m", (0, 5), 5.0, "Entity bounding box width"),
|
| 106 |
+
TicketField(12, 1, "size_height", "m", (0, 4), 4.0, "Entity bounding box height"),
|
| 107 |
+
TicketField(13, 1, "distance_norm", "m", (0, 100), 100.0, "Euclidean distance to ego, clamped to [0,1]"),
|
| 108 |
+
TicketField(14, 1, "ttc_norm", "s", (0, 30), 30.0, "Time-to-collision, clamped to [0,1]. 1.0 = no collision"),
|
| 109 |
+
TicketField(15, 1, "bearing_sin", "rad", (-1, 1), 1.0, "sin(bearing angle from ego forward axis)"),
|
| 110 |
+
TicketField(16, 1, "bearing_cos", "rad", (-1, 1), 1.0, "cos(bearing angle from ego forward axis)"),
|
| 111 |
+
TicketField(17, 1, "confidence", "norm", (0, 1), 1.0, "Perception confidence [0=unreliable, 1=certain]"),
|
| 112 |
+
TicketField(18, len(TICKET_TYPES), "type_onehot", "bool", (0, 1), 1.0, "One-hot ticket type"),
|
| 113 |
+
TicketField(18 + len(TICKET_TYPES), len(ENTITY_TYPES), "entity_onehot", "bool", (0, 1), 1.0, "One-hot entity type"),
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ββ Full observation specification βββββββββββββββββββββββββββββββββββββββββββ
|
| 118 |
+
|
| 119 |
+
OBS_DIM = EGO_STATE_DIM + MAX_TICKETS * TICKET_VECTOR_DIM # 11 + 16*37 = 603
|
| 120 |
+
|
| 121 |
+
@dataclass(frozen=True)
|
| 122 |
+
class ObsSpec:
|
| 123 |
+
"""Complete observation space specification."""
|
| 124 |
+
ego_dim: int = EGO_STATE_DIM
|
| 125 |
+
ticket_dim: int = TICKET_VECTOR_DIM
|
| 126 |
+
max_tickets: int = MAX_TICKETS
|
| 127 |
+
total_dim: int = OBS_DIM
|
| 128 |
+
dtype: str = "float32"
|
| 129 |
+
value_range: Tuple[float, float] = (-1.0, 1.0)
|
| 130 |
+
|
| 131 |
+
# Layout: obs[0:ego_dim] = ego state
|
| 132 |
+
# obs[ego_dim:] reshaped to (max_tickets, ticket_dim)
|
| 133 |
+
# Tickets are sorted by severity desc, distance asc. Zero-padded rows = empty slots.
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# ββ Action specification βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 137 |
+
|
| 138 |
+
@dataclass(frozen=True)
|
| 139 |
+
class ActionField:
|
| 140 |
+
index: int
|
| 141 |
+
name: str
|
| 142 |
+
raw_range: Tuple[float, float]
|
| 143 |
+
description: str
|
| 144 |
+
|
| 145 |
+
ACTION_DIM = 3
|
| 146 |
+
|
| 147 |
+
ACTION_FIELDS: List[ActionField] = [
|
| 148 |
+
ActionField(0, "steer", (-1.0, 1.0), "Steering command. -1=full left, +1=full right. Scaled by MAX_STEER=0.6 rad"),
|
| 149 |
+
ActionField(1, "throttle", (-1.0, 1.0), "Throttle command. Only positive values used (clipped to [0,1]). Scaled by MAX_ACCEL=4.0 m/s^2"),
|
| 150 |
+
ActionField(2, "brake", (-1.0, 1.0), "Brake command. Only positive values used (clipped to [0,1]). Scaled by MAX_BRAKE=8.0 m/s^2"),
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
@dataclass(frozen=True)
|
| 154 |
+
class ActionSpec:
|
| 155 |
+
"""Action space specification."""
|
| 156 |
+
dim: int = ACTION_DIM
|
| 157 |
+
dtype: str = "float32"
|
| 158 |
+
value_range: Tuple[float, float] = (-1.0, 1.0)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# ββ Policy input requirements ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 162 |
+
|
| 163 |
+
@dataclass(frozen=True)
|
| 164 |
+
class PolicyInputSpec:
|
| 165 |
+
"""Describes what a specific policy reads from the observation."""
|
| 166 |
+
name: str
|
| 167 |
+
reads_ego: bool
|
| 168 |
+
ego_indices: Tuple[int, ...] # which ego fields are used
|
| 169 |
+
reads_tickets: bool
|
| 170 |
+
ticket_filter: Optional[str] # None = all, or "kinematic" / "failure"
|
| 171 |
+
max_tickets_used: int # how many ticket slots the policy actually reads
|
| 172 |
+
requires_history: bool # whether GRU/recurrent hidden state is needed
|
| 173 |
+
description: str
|
| 174 |
+
|
| 175 |
+
POLICY_SPECS: Dict[str, PolicyInputSpec] = {
|
| 176 |
+
"SurvivalPolicy": PolicyInputSpec(
|
| 177 |
+
name="SurvivalPolicy",
|
| 178 |
+
reads_ego=True,
|
| 179 |
+
ego_indices=tuple(range(EGO_STATE_DIM)),
|
| 180 |
+
reads_tickets=False,
|
| 181 |
+
ticket_filter=None,
|
| 182 |
+
max_tickets_used=0,
|
| 183 |
+
requires_history=False,
|
| 184 |
+
description="Stage 1 baseline. Reads only ego state (first 11 dims). "
|
| 185 |
+
"Ticket portion of obs is ignored entirely.",
|
| 186 |
+
),
|
| 187 |
+
"FlatMLPPolicy": PolicyInputSpec(
|
| 188 |
+
name="FlatMLPPolicy",
|
| 189 |
+
reads_ego=True,
|
| 190 |
+
ego_indices=tuple(range(EGO_STATE_DIM)),
|
| 191 |
+
reads_tickets=True,
|
| 192 |
+
ticket_filter=None,
|
| 193 |
+
max_tickets_used=MAX_TICKETS,
|
| 194 |
+
requires_history=False,
|
| 195 |
+
description="Sanity-check baseline. Reads full flat observation (ego + all tickets "
|
| 196 |
+
"concatenated). No attention or structure.",
|
| 197 |
+
),
|
| 198 |
+
"TicketAttentionPolicy": PolicyInputSpec(
|
| 199 |
+
name="TicketAttentionPolicy",
|
| 200 |
+
reads_ego=True,
|
| 201 |
+
ego_indices=tuple(range(EGO_STATE_DIM)),
|
| 202 |
+
reads_tickets=True,
|
| 203 |
+
ticket_filter=None,
|
| 204 |
+
max_tickets_used=MAX_TICKETS,
|
| 205 |
+
requires_history=False,
|
| 206 |
+
description="Main policy (Stage 2+). Cross-attention: ego queries ticket set. "
|
| 207 |
+
"Order-invariant over tickets. Padding mask on zero-rows.",
|
| 208 |
+
),
|
| 209 |
+
"ThreatAvoidancePolicy": PolicyInputSpec(
|
| 210 |
+
name="ThreatAvoidancePolicy",
|
| 211 |
+
reads_ego=True,
|
| 212 |
+
ego_indices=tuple(range(EGO_STATE_DIM)),
|
| 213 |
+
reads_tickets=True,
|
| 214 |
+
ticket_filter="kinematic",
|
| 215 |
+
max_tickets_used=1,
|
| 216 |
+
requires_history=False,
|
| 217 |
+
description="Specialist for kinematic threats (collision_risk, sudden_brake, "
|
| 218 |
+
"side_impact, head_on, merge_cut, rear_end_risk). Extracts the "
|
| 219 |
+
"highest-severity kinematic ticket and gates between brake/evade branches.",
|
| 220 |
+
),
|
| 221 |
+
"SystemFailurePolicy": PolicyInputSpec(
|
| 222 |
+
name="SystemFailurePolicy",
|
| 223 |
+
reads_ego=True,
|
| 224 |
+
ego_indices=tuple(range(EGO_STATE_DIM)),
|
| 225 |
+
reads_tickets=True,
|
| 226 |
+
ticket_filter="failure",
|
| 227 |
+
max_tickets_used=1,
|
| 228 |
+
requires_history=False,
|
| 229 |
+
description="Specialist for onboard failures (tire_blowout, brake_fade, steering_loss). "
|
| 230 |
+
"Mixture-of-experts with one expert per failure type. Initialized with "
|
| 231 |
+
"domain-correct response priors.",
|
| 232 |
+
),
|
| 233 |
+
"RecurrentPolicy": PolicyInputSpec(
|
| 234 |
+
name="RecurrentPolicy",
|
| 235 |
+
reads_ego=True,
|
| 236 |
+
ego_indices=tuple(range(EGO_STATE_DIM)),
|
| 237 |
+
reads_tickets=True,
|
| 238 |
+
ticket_filter=None,
|
| 239 |
+
max_tickets_used=MAX_TICKETS,
|
| 240 |
+
requires_history=True,
|
| 241 |
+
description="GRU-based policy for partial observability (Stage 4+). Carries hidden "
|
| 242 |
+
"state across timesteps. Requires h_prev to be tracked by caller.",
|
| 243 |
+
),
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# ββ Validation helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 248 |
+
|
| 249 |
+
def validate_obs(obs: np.ndarray, spec: Optional[ObsSpec] = None) -> None:
|
| 250 |
+
"""
|
| 251 |
+
Validate an observation array against the spec.
|
| 252 |
+
Raises ValueError with a descriptive message on any mismatch.
|
| 253 |
+
"""
|
| 254 |
+
spec = spec or ObsSpec()
|
| 255 |
+
if obs.ndim != 1:
|
| 256 |
+
raise ValueError(f"Observation must be 1D, got shape {obs.shape}")
|
| 257 |
+
if obs.shape[0] != spec.total_dim:
|
| 258 |
+
raise ValueError(
|
| 259 |
+
f"Observation dim mismatch: expected {spec.total_dim}, got {obs.shape[0]}. "
|
| 260 |
+
f"Check ego_dim ({spec.ego_dim}) + max_tickets ({spec.max_tickets}) "
|
| 261 |
+
f"* ticket_dim ({spec.ticket_dim})"
|
| 262 |
+
)
|
| 263 |
+
if obs.dtype != np.float32:
|
| 264 |
+
raise ValueError(f"Observation dtype must be float32, got {obs.dtype}")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def validate_action(action: np.ndarray) -> None:
|
| 268 |
+
"""Validate an action array."""
|
| 269 |
+
if action.shape != (ACTION_DIM,):
|
| 270 |
+
raise ValueError(f"Action shape mismatch: expected ({ACTION_DIM},), got {action.shape}")
|
| 271 |
+
if np.any(action < -1.0) or np.any(action > 1.0):
|
| 272 |
+
raise ValueError(f"Action values must be in [-1, 1], got min={action.min()}, max={action.max()}")
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def build_obs(
|
| 276 |
+
ego_x: float, ego_y: float, ego_z: float,
|
| 277 |
+
ego_vx: float, ego_vy: float,
|
| 278 |
+
heading: float, speed: float,
|
| 279 |
+
steer: float, throttle: float, brake: float,
|
| 280 |
+
ticket_vectors: Optional[np.ndarray] = None,
|
| 281 |
+
max_tickets: int = MAX_TICKETS,
|
| 282 |
+
) -> np.ndarray:
|
| 283 |
+
"""
|
| 284 |
+
Build a valid observation vector from raw values.
|
| 285 |
+
|
| 286 |
+
This is the primary entry point for external environments that want to
|
| 287 |
+
produce observations compatible with OpenENV policies.
|
| 288 |
+
|
| 289 |
+
Parameters
|
| 290 |
+
----------
|
| 291 |
+
ego_x : forward displacement from episode start (metres)
|
| 292 |
+
ego_y : lateral displacement from lane center (metres, + = left)
|
| 293 |
+
ego_z : vertical position (metres)
|
| 294 |
+
ego_vx : forward velocity (m/s)
|
| 295 |
+
ego_vy : lateral velocity (m/s)
|
| 296 |
+
heading : heading angle (radians, 0 = forward)
|
| 297 |
+
speed : scalar speed (m/s)
|
| 298 |
+
steer : current steering command [-1, 1]
|
| 299 |
+
throttle : current throttle command [0, 1]
|
| 300 |
+
brake : current brake command [0, 1]
|
| 301 |
+
ticket_vectors : (N, TICKET_VECTOR_DIM) array of ticket vectors, or None.
|
| 302 |
+
Use EventTicket.to_vector() or build_ticket_vector() to create these.
|
| 303 |
+
max_tickets : number of ticket slots (must match policy expectation, default 16)
|
| 304 |
+
|
| 305 |
+
Returns
|
| 306 |
+
-------
|
| 307 |
+
obs : np.ndarray of shape (EGO_STATE_DIM + max_tickets * TICKET_VECTOR_DIM,)
|
| 308 |
+
"""
|
| 309 |
+
import math
|
| 310 |
+
|
| 311 |
+
ego = np.array([
|
| 312 |
+
ego_x / 1000.0,
|
| 313 |
+
ego_y / 3.7, # ROAD_HALF_WIDTH
|
| 314 |
+
ego_z / 10.0,
|
| 315 |
+
ego_vx / 20.0, # MAX_SPEED
|
| 316 |
+
ego_vy / 20.0,
|
| 317 |
+
0.0, # vz (flat road)
|
| 318 |
+
math.sin(heading),
|
| 319 |
+
math.cos(heading),
|
| 320 |
+
speed / 20.0,
|
| 321 |
+
steer,
|
| 322 |
+
throttle - brake, # net drive signal
|
| 323 |
+
], dtype=np.float32)
|
| 324 |
+
|
| 325 |
+
ticket_matrix = np.zeros((max_tickets, TICKET_VECTOR_DIM), dtype=np.float32)
|
| 326 |
+
if ticket_vectors is not None:
|
| 327 |
+
n = min(len(ticket_vectors), max_tickets)
|
| 328 |
+
ticket_matrix[:n] = ticket_vectors[:n]
|
| 329 |
+
|
| 330 |
+
return np.concatenate([ego, ticket_matrix.flatten()])
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def build_ticket_vector(
|
| 334 |
+
severity_weight: float,
|
| 335 |
+
ttl: float,
|
| 336 |
+
pos_x: float, pos_y: float, pos_z: float,
|
| 337 |
+
vel_x: float, vel_y: float, vel_z: float,
|
| 338 |
+
heading: float,
|
| 339 |
+
size_length: float, size_width: float, size_height: float,
|
| 340 |
+
distance: float,
|
| 341 |
+
time_to_collision: Optional[float],
|
| 342 |
+
bearing: float,
|
| 343 |
+
ticket_type: str,
|
| 344 |
+
entity_type: str,
|
| 345 |
+
confidence: float = 1.0,
|
| 346 |
+
) -> np.ndarray:
|
| 347 |
+
"""
|
| 348 |
+
Build a single ticket vector from raw values without needing the full
|
| 349 |
+
EventTicket class. Use this when adapting a different simulator.
|
| 350 |
+
|
| 351 |
+
Parameters
|
| 352 |
+
----------
|
| 353 |
+
severity_weight : 0.25 (LOW), 0.5 (MEDIUM), 0.75 (HIGH), 1.0 (CRITICAL)
|
| 354 |
+
ttl : seconds remaining until ticket expires
|
| 355 |
+
pos_x/y/z : ego-relative position (metres)
|
| 356 |
+
vel_x/y/z : entity velocity in world frame (m/s)
|
| 357 |
+
heading : entity heading relative to ego (radians)
|
| 358 |
+
size_length/width/height : entity bounding box (metres)
|
| 359 |
+
distance : euclidean distance to ego (metres)
|
| 360 |
+
time_to_collision : seconds until collision, or None if no collision course
|
| 361 |
+
bearing : angle from ego forward axis (radians)
|
| 362 |
+
ticket_type : one of TICKET_TYPES (e.g., "collision_risk")
|
| 363 |
+
entity_type : one of ENTITY_TYPES (e.g., "vehicle")
|
| 364 |
+
confidence : perception confidence [0, 1]
|
| 365 |
+
|
| 366 |
+
Returns
|
| 367 |
+
-------
|
| 368 |
+
vec : np.ndarray of shape (TICKET_VECTOR_DIM,) = (37,)
|
| 369 |
+
"""
|
| 370 |
+
import math
|
| 371 |
+
|
| 372 |
+
ttc_norm = min((time_to_collision if time_to_collision is not None else 30.0) / 30.0, 1.0)
|
| 373 |
+
|
| 374 |
+
type_oh = [0.0] * len(TICKET_TYPES)
|
| 375 |
+
entity_oh = [0.0] * len(ENTITY_TYPES)
|
| 376 |
+
|
| 377 |
+
if ticket_type in TICKET_TYPES:
|
| 378 |
+
type_oh[TICKET_TYPES.index(ticket_type)] = 1.0
|
| 379 |
+
else:
|
| 380 |
+
raise ValueError(f"Unknown ticket_type '{ticket_type}'. Must be one of {TICKET_TYPES}")
|
| 381 |
+
|
| 382 |
+
if entity_type in ENTITY_TYPES:
|
| 383 |
+
entity_oh[ENTITY_TYPES.index(entity_type)] = 1.0
|
| 384 |
+
else:
|
| 385 |
+
raise ValueError(f"Unknown entity_type '{entity_type}'. Must be one of {ENTITY_TYPES}")
|
| 386 |
+
|
| 387 |
+
vec = [
|
| 388 |
+
severity_weight,
|
| 389 |
+
min(ttl / 10.0, 1.0),
|
| 390 |
+
pos_x / 100.0,
|
| 391 |
+
pos_y / 50.0,
|
| 392 |
+
pos_z / 10.0,
|
| 393 |
+
vel_x / 30.0,
|
| 394 |
+
vel_y / 30.0,
|
| 395 |
+
vel_z / 10.0,
|
| 396 |
+
math.sin(heading),
|
| 397 |
+
math.cos(heading),
|
| 398 |
+
size_length / 10.0,
|
| 399 |
+
size_width / 5.0,
|
| 400 |
+
size_height / 4.0,
|
| 401 |
+
min(distance / 100.0, 1.0),
|
| 402 |
+
ttc_norm,
|
| 403 |
+
math.sin(bearing),
|
| 404 |
+
math.cos(bearing),
|
| 405 |
+
confidence,
|
| 406 |
+
*type_oh,
|
| 407 |
+
*entity_oh,
|
| 408 |
+
]
|
| 409 |
+
return np.array(vec, dtype=np.float32)
|
policies/ticket_attention_policy.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TicketAttentionPolicy β the main policy (Stage 2+).
|
| 3 |
+
|
| 4 |
+
Architecture: two-pass "reflective" cross-attention.
|
| 5 |
+
|
| 6 |
+
Pass 1: ego queries tickets β raw threat context
|
| 7 |
+
Pass 2: (ego + raw context) queries tickets again β refined context
|
| 8 |
+
This forces the policy to "think twice" β first perceive, then plan.
|
| 9 |
+
|
| 10 |
+
[ego | refined_context] β steer head β steer action
|
| 11 |
+
β drive head β throttle, brake
|
| 12 |
+
β critic head β value
|
| 13 |
+
|
| 14 |
+
Why two-pass:
|
| 15 |
+
The first pass gathers what threats exist. The second pass re-examines
|
| 16 |
+
tickets knowing what the overall threat picture looks like. This prevents
|
| 17 |
+
the impulsive single-shot responses that cause wild oscillation.
|
| 18 |
+
|
| 19 |
+
Why separate heads:
|
| 20 |
+
Steering requires smooth, conservative output (off-road = death).
|
| 21 |
+
Throttle/brake can be more aggressive. Separate heads + separate
|
| 22 |
+
noise levels let each dimension learn at its own pace.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
|
| 31 |
+
from .base_policy import BasePolicy
|
| 32 |
+
EGO_STATE_DIM = 11
|
| 33 |
+
MAX_TICKETS = 16
|
| 34 |
+
TICKET_VECTOR_DIM = 37
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TicketAttentionPolicy(BasePolicy):
|
| 38 |
+
"""
|
| 39 |
+
Two-pass reflective attention policy.
|
| 40 |
+
|
| 41 |
+
Pass 1: perceive β what threats exist?
|
| 42 |
+
Pass 2: plan β given what I see, which threats matter most?
|
| 43 |
+
Output: separate steer head (conservative) + drive head (throttle/brake)
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
obs_dim: int,
|
| 49 |
+
ego_embed: int = 64,
|
| 50 |
+
ticket_embed: int = 64,
|
| 51 |
+
n_heads: int = 4,
|
| 52 |
+
hidden: int = 256,
|
| 53 |
+
):
|
| 54 |
+
super().__init__(obs_dim)
|
| 55 |
+
assert ego_embed % n_heads == 0
|
| 56 |
+
assert ticket_embed == ego_embed
|
| 57 |
+
|
| 58 |
+
self.ego_embed = ego_embed
|
| 59 |
+
self.max_tickets = MAX_TICKETS
|
| 60 |
+
self.ticket_dim = TICKET_VECTOR_DIM
|
| 61 |
+
|
| 62 |
+
# ββ Encoders ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 63 |
+
self.ego_encoder = nn.Sequential(
|
| 64 |
+
nn.Linear(EGO_STATE_DIM, hidden // 2),
|
| 65 |
+
nn.LayerNorm(hidden // 2),
|
| 66 |
+
nn.Tanh(),
|
| 67 |
+
nn.Linear(hidden // 2, ego_embed),
|
| 68 |
+
nn.LayerNorm(ego_embed),
|
| 69 |
+
)
|
| 70 |
+
self.ticket_encoder = nn.Sequential(
|
| 71 |
+
nn.Linear(TICKET_VECTOR_DIM, hidden // 2),
|
| 72 |
+
nn.LayerNorm(hidden // 2),
|
| 73 |
+
nn.ReLU(),
|
| 74 |
+
nn.Linear(hidden // 2, ticket_embed),
|
| 75 |
+
nn.LayerNorm(ticket_embed),
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# ββ Pass 1: perceive (ego queries tickets) βββββββββββββββββββββββ
|
| 79 |
+
self.attn_pass1 = nn.MultiheadAttention(
|
| 80 |
+
embed_dim=ego_embed, num_heads=n_heads,
|
| 81 |
+
dropout=0.0, batch_first=True,
|
| 82 |
+
)
|
| 83 |
+
self.norm1 = nn.LayerNorm(ego_embed)
|
| 84 |
+
|
| 85 |
+
# ββ Reflection gate: fuse ego + pass1 context for second query βββ
|
| 86 |
+
self.reflect_proj = nn.Sequential(
|
| 87 |
+
nn.Linear(ego_embed * 2, ego_embed),
|
| 88 |
+
nn.LayerNorm(ego_embed),
|
| 89 |
+
nn.Tanh(),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# ββ Pass 2: plan (refined query re-attends to tickets) βββββββββββ
|
| 93 |
+
self.attn_pass2 = nn.MultiheadAttention(
|
| 94 |
+
embed_dim=ego_embed, num_heads=n_heads,
|
| 95 |
+
dropout=0.0, batch_first=True,
|
| 96 |
+
)
|
| 97 |
+
self.norm2 = nn.LayerNorm(ego_embed)
|
| 98 |
+
|
| 99 |
+
# ββ Fused representation βββββββββββββββββββββββββββββββββββββββββ
|
| 100 |
+
fused_dim = ego_embed + ego_embed # ego + refined context
|
| 101 |
+
|
| 102 |
+
# ββ Steer head (conservative, smooth output) βββββββββββββββββββββ
|
| 103 |
+
self.steer_head = nn.Sequential(
|
| 104 |
+
nn.Linear(fused_dim, hidden // 2),
|
| 105 |
+
nn.LayerNorm(hidden // 2),
|
| 106 |
+
nn.Tanh(),
|
| 107 |
+
nn.Linear(hidden // 2, hidden // 4),
|
| 108 |
+
nn.Tanh(),
|
| 109 |
+
nn.Linear(hidden // 4, 1),
|
| 110 |
+
nn.Tanh(),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# ββ Drive head (throttle + brake) ββββββββββββββββββββββββββββββββ
|
| 114 |
+
self.drive_head = nn.Sequential(
|
| 115 |
+
nn.Linear(fused_dim, hidden // 2),
|
| 116 |
+
nn.LayerNorm(hidden // 2),
|
| 117 |
+
nn.Tanh(),
|
| 118 |
+
nn.Linear(hidden // 2, hidden // 4),
|
| 119 |
+
nn.Tanh(),
|
| 120 |
+
nn.Linear(hidden // 4, 2),
|
| 121 |
+
nn.Tanh(),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# ββ Critic head ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 125 |
+
self.critic = nn.Sequential(
|
| 126 |
+
nn.Linear(fused_dim, hidden),
|
| 127 |
+
nn.LayerNorm(hidden),
|
| 128 |
+
nn.Tanh(),
|
| 129 |
+
nn.Linear(hidden, hidden // 2),
|
| 130 |
+
nn.Tanh(),
|
| 131 |
+
nn.Linear(hidden // 2, 1),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
self._init_weights()
|
| 135 |
+
|
| 136 |
+
def _init_weights(self):
|
| 137 |
+
for m in self.modules():
|
| 138 |
+
if isinstance(m, nn.Linear):
|
| 139 |
+
nn.init.orthogonal_(m.weight, gain=1.0)
|
| 140 |
+
if m.bias is not None:
|
| 141 |
+
nn.init.zeros_(m.bias)
|
| 142 |
+
# Very small initial actions β start by doing almost nothing
|
| 143 |
+
nn.init.orthogonal_(self.steer_head[-2].weight, gain=0.01)
|
| 144 |
+
nn.init.orthogonal_(self.drive_head[-2].weight, gain=0.01)
|
| 145 |
+
# Critic starts near zero
|
| 146 |
+
nn.init.orthogonal_(self.critic[-1].weight, gain=0.1)
|
| 147 |
+
|
| 148 |
+
def _attend(self, attn_module, norm_module, query, tk_emb, is_padding, all_empty):
|
| 149 |
+
"""Run one attention pass with NaN-safe masking."""
|
| 150 |
+
B = query.shape[0]
|
| 151 |
+
q = query if query.dim() == 3 else query.unsqueeze(1)
|
| 152 |
+
|
| 153 |
+
if all_empty.all():
|
| 154 |
+
return torch.zeros(B, self.ego_embed, device=query.device)
|
| 155 |
+
|
| 156 |
+
safe_mask = is_padding.clone()
|
| 157 |
+
safe_mask[all_empty, 0] = False
|
| 158 |
+
attn_out, _ = attn_module(
|
| 159 |
+
query=q, key=tk_emb, value=tk_emb,
|
| 160 |
+
key_padding_mask=safe_mask,
|
| 161 |
+
)
|
| 162 |
+
context = attn_out.squeeze(1)
|
| 163 |
+
context[all_empty] = 0.0
|
| 164 |
+
return norm_module(context)
|
| 165 |
+
|
| 166 |
+
def forward(self, obs: torch.Tensor):
|
| 167 |
+
B = obs.shape[0]
|
| 168 |
+
|
| 169 |
+
# Split observation
|
| 170 |
+
ego_raw = obs[:, :EGO_STATE_DIM]
|
| 171 |
+
tk_raw = obs[:, EGO_STATE_DIM:].view(B, self.max_tickets, self.ticket_dim)
|
| 172 |
+
|
| 173 |
+
# Encode
|
| 174 |
+
ego_emb = self.ego_encoder(ego_raw)
|
| 175 |
+
tk_emb = self.ticket_encoder(tk_raw)
|
| 176 |
+
|
| 177 |
+
# Padding mask
|
| 178 |
+
is_padding = (tk_raw.abs().sum(dim=-1) == 0)
|
| 179 |
+
all_empty = is_padding.all(dim=-1)
|
| 180 |
+
|
| 181 |
+
# ββ Pass 1: perceive βββββββββββββββββββββββββββββββββββββββββββββ
|
| 182 |
+
ctx1 = self._attend(self.attn_pass1, self.norm1,
|
| 183 |
+
ego_emb, tk_emb, is_padding, all_empty)
|
| 184 |
+
|
| 185 |
+
# ββ Reflect: combine ego + initial context into refined query ββββ
|
| 186 |
+
reflected = self.reflect_proj(torch.cat([ego_emb, ctx1], dim=-1))
|
| 187 |
+
|
| 188 |
+
# ββ Pass 2: plan (re-attend with richer query) βββββββββββββββββββ
|
| 189 |
+
ctx2 = self._attend(self.attn_pass2, self.norm2,
|
| 190 |
+
reflected, tk_emb, is_padding, all_empty)
|
| 191 |
+
|
| 192 |
+
# ββ Fuse and decode ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 193 |
+
fused = torch.cat([ego_emb, ctx2], dim=-1)
|
| 194 |
+
|
| 195 |
+
steer = self.steer_head(fused) # (B, 1)
|
| 196 |
+
drive = self.drive_head(fused) # (B, 2)
|
| 197 |
+
action = torch.cat([steer, drive], dim=-1) # (B, 3)
|
| 198 |
+
value = self.critic(fused) # (B, 1)
|
| 199 |
+
|
| 200 |
+
return action, value
|
| 201 |
+
|
| 202 |
+
def get_attention_weights(self, obs: torch.Tensor) -> torch.Tensor:
|
| 203 |
+
"""Returns pass-2 attention weights for interpretability."""
|
| 204 |
+
B = obs.shape[0]
|
| 205 |
+
ego_raw = obs[:, :EGO_STATE_DIM]
|
| 206 |
+
tk_raw = obs[:, EGO_STATE_DIM:].view(B, self.max_tickets, self.ticket_dim)
|
| 207 |
+
ego_emb = self.ego_encoder(ego_raw)
|
| 208 |
+
tk_emb = self.ticket_encoder(tk_raw)
|
| 209 |
+
is_padding = (tk_raw.abs().sum(dim=-1) == 0)
|
| 210 |
+
all_empty = is_padding.all(dim=-1)
|
| 211 |
+
|
| 212 |
+
# Pass 1
|
| 213 |
+
ctx1 = self._attend(self.attn_pass1, self.norm1,
|
| 214 |
+
ego_emb, tk_emb, is_padding, all_empty)
|
| 215 |
+
reflected = self.reflect_proj(torch.cat([ego_emb, ctx1], dim=-1))
|
| 216 |
+
|
| 217 |
+
# Pass 2 β get weights
|
| 218 |
+
safe_mask = is_padding.clone()
|
| 219 |
+
safe_mask[all_empty, 0] = False
|
| 220 |
+
query = reflected.unsqueeze(1)
|
| 221 |
+
_, weights = self.attn_pass2(
|
| 222 |
+
query=query, key=tk_emb, value=tk_emb,
|
| 223 |
+
key_padding_mask=safe_mask,
|
| 224 |
+
need_weights=True, average_attn_weights=False,
|
| 225 |
+
)
|
| 226 |
+
weights[all_empty] = 0.0
|
| 227 |
+
return weights
|
pyproject.toml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=45", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "openenv-overflow-env"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Overflow Environment for OpenEnv β autonomous vehicle fleet oversight on a 2D road grid"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"openenv-core[core]>=0.2.1",
|
| 12 |
+
"fastapi>=0.115.0",
|
| 13 |
+
"pydantic>=2.0.0",
|
| 14 |
+
"uvicorn[standard]>=0.24.0",
|
| 15 |
+
"requests>=2.31.0",
|
| 16 |
+
"torch>=2.10.0",
|
| 17 |
+
"numpy>=2.2.6",
|
| 18 |
+
"pillow>=12.1.1",
|
| 19 |
+
"gymnasium>=1.2.3",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
[project.optional-dependencies]
|
| 23 |
+
dev = [
|
| 24 |
+
"pytest>=8.0.0",
|
| 25 |
+
"pytest-cov>=4.0.0",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
[project.scripts]
|
| 29 |
+
server = "overflow_env.server.app:main"
|
| 30 |
+
|
| 31 |
+
[tool.setuptools]
|
| 32 |
+
packages = ["overflow_env", "overflow_env.server"]
|
| 33 |
+
package-dir = { "overflow_env" = ".", "overflow_env.server" = "server" }
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
| 2 |
+
torch==2.5.1+cpu
|
| 3 |
+
numpy>=1.24.0
|
| 4 |
+
pillow==10.4.0
|
| 5 |
+
matplotlib>=3.8.0
|
| 6 |
+
pydantic>=2.0.0
|
| 7 |
+
requests>=2.31.0
|
| 8 |
+
gymnasium>=0.29.0
|
server/__init__.py
ADDED
|
File without changes
|
server/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (167 Bytes). View file
|
|
|
server/__pycache__/overflow_environment.cpython-314.pyc
ADDED
|
Binary file (23.6 kB). View file
|
|
|
server/__pycache__/policy_adapter.cpython-314.pyc
ADDED
|
Binary file (4.89 kB). View file
|
|
|
server/app.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI application for the Overflow Environment.
|
| 3 |
+
|
| 4 |
+
Exposes the OverflowEnvironment over HTTP and WebSocket endpoints.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import inspect
|
| 11 |
+
|
| 12 |
+
from openenv.core.env_server.http_server import create_app
|
| 13 |
+
|
| 14 |
+
from ..models import OverflowAction, OverflowObservation
|
| 15 |
+
from .overflow_environment import OverflowEnvironment
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _create_overflow_app():
|
| 19 |
+
"""Build app across create_app variants that may expect a factory or an instance."""
|
| 20 |
+
try:
|
| 21 |
+
first_param = next(iter(inspect.signature(create_app).parameters.values()))
|
| 22 |
+
annotation_text = str(first_param.annotation)
|
| 23 |
+
except (StopIteration, TypeError, ValueError):
|
| 24 |
+
annotation_text = "typing.Callable"
|
| 25 |
+
|
| 26 |
+
expects_instance = (
|
| 27 |
+
"Environment" in annotation_text and "Callable" not in annotation_text
|
| 28 |
+
)
|
| 29 |
+
env_arg = OverflowEnvironment() if expects_instance else OverflowEnvironment
|
| 30 |
+
return create_app(
|
| 31 |
+
env_arg, OverflowAction, OverflowObservation, env_name="overflow_env"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
app = _create_overflow_app()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main():
|
| 39 |
+
"""Entry point for direct execution via uv run or python -m."""
|
| 40 |
+
import uvicorn
|
| 41 |
+
|
| 42 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
main()
|
server/overflow_environment.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Overflow Environment Implementation.
|
| 3 |
+
|
| 4 |
+
A 2D road grid with N cars. One car (Car 0) is the LLM agent, others follow
|
| 5 |
+
scripted rules. An observer checks for collisions each step. The environment
|
| 6 |
+
returns text observations describing the traffic scene and rewards based on safety.
|
| 7 |
+
|
| 8 |
+
Observations carry both text (for the LLM) and structured data (for the frontend).
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import random
|
| 13 |
+
import re
|
| 14 |
+
from dataclasses import dataclass, field
|
| 15 |
+
from typing import Any, List, Optional
|
| 16 |
+
from uuid import uuid4
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from openenv.core.env_server.interfaces import Environment
|
| 20 |
+
from openenv.core.env_server.types import State
|
| 21 |
+
except ImportError:
|
| 22 |
+
class Environment: # stub for training-only mode
|
| 23 |
+
pass
|
| 24 |
+
class State:
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from ..models import (
|
| 29 |
+
CarStateData, LaneOccupancyData, OverflowAction,
|
| 30 |
+
OverflowObservation, OverflowState, Position, ProximityData,
|
| 31 |
+
)
|
| 32 |
+
from ..policies.flat_mlp_policy import FlatMLPPolicy
|
| 33 |
+
from ..policies.ticket_attention_policy import TicketAttentionPolicy
|
| 34 |
+
from ..policies.policy_spec import OBS_DIM
|
| 35 |
+
from .policy_adapter import overflow_obs_to_policy_obs, policy_action_to_decision
|
| 36 |
+
except ImportError:
|
| 37 |
+
from models import (
|
| 38 |
+
CarStateData, LaneOccupancyData, OverflowAction,
|
| 39 |
+
OverflowObservation, OverflowState, Position, ProximityData,
|
| 40 |
+
)
|
| 41 |
+
from policies.flat_mlp_policy import FlatMLPPolicy
|
| 42 |
+
from policies.ticket_attention_policy import TicketAttentionPolicy
|
| 43 |
+
from policies.policy_spec import OBS_DIM
|
| 44 |
+
from server.policy_adapter import overflow_obs_to_policy_obs, policy_action_to_decision
|
| 45 |
+
|
| 46 |
+
# --- Constants ---
|
| 47 |
+
NUM_LANES = 3
|
| 48 |
+
ROAD_LENGTH = 200
|
| 49 |
+
NUM_CARS = 5
|
| 50 |
+
MAX_STEPS = 100
|
| 51 |
+
CRASH_DISTANCE = 5.0
|
| 52 |
+
NEAR_MISS_DISTANCE = 15.0
|
| 53 |
+
LANE_WIDTH = 3.7 # metres β matches frontend's makeCar convention
|
| 54 |
+
|
| 55 |
+
# Reward values
|
| 56 |
+
REWARD_CRASH = -5.0
|
| 57 |
+
REWARD_NEAR_MISS = -1.0
|
| 58 |
+
REWARD_SAFE_STEP = 0.5
|
| 59 |
+
REWARD_REACHED_GOAL = 3.0
|
| 60 |
+
REWARD_REASONING_MAX = 0.3
|
| 61 |
+
|
| 62 |
+
# Speed bounds
|
| 63 |
+
MIN_SPEED = 20
|
| 64 |
+
MAX_SPEED = 90
|
| 65 |
+
SPEED_DELTA = 5
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@dataclass
|
| 69 |
+
class Car:
|
| 70 |
+
"""Represents a car on the road grid."""
|
| 71 |
+
|
| 72 |
+
car_id: int
|
| 73 |
+
lane: int # 1-indexed: 1, 2, or 3
|
| 74 |
+
position: float
|
| 75 |
+
speed: float
|
| 76 |
+
goal_position: float
|
| 77 |
+
is_agent: bool = False
|
| 78 |
+
reached_goal: bool = False
|
| 79 |
+
prev_speed: float = 0.0 # speed last step, for acceleration calc
|
| 80 |
+
|
| 81 |
+
def distance_to(self, other: "Car") -> float:
|
| 82 |
+
"""Euclidean-ish distance considering lane and position."""
|
| 83 |
+
lane_diff = abs(self.lane - other.lane) * 10.0 # lanes are ~10 units apart
|
| 84 |
+
pos_diff = abs(self.position - other.position)
|
| 85 |
+
return math.sqrt(lane_diff**2 + pos_diff**2)
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def acceleration(self) -> float:
|
| 89 |
+
"""Speed delta since last step."""
|
| 90 |
+
return self.speed - self.prev_speed
|
| 91 |
+
|
| 92 |
+
def to_state_data(self) -> CarStateData:
|
| 93 |
+
"""Convert to frontend-compatible CarStateData."""
|
| 94 |
+
return CarStateData(
|
| 95 |
+
carId=self.car_id,
|
| 96 |
+
lane=self.lane,
|
| 97 |
+
position=Position(x=self.position, y=self.lane * LANE_WIDTH),
|
| 98 |
+
speed=self.speed,
|
| 99 |
+
acceleration=self.acceleration,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _parse_decision(action: OverflowAction) -> str:
|
| 104 |
+
"""Extract a valid decision from the action, being forgiving about format."""
|
| 105 |
+
valid = {"accelerate", "brake", "lane_change_left", "lane_change_right", "maintain"}
|
| 106 |
+
|
| 107 |
+
# Try the decision field directly
|
| 108 |
+
decision = action.decision.strip().lower().replace(" ", "_")
|
| 109 |
+
if decision in valid:
|
| 110 |
+
return decision
|
| 111 |
+
|
| 112 |
+
# Try to extract from free text (the LLM might wrap it in tags)
|
| 113 |
+
text = f"{action.decision} {action.reasoning}".lower()
|
| 114 |
+
|
| 115 |
+
# Check for <action>...</action> tags
|
| 116 |
+
match = re.search(r"<action>\s*(\w+)\s*</action>", text)
|
| 117 |
+
if match:
|
| 118 |
+
candidate = match.group(1).strip().replace(" ", "_")
|
| 119 |
+
if candidate in valid:
|
| 120 |
+
return candidate
|
| 121 |
+
|
| 122 |
+
# Check for keywords anywhere (ordered: most specific first to avoid ambiguity)
|
| 123 |
+
for v in ["lane_change_left", "lane_change_right", "accelerate", "brake", "maintain"]:
|
| 124 |
+
if v in text:
|
| 125 |
+
return v
|
| 126 |
+
|
| 127 |
+
return "maintain"
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _compute_reasoning_bonus(reasoning: str) -> float:
|
| 131 |
+
"""
|
| 132 |
+
Compute a small reasoning quality bonus (0.0 to 0.3).
|
| 133 |
+
|
| 134 |
+
Gives a minor reward for providing structured reasoning, kept low
|
| 135 |
+
so driving performance remains the dominant training signal.
|
| 136 |
+
"""
|
| 137 |
+
if not reasoning:
|
| 138 |
+
return 0.0
|
| 139 |
+
|
| 140 |
+
score = 0.0
|
| 141 |
+
lower = reasoning.lower()
|
| 142 |
+
|
| 143 |
+
# Small bonus for providing any reasoning at all
|
| 144 |
+
if len(reasoning) > 20:
|
| 145 |
+
score += 0.1
|
| 146 |
+
|
| 147 |
+
# Bonus for structured reasoning (not just keyword stuffing)
|
| 148 |
+
if "<think>" in lower or "because" in lower:
|
| 149 |
+
score += 0.1
|
| 150 |
+
if any(word in lower for word in ["therefore", "so i should", "best option", "i will"]):
|
| 151 |
+
score += 0.1
|
| 152 |
+
|
| 153 |
+
return min(score, REWARD_REASONING_MAX)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _scripted_car_action(car: Car, all_cars: List[Car], rng: random.Random) -> str:
|
| 157 |
+
"""
|
| 158 |
+
Simple scripted AI for non-agent cars.
|
| 159 |
+
|
| 160 |
+
Rules:
|
| 161 |
+
- If car ahead in same lane is close (< 20 units): brake
|
| 162 |
+
- If speed is low and random chance: accelerate
|
| 163 |
+
- Otherwise: maintain
|
| 164 |
+
"""
|
| 165 |
+
# Find nearest car ahead in same lane
|
| 166 |
+
nearest_ahead_dist = float("inf")
|
| 167 |
+
for other in all_cars:
|
| 168 |
+
if other.car_id == car.car_id:
|
| 169 |
+
continue
|
| 170 |
+
if other.lane == car.lane and other.position > car.position:
|
| 171 |
+
dist = other.position - car.position
|
| 172 |
+
if dist < nearest_ahead_dist:
|
| 173 |
+
nearest_ahead_dist = dist
|
| 174 |
+
|
| 175 |
+
if nearest_ahead_dist < 20:
|
| 176 |
+
return "brake"
|
| 177 |
+
|
| 178 |
+
if car.speed < 60 and rng.random() < 0.1:
|
| 179 |
+
return "accelerate"
|
| 180 |
+
|
| 181 |
+
# Occasionally change lanes to make traffic more dynamic
|
| 182 |
+
if rng.random() < 0.05:
|
| 183 |
+
if car.lane > 1 and rng.random() < 0.5:
|
| 184 |
+
return "lane_change_left"
|
| 185 |
+
elif car.lane < NUM_LANES:
|
| 186 |
+
return "lane_change_right"
|
| 187 |
+
|
| 188 |
+
return "maintain"
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _apply_action(car: Car, decision: str) -> None:
|
| 192 |
+
"""Apply a driving decision to a car, mutating it in place."""
|
| 193 |
+
if decision == "accelerate":
|
| 194 |
+
car.speed = min(car.speed + SPEED_DELTA, MAX_SPEED)
|
| 195 |
+
elif decision == "brake":
|
| 196 |
+
car.speed = max(car.speed - SPEED_DELTA, MIN_SPEED)
|
| 197 |
+
elif decision == "lane_change_left":
|
| 198 |
+
if car.lane > 1:
|
| 199 |
+
car.lane -= 1
|
| 200 |
+
elif decision == "lane_change_right":
|
| 201 |
+
if car.lane < NUM_LANES:
|
| 202 |
+
car.lane += 1
|
| 203 |
+
# "maintain" β no change
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _generate_scene_description(agent_car: Car, cars: List[Car]) -> str:
|
| 207 |
+
"""Generate a text description of the current traffic scene."""
|
| 208 |
+
lines = [
|
| 209 |
+
f"You are Car 0 in lane {agent_car.lane}, position {agent_car.position:.0f}, speed {agent_car.speed:.0f}.",
|
| 210 |
+
f"Goal: reach position {agent_car.goal_position:.0f}.",
|
| 211 |
+
"Nearby cars:",
|
| 212 |
+
]
|
| 213 |
+
|
| 214 |
+
for car in cars:
|
| 215 |
+
if car.car_id == agent_car.car_id:
|
| 216 |
+
continue
|
| 217 |
+
|
| 218 |
+
detail = f"- Car {car.car_id}: lane {car.lane}, position {car.position:.0f}, speed {car.speed:.0f}"
|
| 219 |
+
|
| 220 |
+
# Add context about relative position
|
| 221 |
+
if car.lane == agent_car.lane:
|
| 222 |
+
pos_diff = car.position - agent_car.position
|
| 223 |
+
if pos_diff > 0:
|
| 224 |
+
detail += f" [AHEAD IN YOUR LANE - {pos_diff:.0f} units away]"
|
| 225 |
+
else:
|
| 226 |
+
detail += f" [BEHIND IN YOUR LANE - {abs(pos_diff):.0f} units away]"
|
| 227 |
+
|
| 228 |
+
if car.reached_goal:
|
| 229 |
+
detail += " [REACHED GOAL]"
|
| 230 |
+
|
| 231 |
+
lines.append(detail)
|
| 232 |
+
|
| 233 |
+
return "\n".join(lines)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _build_structured_data(
|
| 237 |
+
cars: List[Car],
|
| 238 |
+
proximity_pairs: List[ProximityData],
|
| 239 |
+
) -> tuple[List[CarStateData], List[LaneOccupancyData]]:
|
| 240 |
+
"""Build structured arrays for the observation."""
|
| 241 |
+
cars_data = [c.to_state_data() for c in cars]
|
| 242 |
+
|
| 243 |
+
# Lane occupancies
|
| 244 |
+
lane_map: dict[int, list[int]] = {}
|
| 245 |
+
for car in cars:
|
| 246 |
+
if not car.reached_goal:
|
| 247 |
+
lane_map.setdefault(car.lane, []).append(car.car_id)
|
| 248 |
+
lane_occupancies = [
|
| 249 |
+
LaneOccupancyData(lane=lane, carIds=ids)
|
| 250 |
+
for lane, ids in sorted(lane_map.items())
|
| 251 |
+
]
|
| 252 |
+
|
| 253 |
+
return cars_data, lane_occupancies
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class OverflowEnvironment(Environment):
|
| 257 |
+
"""
|
| 258 |
+
Autonomous vehicle fleet oversight environment.
|
| 259 |
+
|
| 260 |
+
A 2D road grid with N cars. Car 0 is the LLM agent, others follow
|
| 261 |
+
scripted rules. The observer detects crashes and near-misses and
|
| 262 |
+
computes rewards based on safety.
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
def __init__(self):
|
| 266 |
+
super().__init__()
|
| 267 |
+
self._state = OverflowState(episode_id=str(uuid4()))
|
| 268 |
+
self._cars: List[Car] = []
|
| 269 |
+
self._rng = random.Random()
|
| 270 |
+
self._done = False
|
| 271 |
+
self._last_obs: Optional[OverflowObservation] = None
|
| 272 |
+
self._policies = {
|
| 273 |
+
"flat_mlp": FlatMLPPolicy(obs_dim=OBS_DIM),
|
| 274 |
+
"ticket_attention": TicketAttentionPolicy(obs_dim=OBS_DIM),
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
def _build_observation(
|
| 278 |
+
self,
|
| 279 |
+
incident_report: str,
|
| 280 |
+
reward: float,
|
| 281 |
+
proximities: Optional[List[ProximityData]] = None,
|
| 282 |
+
) -> OverflowObservation:
|
| 283 |
+
"""Build a full observation with text + structured data."""
|
| 284 |
+
agent = self._cars[0]
|
| 285 |
+
scene = _generate_scene_description(agent, self._cars)
|
| 286 |
+
prox = proximities or []
|
| 287 |
+
cars_data, lane_occ = _build_structured_data(self._cars, prox)
|
| 288 |
+
|
| 289 |
+
return OverflowObservation(
|
| 290 |
+
scene_description=scene,
|
| 291 |
+
incident_report=incident_report,
|
| 292 |
+
done=self._done,
|
| 293 |
+
reward=reward,
|
| 294 |
+
cars=cars_data,
|
| 295 |
+
proximities=prox,
|
| 296 |
+
lane_occupancies=lane_occ,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
def reset(
|
| 300 |
+
self,
|
| 301 |
+
seed: Optional[int] = None,
|
| 302 |
+
episode_id: Optional[str] = None,
|
| 303 |
+
**kwargs: Any,
|
| 304 |
+
) -> OverflowObservation:
|
| 305 |
+
"""Reset the environment: create road and spawn cars."""
|
| 306 |
+
if seed is not None:
|
| 307 |
+
self._rng = random.Random(seed)
|
| 308 |
+
else:
|
| 309 |
+
self._rng = random.Random()
|
| 310 |
+
|
| 311 |
+
self._state = OverflowState(
|
| 312 |
+
episode_id=episode_id or str(uuid4()),
|
| 313 |
+
step_count=0,
|
| 314 |
+
crash_count=0,
|
| 315 |
+
near_miss_count=0,
|
| 316 |
+
cars_reached_goal=0,
|
| 317 |
+
total_cars=NUM_CARS,
|
| 318 |
+
)
|
| 319 |
+
self._done = False
|
| 320 |
+
|
| 321 |
+
# Spawn cars with random positions, speeds, lanes, and goals
|
| 322 |
+
self._cars = []
|
| 323 |
+
|
| 324 |
+
for i in range(NUM_CARS):
|
| 325 |
+
# Ensure no two cars spawn within crash distance
|
| 326 |
+
for _attempt in range(100):
|
| 327 |
+
lane = self._rng.randint(1, NUM_LANES)
|
| 328 |
+
position = float(self._rng.randint(10, 80))
|
| 329 |
+
too_close = False
|
| 330 |
+
for existing in self._cars:
|
| 331 |
+
lane_diff = abs(lane - existing.lane) * 10.0
|
| 332 |
+
pos_diff = abs(position - existing.position)
|
| 333 |
+
dist = math.sqrt(lane_diff**2 + pos_diff**2)
|
| 334 |
+
if dist < CRASH_DISTANCE * 2:
|
| 335 |
+
too_close = True
|
| 336 |
+
break
|
| 337 |
+
if not too_close:
|
| 338 |
+
break
|
| 339 |
+
|
| 340 |
+
speed = float(self._rng.randint(40, 70))
|
| 341 |
+
goal = float(self._rng.randint(160, 195))
|
| 342 |
+
|
| 343 |
+
self._cars.append(
|
| 344 |
+
Car(
|
| 345 |
+
car_id=i,
|
| 346 |
+
lane=lane,
|
| 347 |
+
position=position,
|
| 348 |
+
speed=speed,
|
| 349 |
+
goal_position=goal,
|
| 350 |
+
is_agent=(i == 0),
|
| 351 |
+
prev_speed=speed, # no delta on first step
|
| 352 |
+
)
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
self._last_obs = self._build_observation(incident_report="", reward=0.0)
|
| 356 |
+
return self._last_obs
|
| 357 |
+
|
| 358 |
+
def step(
|
| 359 |
+
self,
|
| 360 |
+
action: OverflowAction,
|
| 361 |
+
timeout_s: Optional[float] = None,
|
| 362 |
+
**kwargs: Any,
|
| 363 |
+
) -> OverflowObservation:
|
| 364 |
+
"""Execute one simulation step."""
|
| 365 |
+
if self._done:
|
| 366 |
+
return self._build_observation(
|
| 367 |
+
incident_report="Episode is over. Call reset() to start a new one.",
|
| 368 |
+
reward=0.0,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# Policy intercept: decision="policy:flat_mlp" or "policy:ticket_attention"
|
| 372 |
+
if action.decision.startswith("policy:") and self._last_obs is not None:
|
| 373 |
+
policy_name = action.decision.split(":", 1)[1].lower()
|
| 374 |
+
if policy_name in self._policies:
|
| 375 |
+
obs_vec = overflow_obs_to_policy_obs(self._last_obs)
|
| 376 |
+
act_vec = self._policies[policy_name].predict(obs_vec)
|
| 377 |
+
decision, reasoning = policy_action_to_decision(act_vec)
|
| 378 |
+
action = OverflowAction(
|
| 379 |
+
decision=decision,
|
| 380 |
+
reasoning=f"[{policy_name}] {reasoning}",
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
self._state.step_count += 1
|
| 384 |
+
reward = 0.0
|
| 385 |
+
incidents = []
|
| 386 |
+
|
| 387 |
+
# Snapshot previous speeds for acceleration tracking
|
| 388 |
+
for car in self._cars:
|
| 389 |
+
car.prev_speed = car.speed
|
| 390 |
+
|
| 391 |
+
# 1. Parse and apply the agent's action to Car 0
|
| 392 |
+
decision = _parse_decision(action)
|
| 393 |
+
_apply_action(self._cars[0], decision)
|
| 394 |
+
|
| 395 |
+
# 2. Compute and apply scripted actions for Cars 1-N
|
| 396 |
+
for car in self._cars[1:]:
|
| 397 |
+
if car.reached_goal:
|
| 398 |
+
continue
|
| 399 |
+
scripted_decision = _scripted_car_action(car, self._cars, self._rng)
|
| 400 |
+
_apply_action(car, scripted_decision)
|
| 401 |
+
|
| 402 |
+
# 3. Move all cars forward based on speed (speed is in units/step, scaled down)
|
| 403 |
+
for car in self._cars:
|
| 404 |
+
if car.reached_goal:
|
| 405 |
+
continue
|
| 406 |
+
car.position += car.speed * 0.1 # scale factor for reasonable movement
|
| 407 |
+
|
| 408 |
+
# 4. Collision detection (pairwise)
|
| 409 |
+
agent_crash = False
|
| 410 |
+
proximity_list: List[ProximityData] = []
|
| 411 |
+
active_cars = [c for c in self._cars if not c.reached_goal]
|
| 412 |
+
agent_id = self._cars[0].car_id
|
| 413 |
+
for i in range(len(active_cars)):
|
| 414 |
+
for j in range(i + 1, len(active_cars)):
|
| 415 |
+
dist = active_cars[i].distance_to(active_cars[j])
|
| 416 |
+
involves_agent = active_cars[i].car_id == agent_id or active_cars[j].car_id == agent_id
|
| 417 |
+
if dist < CRASH_DISTANCE:
|
| 418 |
+
self._state.crash_count += 1
|
| 419 |
+
proximity_list.append(
|
| 420 |
+
ProximityData(
|
| 421 |
+
carA=active_cars[i].car_id,
|
| 422 |
+
carB=active_cars[j].car_id,
|
| 423 |
+
distance=round(dist, 2),
|
| 424 |
+
)
|
| 425 |
+
)
|
| 426 |
+
incidents.append(
|
| 427 |
+
f"CRASH between Car {active_cars[i].car_id} and Car {active_cars[j].car_id}! "
|
| 428 |
+
f"(distance: {dist:.1f})"
|
| 429 |
+
)
|
| 430 |
+
if involves_agent:
|
| 431 |
+
agent_crash = True
|
| 432 |
+
elif dist < NEAR_MISS_DISTANCE:
|
| 433 |
+
self._state.near_miss_count += 1
|
| 434 |
+
# Only penalize near misses involving the agent
|
| 435 |
+
if involves_agent:
|
| 436 |
+
reward += REWARD_NEAR_MISS
|
| 437 |
+
proximity_list.append(
|
| 438 |
+
ProximityData(
|
| 439 |
+
carA=active_cars[i].car_id,
|
| 440 |
+
carB=active_cars[j].car_id,
|
| 441 |
+
distance=round(dist, 2),
|
| 442 |
+
)
|
| 443 |
+
)
|
| 444 |
+
incidents.append(
|
| 445 |
+
f"NEAR MISS between Car {active_cars[i].car_id} and Car {active_cars[j].car_id} "
|
| 446 |
+
f"(distance: {dist:.1f})"
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
if agent_crash:
|
| 450 |
+
reward += REWARD_CRASH
|
| 451 |
+
self._done = True
|
| 452 |
+
else:
|
| 453 |
+
# 5. Goal check for agent car
|
| 454 |
+
agent = self._cars[0]
|
| 455 |
+
if agent.position >= agent.goal_position:
|
| 456 |
+
agent.reached_goal = True
|
| 457 |
+
self._state.cars_reached_goal += 1
|
| 458 |
+
reward += REWARD_REACHED_GOAL
|
| 459 |
+
incidents.append(
|
| 460 |
+
f"Car 0 reached its goal at position {agent.goal_position:.0f}!"
|
| 461 |
+
)
|
| 462 |
+
self._done = True
|
| 463 |
+
|
| 464 |
+
# Check goal for scripted cars too (for state tracking)
|
| 465 |
+
for car in self._cars[1:]:
|
| 466 |
+
if not car.reached_goal and car.position >= car.goal_position:
|
| 467 |
+
car.reached_goal = True
|
| 468 |
+
self._state.cars_reached_goal += 1
|
| 469 |
+
|
| 470 |
+
# 6. Safe step bonus (no crash, agent still active)
|
| 471 |
+
if not self._done:
|
| 472 |
+
reward += REWARD_SAFE_STEP
|
| 473 |
+
|
| 474 |
+
# 7. Reasoning quality bonus
|
| 475 |
+
reasoning_bonus = _compute_reasoning_bonus(action.reasoning)
|
| 476 |
+
reward += reasoning_bonus
|
| 477 |
+
|
| 478 |
+
# 8. Max steps check
|
| 479 |
+
if self._state.step_count >= MAX_STEPS and not self._done:
|
| 480 |
+
self._done = True
|
| 481 |
+
incidents.append(f"Maximum steps ({MAX_STEPS}) reached.")
|
| 482 |
+
|
| 483 |
+
incident_report = (
|
| 484 |
+
"\n".join(incidents) if incidents else "Observer: No incidents this step."
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
self._last_obs = self._build_observation(
|
| 488 |
+
incident_report=incident_report,
|
| 489 |
+
reward=reward,
|
| 490 |
+
proximities=proximity_list,
|
| 491 |
+
)
|
| 492 |
+
return self._last_obs
|
| 493 |
+
|
| 494 |
+
@property
|
| 495 |
+
def state(self) -> OverflowState:
|
| 496 |
+
"""Get the current environment state."""
|
| 497 |
+
return self._state
|
server/policy_adapter.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapter between OverflowObservation (2D road grid) and the OpenENV policy
|
| 3 |
+
observation format (ego state + ticket matrix).
|
| 4 |
+
|
| 5 |
+
Nearby cars are converted to collision_risk tickets so TicketAttentionPolicy
|
| 6 |
+
can reason about them using the same mechanism it was designed for.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from ..policies.policy_spec import build_obs, build_ticket_vector, OBS_DIM
|
| 16 |
+
except ImportError:
|
| 17 |
+
from policies.policy_spec import build_obs, build_ticket_vector, OBS_DIM
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def overflow_obs_to_policy_obs(obs) -> np.ndarray:
|
| 21 |
+
"""OverflowObservation β 603-dim numpy vector for our policies."""
|
| 22 |
+
cars = obs.cars
|
| 23 |
+
if not cars:
|
| 24 |
+
return np.zeros(OBS_DIM, dtype=np.float32)
|
| 25 |
+
|
| 26 |
+
ego = next((c for c in cars if c.carId == 0), cars[0])
|
| 27 |
+
ego_speed_ms = ego.speed / 4.5 # OverflowEnv speed units β m/s
|
| 28 |
+
ego_x = ego.position.x
|
| 29 |
+
ego_y = (ego.lane - 2) * 3.7 # lane β lateral metres
|
| 30 |
+
|
| 31 |
+
ticket_vectors = []
|
| 32 |
+
for car in cars:
|
| 33 |
+
if car.carId == 0:
|
| 34 |
+
continue
|
| 35 |
+
rel_x = car.position.x - ego.position.x
|
| 36 |
+
rel_y = (car.lane - ego.lane) * 3.7
|
| 37 |
+
car_spd = car.speed / 4.5
|
| 38 |
+
distance = math.sqrt(rel_x ** 2 + rel_y ** 2)
|
| 39 |
+
if distance > 80:
|
| 40 |
+
continue
|
| 41 |
+
closing = max(ego_speed_ms - car_spd * math.copysign(1, max(rel_x, 0.01)), 0.1)
|
| 42 |
+
ttc = min(distance / closing, 30.0)
|
| 43 |
+
severity = 1.0 if distance < 8 else (0.75 if distance < 15 else 0.5)
|
| 44 |
+
ticket_vectors.append(build_ticket_vector(
|
| 45 |
+
severity_weight=severity, ttl=5.0,
|
| 46 |
+
pos_x=rel_x, pos_y=rel_y, pos_z=0.0,
|
| 47 |
+
vel_x=car_spd, vel_y=0.0, vel_z=0.0,
|
| 48 |
+
heading=0.0,
|
| 49 |
+
size_length=4.0, size_width=2.0, size_height=1.5,
|
| 50 |
+
distance=distance, time_to_collision=ttc,
|
| 51 |
+
bearing=math.atan2(rel_y, max(rel_x, 0.01)),
|
| 52 |
+
ticket_type="collision_risk", entity_type="vehicle", confidence=1.0,
|
| 53 |
+
))
|
| 54 |
+
|
| 55 |
+
tv = np.array(ticket_vectors, dtype=np.float32) if ticket_vectors else None
|
| 56 |
+
return build_obs(
|
| 57 |
+
ego_x=ego_x, ego_y=ego_y, ego_z=0.0,
|
| 58 |
+
ego_vx=ego_speed_ms, ego_vy=0.0,
|
| 59 |
+
heading=0.0, speed=ego_speed_ms,
|
| 60 |
+
steer=0.0, throttle=0.5, brake=0.0,
|
| 61 |
+
ticket_vectors=tv,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def policy_action_to_decision(action_vec: np.ndarray) -> tuple[str, str]:
|
| 66 |
+
"""Continuous [steer, throttle, brake] β (text decision, reasoning)."""
|
| 67 |
+
steer, throttle, brake = float(action_vec[0]), float(action_vec[1]), float(action_vec[2])
|
| 68 |
+
if abs(steer) > 0.35:
|
| 69 |
+
decision = "lane_change_left" if steer < 0 else "lane_change_right"
|
| 70 |
+
reasoning = f"steer={steer:.2f}: lateral avoidance"
|
| 71 |
+
elif brake > 0.25:
|
| 72 |
+
decision = "brake"
|
| 73 |
+
reasoning = f"brake={brake:.2f}: closing gap"
|
| 74 |
+
elif throttle > 0.20:
|
| 75 |
+
decision = "accelerate"
|
| 76 |
+
reasoning = f"throttle={throttle:.2f}: clear ahead"
|
| 77 |
+
else:
|
| 78 |
+
decision = "maintain"
|
| 79 |
+
reasoning = f"s={steer:.2f} t={throttle:.2f} b={brake:.2f}: holding course"
|
| 80 |
+
return decision, reasoning
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
| 2 |
+
torch>=2.5.0
|
| 3 |
+
gymnasium>=0.29.0
|
| 4 |
+
openenv-core[core]>=0.2.1
|
| 5 |
+
fastapi>=0.115.0
|
| 6 |
+
pydantic>=2.0.0
|
| 7 |
+
uvicorn[standard]>=0.24.0
|
| 8 |
+
requests>=2.31.0
|
training/__init__.py
ADDED
|
File without changes
|
training/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (169 Bytes). View file
|
|
|
training/__pycache__/curriculum.cpython-314.pyc
ADDED
|
Binary file (5.69 kB). View file
|
|
|
training/__pycache__/overflow_gym_env.cpython-314.pyc
ADDED
|
Binary file (9.75 kB). View file
|
|
|
training/__pycache__/ppo_trainer.cpython-314.pyc
ADDED
|
Binary file (20.6 kB). View file
|
|
|
training/__pycache__/reward.cpython-314.pyc
ADDED
|
Binary file (3.47 kB). View file
|
|
|
training/curriculum.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CurriculumManager β ported from openenv/training/curriculum.py.
|
| 3 |
+
|
| 4 |
+
Same 4-stage progression and same reward thresholds. Adapted for
|
| 5 |
+
OverflowEnvironment: no ticket injection (the env has its own scripted
|
| 6 |
+
NPCs), stages instead control training logging and advancement criteria.
|
| 7 |
+
|
| 8 |
+
Stage 1 No extra pressure. Goal: learn basic speed + lane keeping.
|
| 9 |
+
Stage 2 Standard traffic. Goal: survive without crashing.
|
| 10 |
+
Stage 3 Evaluate more. Goal: consistent goal-reaching.
|
| 11 |
+
Stage 4 Full evaluation. Goal: high mean reward over long window.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import List
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class StageConfig:
|
| 22 |
+
stage: int
|
| 23 |
+
name: str
|
| 24 |
+
description: str
|
| 25 |
+
advance_threshold: float # mean episode reward to advance
|
| 26 |
+
advance_window: int # consecutive episodes required
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
STAGES: List[StageConfig] = [
|
| 30 |
+
StageConfig(
|
| 31 |
+
stage=1, name="Survival",
|
| 32 |
+
description="Learn basic speed control and lane keeping.",
|
| 33 |
+
advance_threshold=50.0, advance_window=8,
|
| 34 |
+
),
|
| 35 |
+
StageConfig(
|
| 36 |
+
stage=2, name="Crash Avoidance",
|
| 37 |
+
description="Navigate traffic without colliding.",
|
| 38 |
+
advance_threshold=120.0, advance_window=15,
|
| 39 |
+
),
|
| 40 |
+
StageConfig(
|
| 41 |
+
stage=3, name="Goal Reaching",
|
| 42 |
+
description="Consistently reach the goal position.",
|
| 43 |
+
advance_threshold=200.0, advance_window=15,
|
| 44 |
+
),
|
| 45 |
+
StageConfig(
|
| 46 |
+
stage=4, name="Mastery",
|
| 47 |
+
description="High reward, smooth driving, minimal near-misses.",
|
| 48 |
+
advance_threshold=280.0, advance_window=15,
|
| 49 |
+
),
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class CurriculumManager:
|
| 54 |
+
"""
|
| 55 |
+
Tracks stage progression based on episode rewards.
|
| 56 |
+
Same API as openenv CurriculumManager β PPOTrainer calls it unchanged.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, seed: int = 0):
|
| 60 |
+
self._stage_idx = 0
|
| 61 |
+
self._rewards: List[float] = []
|
| 62 |
+
self._auto_advance = True
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def current_stage(self) -> int:
|
| 66 |
+
return STAGES[self._stage_idx].stage
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def config(self) -> StageConfig:
|
| 70 |
+
return STAGES[self._stage_idx]
|
| 71 |
+
|
| 72 |
+
def step(self, sim_time: float) -> list:
|
| 73 |
+
"""No ticket injection in OverflowEnvironment β always returns []."""
|
| 74 |
+
return []
|
| 75 |
+
|
| 76 |
+
def record_episode_reward(self, reward: float) -> bool:
|
| 77 |
+
"""Record episode reward and advance stage if threshold met."""
|
| 78 |
+
self._rewards.append(reward)
|
| 79 |
+
cfg = self.config
|
| 80 |
+
window = self._rewards[-cfg.advance_window:]
|
| 81 |
+
|
| 82 |
+
if (
|
| 83 |
+
self._auto_advance
|
| 84 |
+
and len(window) >= cfg.advance_window
|
| 85 |
+
and sum(window) / len(window) >= cfg.advance_threshold
|
| 86 |
+
and self._stage_idx < len(STAGES) - 1
|
| 87 |
+
):
|
| 88 |
+
self._stage_idx += 1
|
| 89 |
+
self._rewards = []
|
| 90 |
+
print(f"[Curriculum] Advanced to Stage {self.current_stage}: {self.config.name}")
|
| 91 |
+
return True
|
| 92 |
+
return False
|
| 93 |
+
|
| 94 |
+
def force_stage(self, stage: int) -> None:
|
| 95 |
+
idx = stage - 1
|
| 96 |
+
if 0 <= idx < len(STAGES):
|
| 97 |
+
self._stage_idx = idx
|
| 98 |
+
self._rewards = []
|
| 99 |
+
print(f"[Curriculum] Forced to Stage {stage}: {self.config.name}")
|
training/overflow_gym_env.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gymnasium wrapper around OverflowEnvironment.
|
| 3 |
+
|
| 4 |
+
Bridges the gap between OverflowEnvironment (text actions, structured obs)
|
| 5 |
+
and our PPO trainer (continuous actions, numeric obs vector).
|
| 6 |
+
|
| 7 |
+
Observation: 603-dim float32 vector (same layout as CarEnv3D β ego state +
|
| 8 |
+
collision-risk ticket matrix built from nearby cars)
|
| 9 |
+
|
| 10 |
+
Action: [steer, throttle, brake] all in [-1, 1]
|
| 11 |
+
β mapped to text decision for OverflowEnvironment
|
| 12 |
+
|
| 13 |
+
This makes OverflowEnvironment a drop-in replacement for CarEnv3D so that
|
| 14 |
+
FlatMLPPolicy and TicketAttentionPolicy train with the exact same PPO loop.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
from typing import Any, Dict, Optional, Tuple
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import gymnasium as gym
|
| 24 |
+
from gymnasium import spaces
|
| 25 |
+
|
| 26 |
+
from ..server.overflow_environment import OverflowEnvironment
|
| 27 |
+
from ..models import OverflowAction
|
| 28 |
+
from ..policies.policy_spec import (
|
| 29 |
+
build_obs, build_ticket_vector, OBS_DIM,
|
| 30 |
+
)
|
| 31 |
+
from .reward import compute_reward
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ββ Action mapping ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
+
|
| 36 |
+
def _action_to_decision(action: np.ndarray) -> str:
|
| 37 |
+
steer, throttle, brake = float(action[0]), float(action[1]), float(action[2])
|
| 38 |
+
if abs(steer) > 0.35:
|
| 39 |
+
return "lane_change_left" if steer < 0 else "lane_change_right"
|
| 40 |
+
if brake > 0.25:
|
| 41 |
+
return "brake"
|
| 42 |
+
if throttle > 0.20:
|
| 43 |
+
return "accelerate"
|
| 44 |
+
return "maintain"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ββ Observation extraction ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 48 |
+
|
| 49 |
+
def _obs_to_vector(overflow_obs) -> np.ndarray:
|
| 50 |
+
"""OverflowObservation β 603-dim numpy vector matching policy_spec layout."""
|
| 51 |
+
cars = overflow_obs.cars
|
| 52 |
+
if not cars:
|
| 53 |
+
return np.zeros(OBS_DIM, dtype=np.float32)
|
| 54 |
+
|
| 55 |
+
ego = next((c for c in cars if c.carId == 0), cars[0])
|
| 56 |
+
ego_speed_ms = ego.speed / 4.5
|
| 57 |
+
ego_x = ego.position.x
|
| 58 |
+
ego_y = (ego.lane - 2) * 3.7
|
| 59 |
+
|
| 60 |
+
ticket_vectors = []
|
| 61 |
+
for car in cars:
|
| 62 |
+
if car.carId == 0:
|
| 63 |
+
continue
|
| 64 |
+
rel_x = car.position.x - ego.position.x
|
| 65 |
+
rel_y = (car.lane - ego.lane) * 3.7
|
| 66 |
+
car_spd = car.speed / 4.5
|
| 67 |
+
distance = math.sqrt(rel_x ** 2 + rel_y ** 2)
|
| 68 |
+
if distance > 80:
|
| 69 |
+
continue
|
| 70 |
+
closing = max(ego_speed_ms - car_spd * math.copysign(1, max(rel_x, 0.01)), 0.1)
|
| 71 |
+
ttc = min(distance / closing, 30.0)
|
| 72 |
+
severity = 1.0 if distance < 8 else (0.75 if distance < 15 else 0.5)
|
| 73 |
+
ticket_vectors.append(build_ticket_vector(
|
| 74 |
+
severity_weight=severity, ttl=5.0,
|
| 75 |
+
pos_x=rel_x, pos_y=rel_y, pos_z=0.0,
|
| 76 |
+
vel_x=car_spd, vel_y=0.0, vel_z=0.0,
|
| 77 |
+
heading=0.0,
|
| 78 |
+
size_length=4.0, size_width=2.0, size_height=1.5,
|
| 79 |
+
distance=distance, time_to_collision=ttc,
|
| 80 |
+
bearing=math.atan2(rel_y, max(rel_x, 0.01)),
|
| 81 |
+
ticket_type="collision_risk", entity_type="vehicle", confidence=1.0,
|
| 82 |
+
))
|
| 83 |
+
|
| 84 |
+
tv = np.array(ticket_vectors, dtype=np.float32) if ticket_vectors else None
|
| 85 |
+
return build_obs(
|
| 86 |
+
ego_x=ego_x, ego_y=ego_y, ego_z=0.0,
|
| 87 |
+
ego_vx=ego_speed_ms, ego_vy=0.0,
|
| 88 |
+
heading=0.0, speed=ego_speed_ms,
|
| 89 |
+
steer=0.0, throttle=0.5, brake=0.0,
|
| 90 |
+
ticket_vectors=tv,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# ββ Gymnasium wrapper βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 95 |
+
|
| 96 |
+
class OverflowGymEnv(gym.Env):
|
| 97 |
+
"""
|
| 98 |
+
Gymnasium-compatible wrapper around OverflowEnvironment.
|
| 99 |
+
|
| 100 |
+
Provides the same interface as CarEnv3D so PPOTrainer works unchanged.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
metadata = {"render_modes": []}
|
| 104 |
+
|
| 105 |
+
def __init__(self):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self._env = OverflowEnvironment()
|
| 108 |
+
self._last_overflow_obs = None
|
| 109 |
+
self._prev_action = np.zeros(3, dtype=np.float32)
|
| 110 |
+
self._sim_time = 0.0 # incremented each step (mirrors CarEnv3D._sim_time)
|
| 111 |
+
self._step_dt = 0.1 # seconds per step
|
| 112 |
+
|
| 113 |
+
self.observation_space = spaces.Box(
|
| 114 |
+
low=-1.0, high=1.0, shape=(OBS_DIM,), dtype=np.float32
|
| 115 |
+
)
|
| 116 |
+
self.action_space = spaces.Box(
|
| 117 |
+
low=-1.0, high=1.0, shape=(3,), dtype=np.float32
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def reset(
|
| 121 |
+
self,
|
| 122 |
+
seed: Optional[int] = None,
|
| 123 |
+
options: Optional[Dict[str, Any]] = None,
|
| 124 |
+
) -> Tuple[np.ndarray, Dict]:
|
| 125 |
+
super().reset(seed=seed)
|
| 126 |
+
self._last_overflow_obs = self._env.reset(seed=seed)
|
| 127 |
+
self._prev_action = np.zeros(3, dtype=np.float32)
|
| 128 |
+
self._sim_time = 0.0
|
| 129 |
+
return _obs_to_vector(self._last_overflow_obs), {}
|
| 130 |
+
|
| 131 |
+
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict]:
|
| 132 |
+
decision = _action_to_decision(action)
|
| 133 |
+
overflow_action = OverflowAction(decision=decision, reasoning="")
|
| 134 |
+
overflow_obs = self._env.step(overflow_action)
|
| 135 |
+
self._last_overflow_obs = overflow_obs
|
| 136 |
+
self._sim_time += self._step_dt
|
| 137 |
+
|
| 138 |
+
obs_vec = _obs_to_vector(overflow_obs)
|
| 139 |
+
|
| 140 |
+
# Extract signals for reward shaping
|
| 141 |
+
ego = next((c for c in overflow_obs.cars if c.carId == 0), None)
|
| 142 |
+
ego_speed_ms = (ego.speed / 4.5) if ego else 0.0
|
| 143 |
+
ego_y = ((ego.lane - 2) * 3.7) if ego else 0.0
|
| 144 |
+
|
| 145 |
+
collision = any("CRASH" in p for p in (overflow_obs.incident_report or "").split("\n")
|
| 146 |
+
if "Car 0" in p)
|
| 147 |
+
goal_reached = overflow_obs.done and not collision
|
| 148 |
+
|
| 149 |
+
reward = compute_reward(
|
| 150 |
+
ego_speed = ego_speed_ms,
|
| 151 |
+
ego_y = ego_y,
|
| 152 |
+
action = action,
|
| 153 |
+
prev_action = self._prev_action,
|
| 154 |
+
collision = collision,
|
| 155 |
+
goal_reached = goal_reached,
|
| 156 |
+
near_miss = "NEAR MISS" in (overflow_obs.incident_report or ""),
|
| 157 |
+
raw_reward = overflow_obs.reward or 0.0,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
self._prev_action = action.copy()
|
| 161 |
+
|
| 162 |
+
terminated = overflow_obs.done
|
| 163 |
+
truncated = False
|
| 164 |
+
info: Dict[str, Any] = {
|
| 165 |
+
"collision": collision,
|
| 166 |
+
"goal_reached": goal_reached,
|
| 167 |
+
"incident": overflow_obs.incident_report,
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
return obs_vec, reward, terminated, truncated, info
|
training/ppo_trainer.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PPO trainer β ported directly from openenv/training/ppo_trainer.py.
|
| 3 |
+
|
| 4 |
+
Same algorithm, same hyperparameters, same GAE implementation.
|
| 5 |
+
Only change: uses OverflowGymEnv instead of CarEnv3D.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
from overflow_env.training.ppo_trainer import run_training
|
| 9 |
+
run_training(policy_type="attention", total_steps=2_000_000)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import time
|
| 15 |
+
from collections import deque
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.optim as optim
|
| 23 |
+
|
| 24 |
+
from .overflow_gym_env import OverflowGymEnv
|
| 25 |
+
from .curriculum import CurriculumManager
|
| 26 |
+
from .reward import compute_episode_bonus
|
| 27 |
+
from ..policies.base_policy import BasePolicy
|
| 28 |
+
from ..policies.policy_spec import OBS_DIM
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ββ Rollout buffer βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
+
# Identical to openenv/training/ppo_trainer.py
|
| 33 |
+
|
| 34 |
+
class RolloutBuffer:
|
| 35 |
+
def __init__(self, n_steps: int, obs_dim: int, device: torch.device):
|
| 36 |
+
self.n = n_steps
|
| 37 |
+
self.obs = torch.zeros(n_steps, obs_dim, device=device)
|
| 38 |
+
self.acts = torch.zeros(n_steps, 3, device=device)
|
| 39 |
+
self.rew = torch.zeros(n_steps, device=device)
|
| 40 |
+
self.val = torch.zeros(n_steps, device=device)
|
| 41 |
+
self.logp = torch.zeros(n_steps, device=device)
|
| 42 |
+
self.done = torch.zeros(n_steps, device=device)
|
| 43 |
+
self.ptr = 0
|
| 44 |
+
|
| 45 |
+
def add(self, obs, act, rew, val, logp, done):
|
| 46 |
+
i = self.ptr
|
| 47 |
+
self.obs[i] = torch.as_tensor(obs, dtype=torch.float32)
|
| 48 |
+
self.acts[i] = torch.as_tensor(act, dtype=torch.float32)
|
| 49 |
+
self.rew[i] = float(rew)
|
| 50 |
+
self.val[i] = float(val)
|
| 51 |
+
self.logp[i] = float(logp)
|
| 52 |
+
self.done[i] = float(done)
|
| 53 |
+
self.ptr += 1
|
| 54 |
+
|
| 55 |
+
def full(self) -> bool:
|
| 56 |
+
return self.ptr >= self.n
|
| 57 |
+
|
| 58 |
+
def reset(self):
|
| 59 |
+
self.ptr = 0
|
| 60 |
+
|
| 61 |
+
def compute_returns(self, last_val: float, gamma: float, gae_lambda: float):
|
| 62 |
+
"""Generalized Advantage Estimation β identical to openenv."""
|
| 63 |
+
adv = torch.zeros_like(self.rew)
|
| 64 |
+
gae = 0.0
|
| 65 |
+
for t in reversed(range(self.n)):
|
| 66 |
+
next_val = last_val if t == self.n - 1 else float(self.val[t + 1])
|
| 67 |
+
delta = self.rew[t] + gamma * next_val * (1 - self.done[t]) - self.val[t]
|
| 68 |
+
gae = delta + gamma * gae_lambda * (1 - self.done[t]) * gae
|
| 69 |
+
adv[t] = gae
|
| 70 |
+
self.ret = adv + self.val
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ββ PPO Trainer ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 74 |
+
|
| 75 |
+
class PPOTrainer:
|
| 76 |
+
"""
|
| 77 |
+
Identical to openenv PPOTrainer β same hyperparameters, same PPO update.
|
| 78 |
+
Environment is OverflowGymEnv instead of CarEnv3D.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
policy: BasePolicy,
|
| 84 |
+
env: OverflowGymEnv,
|
| 85 |
+
curriculum: Optional[CurriculumManager] = None,
|
| 86 |
+
# PPO hyperparameters β same defaults as openenv
|
| 87 |
+
lr: float = 3e-4,
|
| 88 |
+
gamma: float = 0.99,
|
| 89 |
+
gae_lambda: float = 0.95,
|
| 90 |
+
clip_range: float = 0.2,
|
| 91 |
+
clip_range_vf: float = 0.2,
|
| 92 |
+
ent_coef: float = 0.02,
|
| 93 |
+
vf_coef: float = 0.5,
|
| 94 |
+
max_grad_norm: float = 0.5,
|
| 95 |
+
n_steps: int = 2048,
|
| 96 |
+
batch_size: int = 256,
|
| 97 |
+
n_epochs: int = 10,
|
| 98 |
+
save_dir: str = "checkpoints",
|
| 99 |
+
log_interval: int = 10,
|
| 100 |
+
device: str = "auto",
|
| 101 |
+
):
|
| 102 |
+
self.policy = policy
|
| 103 |
+
self.env = env
|
| 104 |
+
self.curriculum = curriculum or CurriculumManager()
|
| 105 |
+
self.gamma = gamma
|
| 106 |
+
self.gae_lambda = gae_lambda
|
| 107 |
+
self.clip = clip_range
|
| 108 |
+
self.clip_vf = clip_range_vf
|
| 109 |
+
self.ent_coef = ent_coef
|
| 110 |
+
self.vf_coef = vf_coef
|
| 111 |
+
self.max_grad = max_grad_norm
|
| 112 |
+
self.n_steps = n_steps
|
| 113 |
+
self.batch_size = batch_size
|
| 114 |
+
self.n_epochs = n_epochs
|
| 115 |
+
self.log_every = log_interval
|
| 116 |
+
self.save_dir = Path(save_dir)
|
| 117 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 118 |
+
|
| 119 |
+
if device == "auto":
|
| 120 |
+
device = "cuda" if torch.cuda.is_available() else \
|
| 121 |
+
"mps" if torch.backends.mps.is_available() else "cpu"
|
| 122 |
+
self.device = torch.device(device)
|
| 123 |
+
self.policy.to(self.device)
|
| 124 |
+
|
| 125 |
+
self.optimizer = optim.Adam(policy.parameters(), lr=lr, eps=1e-5)
|
| 126 |
+
self.scheduler = optim.lr_scheduler.LinearLR(
|
| 127 |
+
self.optimizer, start_factor=1.0, end_factor=0.1, total_iters=500,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
self.buffer = RolloutBuffer(n_steps, OBS_DIM, self.device)
|
| 131 |
+
|
| 132 |
+
self.ep_rewards = deque(maxlen=100)
|
| 133 |
+
self.ep_lengths = deque(maxlen=100)
|
| 134 |
+
self.total_steps = 0
|
| 135 |
+
self.n_updates = 0
|
| 136 |
+
|
| 137 |
+
# ββ Main training loop βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
+
|
| 139 |
+
def train(self, total_steps: int = 2_000_000) -> None:
|
| 140 |
+
print(f"\n{'='*70}", flush=True)
|
| 141 |
+
print(f" OpenENV PPO Training β policy={self.policy.__class__.__name__}", flush=True)
|
| 142 |
+
print(f" total_steps={total_steps} n_steps={self.n_steps} lr={self.optimizer.param_groups[0]['lr']:.0e}", flush=True)
|
| 143 |
+
print(f" gamma={self.gamma} gae_lambda={self.gae_lambda} clip={self.clip} ent_coef={self.ent_coef}", flush=True)
|
| 144 |
+
print(f"{'='*70}\n", flush=True)
|
| 145 |
+
|
| 146 |
+
obs, _ = self.env.reset()
|
| 147 |
+
ep_reward = 0.0
|
| 148 |
+
ep_steps = 0
|
| 149 |
+
t0 = time.time()
|
| 150 |
+
|
| 151 |
+
while self.total_steps < total_steps:
|
| 152 |
+
self.buffer.reset()
|
| 153 |
+
self.policy.eval()
|
| 154 |
+
|
| 155 |
+
# ββ Collect rollout ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 156 |
+
for _ in range(self.n_steps):
|
| 157 |
+
# Curriculum step (returns [] for OverflowEnv β kept for API compat)
|
| 158 |
+
self.curriculum.step(self.env._sim_time)
|
| 159 |
+
|
| 160 |
+
obs_t = torch.as_tensor(obs, dtype=torch.float32, device=self.device)
|
| 161 |
+
with torch.no_grad():
|
| 162 |
+
act_mean, val = self.policy(obs_t.unsqueeze(0))
|
| 163 |
+
act_mean = act_mean.squeeze(0)
|
| 164 |
+
val = val.squeeze(0)
|
| 165 |
+
|
| 166 |
+
dist = torch.distributions.Normal(act_mean, torch.ones_like(act_mean) * 0.3)
|
| 167 |
+
action = dist.sample().clamp(-1, 1)
|
| 168 |
+
logp = dist.log_prob(action).sum()
|
| 169 |
+
|
| 170 |
+
next_obs, reward, term, trunc, info = self.env.step(action.cpu().numpy())
|
| 171 |
+
|
| 172 |
+
self.buffer.add(
|
| 173 |
+
obs, action.cpu().numpy(), reward,
|
| 174 |
+
float(val), float(logp), float(term or trunc),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
obs = next_obs
|
| 178 |
+
ep_reward += reward
|
| 179 |
+
ep_steps += 1
|
| 180 |
+
self.total_steps += 1
|
| 181 |
+
|
| 182 |
+
if term or trunc:
|
| 183 |
+
bonus = compute_episode_bonus(
|
| 184 |
+
total_steps=ep_steps,
|
| 185 |
+
survived=not info.get("collision", False),
|
| 186 |
+
)
|
| 187 |
+
ep_reward += bonus
|
| 188 |
+
self.ep_rewards.append(ep_reward)
|
| 189 |
+
self.ep_lengths.append(ep_steps)
|
| 190 |
+
advanced = self.curriculum.record_episode_reward(ep_reward)
|
| 191 |
+
|
| 192 |
+
outcome = "CRASH" if info.get("collision") else ("GOAL" if info.get("goal_reached") else "timeout")
|
| 193 |
+
print(
|
| 194 |
+
f" ep#{len(self.ep_rewards):>4d} | "
|
| 195 |
+
f"steps={ep_steps:>3d} | "
|
| 196 |
+
f"reward={ep_reward:>8.2f} | "
|
| 197 |
+
f"outcome={outcome:<8} | "
|
| 198 |
+
f"stage={self.curriculum.current_stage} | "
|
| 199 |
+
f"total_steps={self.total_steps}",
|
| 200 |
+
flush=True,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
obs, _ = self.env.reset()
|
| 204 |
+
ep_reward = 0.0
|
| 205 |
+
ep_steps = 0
|
| 206 |
+
|
| 207 |
+
# ββ PPO update βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
obs_t = torch.as_tensor(obs, dtype=torch.float32, device=self.device)
|
| 210 |
+
_, last_val = self.policy(obs_t.unsqueeze(0))
|
| 211 |
+
self.buffer.compute_returns(float(last_val), self.gamma, self.gae_lambda)
|
| 212 |
+
|
| 213 |
+
self.policy.train()
|
| 214 |
+
self._ppo_update()
|
| 215 |
+
self.n_updates += 1
|
| 216 |
+
self.scheduler.step()
|
| 217 |
+
|
| 218 |
+
elapsed = time.time() - t0
|
| 219 |
+
sps = self.total_steps / max(elapsed, 1)
|
| 220 |
+
mean_r = np.mean(self.ep_rewards) if self.ep_rewards else 0.0
|
| 221 |
+
mean_l = np.mean(self.ep_lengths) if self.ep_lengths else 0.0
|
| 222 |
+
print(
|
| 223 |
+
f"\n[PPO update #{self.n_updates}] "
|
| 224 |
+
f"step={self.total_steps} "
|
| 225 |
+
f"mean_reward={mean_r:.2f} "
|
| 226 |
+
f"mean_ep_len={mean_l:.0f} "
|
| 227 |
+
f"stage={self.curriculum.current_stage} "
|
| 228 |
+
f"sps={sps:.0f}\n",
|
| 229 |
+
flush=True,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# ββ Checkpoint βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 233 |
+
if self.n_updates % 50 == 0:
|
| 234 |
+
ckpt = self.save_dir / f"policy_step{self.total_steps}_stage{self.curriculum.current_stage}.pt"
|
| 235 |
+
torch.save({
|
| 236 |
+
"step": self.total_steps,
|
| 237 |
+
"stage": self.curriculum.current_stage,
|
| 238 |
+
"policy": self.policy.state_dict(),
|
| 239 |
+
"optim": self.optimizer.state_dict(),
|
| 240 |
+
}, ckpt)
|
| 241 |
+
print(f"[PPO] Saved checkpoint β {ckpt}")
|
| 242 |
+
|
| 243 |
+
# ββ PPO update pass β identical to openenv βββββββββββββββββββββββββββββββββ
|
| 244 |
+
|
| 245 |
+
def _ppo_update(self):
|
| 246 |
+
obs = self.buffer.obs
|
| 247 |
+
acts = self.buffer.acts
|
| 248 |
+
old_logp = self.buffer.logp
|
| 249 |
+
adv = self.buffer.ret - self.buffer.val
|
| 250 |
+
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
|
| 251 |
+
ret = self.buffer.ret
|
| 252 |
+
old_val = self.buffer.val
|
| 253 |
+
|
| 254 |
+
indices = torch.randperm(self.n_steps, device=self.device)
|
| 255 |
+
|
| 256 |
+
for _ in range(self.n_epochs):
|
| 257 |
+
for start in range(0, self.n_steps, self.batch_size):
|
| 258 |
+
idx = indices[start: start + self.batch_size]
|
| 259 |
+
|
| 260 |
+
act_mean, val = self.policy(obs[idx])
|
| 261 |
+
val = val.squeeze(-1)
|
| 262 |
+
|
| 263 |
+
dist = torch.distributions.Normal(act_mean, torch.ones_like(act_mean) * 0.3)
|
| 264 |
+
logp = dist.log_prob(acts[idx]).sum(dim=-1)
|
| 265 |
+
entropy = dist.entropy().sum(dim=-1).mean()
|
| 266 |
+
|
| 267 |
+
ratio = torch.exp(logp - old_logp[idx])
|
| 268 |
+
pg_loss1 = -adv[idx] * ratio
|
| 269 |
+
pg_loss2 = -adv[idx] * ratio.clamp(1 - self.clip, 1 + self.clip)
|
| 270 |
+
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
|
| 271 |
+
|
| 272 |
+
val_unclipped = (val - ret[idx]) ** 2
|
| 273 |
+
val_clipped = (
|
| 274 |
+
old_val[idx]
|
| 275 |
+
+ (val - old_val[idx]).clamp(-self.clip_vf, self.clip_vf)
|
| 276 |
+
- ret[idx]
|
| 277 |
+
) ** 2
|
| 278 |
+
vf_loss = 0.5 * torch.max(val_unclipped, val_clipped).mean()
|
| 279 |
+
|
| 280 |
+
loss = pg_loss + self.vf_coef * vf_loss - self.ent_coef * entropy
|
| 281 |
+
|
| 282 |
+
self.optimizer.zero_grad()
|
| 283 |
+
loss.backward()
|
| 284 |
+
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad)
|
| 285 |
+
self.optimizer.step()
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# ββ Entry point ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 289 |
+
|
| 290 |
+
def run_training(
|
| 291 |
+
policy_type: str = "attention",
|
| 292 |
+
total_steps: int = 2_000_000,
|
| 293 |
+
start_stage: int = 1,
|
| 294 |
+
checkpoint: Optional[str] = None,
|
| 295 |
+
device: str = "auto",
|
| 296 |
+
) -> None:
|
| 297 |
+
from ..policies.ticket_attention_policy import TicketAttentionPolicy
|
| 298 |
+
from ..policies.flat_mlp_policy import FlatMLPPolicy
|
| 299 |
+
|
| 300 |
+
policy_map = {
|
| 301 |
+
"attention": lambda: TicketAttentionPolicy(obs_dim=OBS_DIM),
|
| 302 |
+
"mlp": lambda: FlatMLPPolicy(obs_dim=OBS_DIM),
|
| 303 |
+
}
|
| 304 |
+
policy = policy_map[policy_type]()
|
| 305 |
+
|
| 306 |
+
if checkpoint:
|
| 307 |
+
ckpt = torch.load(checkpoint, map_location="cpu")
|
| 308 |
+
policy.load_state_dict(ckpt["policy"])
|
| 309 |
+
print(f"[PPO] Loaded checkpoint from {checkpoint}")
|
| 310 |
+
|
| 311 |
+
env = OverflowGymEnv()
|
| 312 |
+
cm = CurriculumManager()
|
| 313 |
+
if start_stage > 1:
|
| 314 |
+
cm.force_stage(start_stage)
|
| 315 |
+
|
| 316 |
+
trainer = PPOTrainer(policy=policy, env=env, curriculum=cm, device=device, n_steps=512)
|
| 317 |
+
trainer.train(total_steps=total_steps)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
if __name__ == "__main__":
|
| 321 |
+
import argparse
|
| 322 |
+
p = argparse.ArgumentParser()
|
| 323 |
+
p.add_argument("--policy", default="attention", choices=["attention", "mlp"])
|
| 324 |
+
p.add_argument("--steps", default=2_000_000, type=int)
|
| 325 |
+
p.add_argument("--stage", default=1, type=int)
|
| 326 |
+
p.add_argument("--checkpoint", default=None)
|
| 327 |
+
p.add_argument("--device", default="auto")
|
| 328 |
+
args = p.parse_args()
|
| 329 |
+
run_training(args.policy, args.steps, args.stage, args.checkpoint, args.device)
|
training/reward.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reward shaping for OverflowEnvironment β ported from openenv/training/reward.py.
|
| 3 |
+
|
| 4 |
+
Same core principle: BASE + THREAT_RESPONSE with clear gradient direction.
|
| 5 |
+
Adapted to OverflowEnvironment's signals (no EventTicket objects β uses
|
| 6 |
+
collision/near-miss flags and raw reward from the environment).
|
| 7 |
+
|
| 8 |
+
BASE: survival + speed + lane ~+0.4/step
|
| 9 |
+
COLLISION: -50 (terminal)
|
| 10 |
+
NEAR MISS: -0.8 per event
|
| 11 |
+
GOAL REACHED: +5.0 (terminal bonus)
|
| 12 |
+
SMOOTH DRIVING: small bonus when no threats
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
# ββ Same weights as openenv/training/reward.py ββββββββββββββββββββββββββββββββ
|
| 20 |
+
|
| 21 |
+
W_ALIVE = 0.40
|
| 22 |
+
W_SPEED = 0.10
|
| 23 |
+
W_LANE = 0.15
|
| 24 |
+
W_SMOOTH = 0.03
|
| 25 |
+
TARGET_SPEED = 11.0 # m/s (~40 km/h)
|
| 26 |
+
TARGET_SPEED_TOL = 3.0
|
| 27 |
+
|
| 28 |
+
W_COLLISION = -50.0
|
| 29 |
+
W_NEAR_MISS = -0.8
|
| 30 |
+
W_GOAL = 5.0
|
| 31 |
+
W_SURVIVE_BONUS = 5.0
|
| 32 |
+
|
| 33 |
+
ROAD_HALF_WIDTH = 3.7 * 1.5 # ~2.5 lanes worth of tolerance
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def compute_reward(
|
| 37 |
+
ego_speed: float,
|
| 38 |
+
ego_y: float,
|
| 39 |
+
action: np.ndarray,
|
| 40 |
+
prev_action: np.ndarray,
|
| 41 |
+
collision: bool,
|
| 42 |
+
goal_reached: bool,
|
| 43 |
+
near_miss: bool,
|
| 44 |
+
raw_reward: float, # OverflowEnvironment's built-in reward (used as baseline)
|
| 45 |
+
) -> float:
|
| 46 |
+
"""
|
| 47 |
+
Shaped reward. Mirrors openenv reward structure:
|
| 48 |
+
- collision β large terminal penalty
|
| 49 |
+
- base survival + speed + lane keeping
|
| 50 |
+
- near-miss penalty
|
| 51 |
+
- goal bonus
|
| 52 |
+
- smooth driving bonus when clear
|
| 53 |
+
"""
|
| 54 |
+
if collision:
|
| 55 |
+
return W_COLLISION
|
| 56 |
+
|
| 57 |
+
reward = 0.0
|
| 58 |
+
|
| 59 |
+
# 1. Survival
|
| 60 |
+
reward += W_ALIVE
|
| 61 |
+
|
| 62 |
+
# 2. Speed maintenance (same formula as openenv)
|
| 63 |
+
speed_err = abs(ego_speed - TARGET_SPEED)
|
| 64 |
+
if speed_err < TARGET_SPEED_TOL:
|
| 65 |
+
reward += W_SPEED * (1.0 - speed_err / TARGET_SPEED_TOL)
|
| 66 |
+
else:
|
| 67 |
+
reward -= 0.03 * min(speed_err - TARGET_SPEED_TOL, 5.0)
|
| 68 |
+
|
| 69 |
+
# 3. Lane keeping
|
| 70 |
+
norm_y = abs(ego_y) / ROAD_HALF_WIDTH
|
| 71 |
+
reward += W_LANE * max(0.0, 1.0 - norm_y ** 2)
|
| 72 |
+
|
| 73 |
+
# 4. Near miss penalty
|
| 74 |
+
if near_miss:
|
| 75 |
+
reward += W_NEAR_MISS
|
| 76 |
+
|
| 77 |
+
# 5. Goal bonus
|
| 78 |
+
if goal_reached:
|
| 79 |
+
reward += W_GOAL
|
| 80 |
+
|
| 81 |
+
# 6. Smooth driving
|
| 82 |
+
action_delta = np.abs(action - prev_action).sum()
|
| 83 |
+
reward += W_SMOOTH * max(0.0, 1.0 - action_delta * 3.0)
|
| 84 |
+
|
| 85 |
+
return float(reward)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def compute_episode_bonus(total_steps: int, survived: bool) -> float:
|
| 89 |
+
"""End-of-episode bonus β same as openenv."""
|
| 90 |
+
if not survived:
|
| 91 |
+
return 0.0
|
| 92 |
+
bonus = W_SURVIVE_BONUS
|
| 93 |
+
bonus += min(total_steps, 500) * 0.02 # longevity reward
|
| 94 |
+
return float(bonus)
|