Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +5 -0
- Dockerfile +49 -0
- README.md +186 -6
- __init__.py +1 -0
- agents/README.md +20 -0
- agents/__init__.py +15 -0
- agents/district_controller.py +187 -0
- agents/district_coordinator.py +11 -0
- agents/heuristic_controller.py +91 -0
- agents/local_policy.py +110 -0
- agents/message_protocol.py +116 -0
- artifacts/district_llm_adapter_v3/main_run/adapter/README.md +210 -0
- artifacts/district_llm_adapter_v3/main_run/adapter/adapter_config.json +50 -0
- artifacts/district_llm_adapter_v3/main_run/adapter/adapter_model.safetensors +3 -0
- artifacts/district_llm_adapter_v3/main_run/adapter/tokenizer.json +3 -0
- artifacts/district_llm_adapter_v3/main_run/adapter/tokenizer_config.json +18 -0
- artifacts/dqn_shared/best_validation.pt +3 -0
- client.py +52 -0
- data/splits/README.md +21 -0
- data/splits/test_cities.txt +15 -0
- data/splits/train_cities.txt +70 -0
- data/splits/val_cities.txt +15 -0
- district_llm/FINAL_ABLATION_RUNBOOK.md +79 -0
- district_llm/RL_GUIDANCE_EVAL_RUNBOOK.md +126 -0
- district_llm/RL_LLM_WRAPPER_SWEEP_RUNBOOK.md +121 -0
- district_llm/__init__.py +18 -0
- district_llm/data.py +27 -0
- district_llm/derivation.py +228 -0
- district_llm/eval.py +436 -0
- district_llm/generate_dataset.py +390 -0
- district_llm/guided_control.py +67 -0
- district_llm/heuristic_guidance.py +73 -0
- district_llm/inference.py +223 -0
- district_llm/metrics.py +97 -0
- district_llm/prompting.py +66 -0
- district_llm/repair.py +392 -0
- district_llm/rl_guidance_wrapper.py +1004 -0
- district_llm/schema.py +429 -0
- district_llm/summary_builder.py +413 -0
- district_llm/teachers.py +227 -0
- district_llm/train_unsloth.py +129 -0
- env/README.md +43 -0
- env/__init__.py +18 -0
- env/cityflow_adapter.py +102 -0
- env/district_summary.py +9 -0
- env/intersection_config.py +49 -0
- env/observation_builder.py +224 -0
- env/reward.py +244 -0
- env/scenarios.py +10 -0
- env/traffic_env.py +356 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
artifacts/district_llm_adapter_v3/main_run/adapter/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
third_party/CityFlow/examples/replay.txt filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
third_party/CityFlow/extern/pybind11/.git.bak/objects/pack/pack-0b2353194187af7e228cfabd32025bb4d3af8551.idx filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
third_party/CityFlow/extern/pybind11/.git.bak/objects/pack/pack-0b2353194187af7e228cfabd32025bb4d3af8551.pack filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
third_party/CityFlow/extern/pybind11/.git.bak/objects/pack/pack-0b2353194187af7e228cfabd32025bb4d3af8551.rev filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim AS builder
|
| 2 |
+
|
| 3 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 4 |
+
build-essential \
|
| 5 |
+
cmake \
|
| 6 |
+
libboost-all-dev \
|
| 7 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
WORKDIR /build
|
| 10 |
+
|
| 11 |
+
COPY third_party/CityFlow ./CityFlow
|
| 12 |
+
RUN rm -rf ./CityFlow/build
|
| 13 |
+
RUN pip install --no-cache-dir ./CityFlow
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
FROM python:3.12-slim AS runtime
|
| 17 |
+
|
| 18 |
+
WORKDIR /app
|
| 19 |
+
|
| 20 |
+
COPY --from=builder /usr/local/lib/python3.12/site-packages/cityflow* \
|
| 21 |
+
/usr/local/lib/python3.12/site-packages/
|
| 22 |
+
|
| 23 |
+
COPY openenv_app/requirements.txt ./requirements.txt
|
| 24 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 25 |
+
|
| 26 |
+
COPY __init__.py ./__init__.py
|
| 27 |
+
COPY client.py ./client.py
|
| 28 |
+
COPY models.py ./models.py
|
| 29 |
+
COPY agents/ ./agents/
|
| 30 |
+
COPY district_llm/ ./district_llm/
|
| 31 |
+
COPY env/ ./env/
|
| 32 |
+
COPY openenv_app/ ./openenv_app/
|
| 33 |
+
COPY server/ ./server/
|
| 34 |
+
COPY training/ ./training/
|
| 35 |
+
COPY data/splits/ ./data/splits/
|
| 36 |
+
COPY artifacts/dqn_shared/best_validation.pt ./artifacts/dqn_shared/best_validation.pt
|
| 37 |
+
COPY artifacts/district_llm_adapter_v3/main_run/adapter/ ./artifacts/district_llm_adapter_v3/main_run/adapter/
|
| 38 |
+
|
| 39 |
+
RUN mkdir -p /app/data/generated /app/data/splits
|
| 40 |
+
|
| 41 |
+
ENV DATA_DIR=/app/data/generated
|
| 42 |
+
ENV SPLITS_DIR=/app/data/splits
|
| 43 |
+
ENV CHECKPOINT_PATH=/app/artifacts/dqn_shared/best_validation.pt
|
| 44 |
+
ENV DISTRICT_LLM_ADAPTER_PATH=/app/artifacts/district_llm_adapter_v3/main_run/adapter
|
| 45 |
+
|
| 46 |
+
EXPOSE 7860
|
| 47 |
+
|
| 48 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 49 |
+
CMD ["sh", "-c", "uvicorn server.app:app --host 0.0.0.0 --port ${PORT:-7860}"]
|
README.md
CHANGED
|
@@ -1,11 +1,191 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
-
short_description:
|
|
|
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Agentic Traffic
|
| 3 |
+
emoji: 🏢
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
short_description: Agentic AI to control traffic lights
|
| 9 |
+
app_port: 7860
|
| 10 |
+
base_path: /web
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# traffic-llm
|
| 14 |
+
|
| 15 |
+
CityFlow-based traffic-control project with intersection-level multi-agent DQN training and district-aware policy variants.
|
| 16 |
+
|
| 17 |
+
## Training
|
| 18 |
+
|
| 19 |
+
The default local-policy trainer now uses parameter-shared dueling Double DQN with prioritized replay and n-step returns:
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
python3 -m training.train_local_policy train
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
That trains against `data/generated`, uses `data/splits`, writes checkpoints to `artifacts/dqn_shared`, enables TensorBoard logging, uses parallel CPU rollout workers by default, shows `tqdm` progress bars, and now validates plus checkpoints every 40 updates by default.
|
| 26 |
+
|
| 27 |
+
For a broader but still manageable validation pass:
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
python3 -m training.train_local_policy train --max-val-cities 3 --val-scenarios-per-city 7
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
That evaluates 3 validation cities across all 7 scenario types. This gives 21 learned-policy validation episodes per eval, or 63 total episodes if random and fixed baselines are also enabled.
|
| 34 |
+
|
| 35 |
+
Phase-3-style full training with the same 40-update eval/checkpoint cadence:
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
python3 -m training.train_local_policy train \
|
| 39 |
+
--max-train-cities 70 \
|
| 40 |
+
--max-val-cities 3 \
|
| 41 |
+
--val-scenarios-per-city 7 \
|
| 42 |
+
--policy-arch single_head_with_district_feature \
|
| 43 |
+
--reward-variant wait_queue_throughput
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Useful ablations:
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
python3 -m training.train_local_policy train --policy-arch multi_head --reward-variant current
|
| 50 |
+
python3 -m training.train_local_policy train --policy-arch single_head --reward-variant current
|
| 51 |
+
python3 -m training.train_local_policy train --policy-arch single_head_with_district_feature --reward-variant wait_queue_throughput
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
For a fast phase-1 overfit run on one fixed world:
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
python3 -m training.train_local_policy train \
|
| 58 |
+
--total-updates 25 \
|
| 59 |
+
--train-city-id city_0072 \
|
| 60 |
+
--train-scenario-name normal \
|
| 61 |
+
--overfit-val-on-train-scenario \
|
| 62 |
+
--fast-overfit \
|
| 63 |
+
--policy-arch single_head_with_district_feature \
|
| 64 |
+
--reward-variant wait_queue_throughput
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
To create or refresh dataset splits:
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
python3 -m training.train_local_policy make-splits
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
To evaluate the best checkpoint:
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
python3 -m training.train_local_policy evaluate \
|
| 77 |
+
--checkpoint artifacts/dqn_shared/best_validation.pt \
|
| 78 |
+
--split val
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
To evaluate a heuristic baseline directly:
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
python3 -m training.train_local_policy evaluate --baseline queue_greedy --split val
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
## TensorBoard
|
| 88 |
+
|
| 89 |
+
TensorBoard logs are written to `artifacts/dqn_shared/tensorboard` by default.
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
tensorboard --logdir artifacts/dqn_shared/tensorboard
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
## District LLM
|
| 96 |
+
|
| 97 |
+
The district LLM stack lives under `district_llm/`. It treats the learned DQN local controller as the low-level executor, derives district-scale SFT labels automatically from DQN rollout windows, and defaults district-model fine-tuning to DQN-derived rows only.
|
| 98 |
+
|
| 99 |
+
Generate district-LLM data from a learned checkpoint:
|
| 100 |
+
|
| 101 |
+
```bash
|
| 102 |
+
python3 -m district_llm.generate_dataset \
|
| 103 |
+
--controller rl_checkpoint \
|
| 104 |
+
--checkpoint artifacts/dqn_shared/best_validation.pt \
|
| 105 |
+
--episodes 100 \
|
| 106 |
+
--decision-interval 10 \
|
| 107 |
+
--use-checkpoint-env-config \
|
| 108 |
+
--output data/district_llm_train.jsonl
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
Generate from fixed or heuristic baselines:
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
python3 -m district_llm.generate_dataset --controller fixed --episodes 50 --decision-interval 10 --output data/district_llm_fixed.jsonl
|
| 115 |
+
python3 -m district_llm.generate_dataset --controller queue_greedy --episodes 50 --decision-interval 10 --output data/district_llm_heuristic.jsonl
|
| 116 |
+
python3 -m district_llm.generate_dataset --teacher-spec fixed --teacher-spec random --episodes 50 --decision-interval 10 --output data/district_llm_multi_teacher.jsonl
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
Train a first-pass district model with Unsloth/QLoRA:
|
| 120 |
+
|
| 121 |
+
```bash
|
| 122 |
+
python3 -m training.train_district_llm \
|
| 123 |
+
--dataset data/district_llm_train.jsonl \
|
| 124 |
+
--output-dir artifacts/district_llm_qwen \
|
| 125 |
+
--model-name Qwen/Qwen2.5-7B-Instruct \
|
| 126 |
+
--load-in-4bit \
|
| 127 |
+
--lora-rank 16 \
|
| 128 |
+
--max-seq-length 1024 \
|
| 129 |
+
--max-steps 1000
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
Run single-sample inference:
|
| 133 |
+
|
| 134 |
+
```bash
|
| 135 |
+
python3 -m district_llm.inference \
|
| 136 |
+
--model artifacts/district_llm_qwen \
|
| 137 |
+
--city-id city_0006 \
|
| 138 |
+
--scenario-name accident \
|
| 139 |
+
--district-id d_00
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
Run the OpenEnv-compatible district wrapper on top of the current DQN stack:
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
uvicorn openenv_app.app:app --reload
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
## Algorithm
|
| 149 |
+
|
| 150 |
+
- Training algorithm: parameter-shared dueling Double DQN.
|
| 151 |
+
- Replay: prioritized replay over per-intersection transitions gathered from full CityFlow worlds.
|
| 152 |
+
- Return target: n-step bootstrap target with target-network updates.
|
| 153 |
+
- Execution: all controllable intersections act simultaneously every RL decision interval.
|
| 154 |
+
- Action space: `0 = hold current phase`, `1 = switch to next green phase`.
|
| 155 |
+
- Safety: `min_green_time` is enforced in the environment and exposed through action masking.
|
| 156 |
+
|
| 157 |
+
Policy architecture modes:
|
| 158 |
+
|
| 159 |
+
- `multi_head`: shared trunk with district-type-specific Q heads.
|
| 160 |
+
- `single_head`: one shared Q head for all intersections, with district type removed from the observation.
|
| 161 |
+
- `single_head_with_district_feature`: one shared Q head for all intersections, with district type left in the observation as an explicit feature.
|
| 162 |
+
|
| 163 |
+
Reward variants:
|
| 164 |
+
|
| 165 |
+
- `current`: backward-compatible waiting and queue penalty.
|
| 166 |
+
- `normalized_wait_queue`: normalized queue and waiting reduction reward.
|
| 167 |
+
- `wait_queue_throughput`: normalized queue/wait reduction plus throughput bonus and imbalance penalty.
|
| 168 |
+
|
| 169 |
+
## Smoke Test
|
| 170 |
+
|
| 171 |
+
To sanity-check one generated scenario with the real CityFlow environment:
|
| 172 |
+
|
| 173 |
+
```bash
|
| 174 |
+
python3 scripts/smoke_test_env.py --city-id city_0001 --scenario-name normal --policy random
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
## Project layout
|
| 178 |
+
|
| 179 |
+
- `agents/`: heuristic local policies and simple baselines.
|
| 180 |
+
- `env/`: CityFlow environment, topology parsing, observation building, and reward logic.
|
| 181 |
+
- `training/`: dataset utilities, replay-based DQN training, evaluation helpers, TensorBoard logging, and CLIs.
|
| 182 |
+
- `data/`: generated synthetic cities, split files, and dataset generation utilities.
|
| 183 |
+
- `scripts/`: utility scripts, including the CityFlow smoke test.
|
| 184 |
+
- `third_party/`: vendored dependencies, including CityFlow source.
|
| 185 |
+
|
| 186 |
+
## Notes
|
| 187 |
+
|
| 188 |
+
- The generated dataset is assumed to already exist under `data/generated`.
|
| 189 |
+
- District membership comes from `district_map.json`.
|
| 190 |
+
- District types come from `metadata.json`.
|
| 191 |
+
- Runtime training and evaluation require the `cityflow` Python module to be installed in the active environment.
|
__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""OpenEnv package root for the lean agentic traffic environment."""
|
agents/README.md
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# agents
|
| 2 |
+
|
| 3 |
+
Local traffic-control policies and compatibility shims.
|
| 4 |
+
|
| 5 |
+
## Main files
|
| 6 |
+
|
| 7 |
+
- [local_policy.py](/Users/aditya/Developer/traffic-llm/agents/local_policy.py)
|
| 8 |
+
Active v1 policy interfaces and simple baselines:
|
| 9 |
+
- `HoldPhasePolicy`
|
| 10 |
+
- `FixedCyclePolicy`
|
| 11 |
+
- `QueueGreedyPolicy`
|
| 12 |
+
- [district_controller.py](/Users/aditya/Developer/traffic-llm/agents/district_controller.py)
|
| 13 |
+
Older district-level prototype logic kept for compatibility.
|
| 14 |
+
- [district_coordinator.py](/Users/aditya/Developer/traffic-llm/agents/district_coordinator.py)
|
| 15 |
+
Import shim for older code paths.
|
| 16 |
+
|
| 17 |
+
## Notes
|
| 18 |
+
|
| 19 |
+
- The learned local-policy network itself lives in [training/models.py](/Users/aditya/Developer/traffic-llm/training/models.py), not here.
|
| 20 |
+
- For active training, use the parameter-shared DQN path in `training/`, not the district-controller prototypes.
|
agents/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from agents.local_policy import (
|
| 2 |
+
BaseLocalPolicy,
|
| 3 |
+
FixedCyclePolicy,
|
| 4 |
+
HoldPhasePolicy,
|
| 5 |
+
QueueGreedyPolicy,
|
| 6 |
+
SharedHeuristicLocalPolicy,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"BaseLocalPolicy",
|
| 11 |
+
"FixedCyclePolicy",
|
| 12 |
+
"HoldPhasePolicy",
|
| 13 |
+
"QueueGreedyPolicy",
|
| 14 |
+
"SharedHeuristicLocalPolicy",
|
| 15 |
+
]
|
agents/district_controller.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Any, Callable
|
| 5 |
+
|
| 6 |
+
from agents.message_protocol import DistrictDirective, parse_district_directive
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BaseDistrictCoordinator(ABC):
|
| 10 |
+
@abstractmethod
|
| 11 |
+
def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]:
|
| 12 |
+
raise NotImplementedError
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RuleBasedDistrictCoordinator(BaseDistrictCoordinator):
|
| 16 |
+
"""
|
| 17 |
+
Fast, deterministic, and robust.
|
| 18 |
+
Good first coordinator and good fallback if the LLM output fails.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
imbalance_threshold: float = 0.15,
|
| 24 |
+
border_pressure_threshold: float = 0.65,
|
| 25 |
+
default_duration: int = 2,
|
| 26 |
+
):
|
| 27 |
+
self.imbalance_threshold = imbalance_threshold
|
| 28 |
+
self.border_pressure_threshold = border_pressure_threshold
|
| 29 |
+
self.default_duration = default_duration
|
| 30 |
+
|
| 31 |
+
def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]:
|
| 32 |
+
district_id = district_summary.get("district_id", "unknown")
|
| 33 |
+
intersection_ids = district_summary.get("intersection_ids", [])
|
| 34 |
+
|
| 35 |
+
emergency = district_summary.get("emergency_vehicle", {})
|
| 36 |
+
if emergency.get("present", False):
|
| 37 |
+
return (
|
| 38 |
+
DistrictDirective(
|
| 39 |
+
mode="emergency_route",
|
| 40 |
+
target_intersections=emergency.get("route", intersection_ids),
|
| 41 |
+
duration=2,
|
| 42 |
+
rationale=f"Emergency vehicle detected in district {district_id}.",
|
| 43 |
+
corridor=emergency.get("corridor"),
|
| 44 |
+
district_weight=1.0,
|
| 45 |
+
)
|
| 46 |
+
.validate()
|
| 47 |
+
.to_dict()
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
corridor_loads = district_summary.get("corridor_loads", {})
|
| 51 |
+
ns = float(corridor_loads.get("ns", corridor_loads.get("north_south", 0.0)))
|
| 52 |
+
ew = float(corridor_loads.get("ew", corridor_loads.get("east_west", 0.0)))
|
| 53 |
+
|
| 54 |
+
border_pressure = district_summary.get("border_pressure", {})
|
| 55 |
+
border_max = 0.0
|
| 56 |
+
if isinstance(border_pressure, dict) and border_pressure:
|
| 57 |
+
border_max = max(float(v) for v in border_pressure.values())
|
| 58 |
+
|
| 59 |
+
if ew - ns > self.imbalance_threshold:
|
| 60 |
+
return (
|
| 61 |
+
DistrictDirective(
|
| 62 |
+
mode="prioritize_ew",
|
| 63 |
+
target_intersections=intersection_ids,
|
| 64 |
+
duration=self.default_duration,
|
| 65 |
+
rationale="East-west corridor is currently more congested than north-south.",
|
| 66 |
+
corridor="ew",
|
| 67 |
+
district_weight=(
|
| 68 |
+
0.7 if border_max < self.border_pressure_threshold else 0.9
|
| 69 |
+
),
|
| 70 |
+
)
|
| 71 |
+
.validate()
|
| 72 |
+
.to_dict()
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if ns - ew > self.imbalance_threshold:
|
| 76 |
+
return (
|
| 77 |
+
DistrictDirective(
|
| 78 |
+
mode="prioritize_ns",
|
| 79 |
+
target_intersections=intersection_ids,
|
| 80 |
+
duration=self.default_duration,
|
| 81 |
+
rationale="North-south corridor is currently more congested than east-west.",
|
| 82 |
+
corridor="ns",
|
| 83 |
+
district_weight=(
|
| 84 |
+
0.7 if border_max < self.border_pressure_threshold else 0.9
|
| 85 |
+
),
|
| 86 |
+
)
|
| 87 |
+
.validate()
|
| 88 |
+
.to_dict()
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if border_max >= self.border_pressure_threshold:
|
| 92 |
+
return (
|
| 93 |
+
DistrictDirective(
|
| 94 |
+
mode="damp_border_inflow",
|
| 95 |
+
target_intersections=intersection_ids,
|
| 96 |
+
duration=2,
|
| 97 |
+
rationale="Border pressure is high; reduce spill-in and smooth cross-district flow.",
|
| 98 |
+
district_weight=0.8,
|
| 99 |
+
)
|
| 100 |
+
.validate()
|
| 101 |
+
.to_dict()
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return (
|
| 105 |
+
DistrictDirective(
|
| 106 |
+
mode="none",
|
| 107 |
+
target_intersections=[],
|
| 108 |
+
duration=1,
|
| 109 |
+
rationale="District is reasonably balanced.",
|
| 110 |
+
district_weight=0.5,
|
| 111 |
+
)
|
| 112 |
+
.validate()
|
| 113 |
+
.to_dict()
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class LLMDistrictCoordinator(BaseDistrictCoordinator):
|
| 118 |
+
"""
|
| 119 |
+
LLM-backed coordinator.
|
| 120 |
+
|
| 121 |
+
`generator_fn` should accept a prompt string and return either:
|
| 122 |
+
- a JSON string, or
|
| 123 |
+
- a dict
|
| 124 |
+
|
| 125 |
+
Example:
|
| 126 |
+
coordinator = LLMDistrictCoordinator(generator_fn=my_model_call)
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
generator_fn: Callable[[str], str | dict[str, Any]],
|
| 132 |
+
fallback: BaseDistrictCoordinator | None = None,
|
| 133 |
+
max_prompt_chars: int = 4000,
|
| 134 |
+
):
|
| 135 |
+
self.generator_fn = generator_fn
|
| 136 |
+
self.fallback = fallback or RuleBasedDistrictCoordinator()
|
| 137 |
+
self.max_prompt_chars = max_prompt_chars
|
| 138 |
+
|
| 139 |
+
def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]:
|
| 140 |
+
prompt = self.build_prompt(district_summary)
|
| 141 |
+
try:
|
| 142 |
+
raw = self.generator_fn(prompt)
|
| 143 |
+
directive = parse_district_directive(raw).to_dict()
|
| 144 |
+
|
| 145 |
+
# If the LLM returns a no-op too often or malformed content,
|
| 146 |
+
# the parser still makes it safe. We keep that behavior.
|
| 147 |
+
return directive
|
| 148 |
+
except Exception:
|
| 149 |
+
return self.fallback.decide(district_summary)
|
| 150 |
+
|
| 151 |
+
def build_prompt(self, district_summary: dict[str, Any]) -> str:
|
| 152 |
+
summary_text = repr(district_summary)
|
| 153 |
+
if len(summary_text) > self.max_prompt_chars:
|
| 154 |
+
summary_text = summary_text[: self.max_prompt_chars] + " ...[truncated]"
|
| 155 |
+
|
| 156 |
+
return f"""You are a district-level traffic coordinator.
|
| 157 |
+
|
| 158 |
+
Your job is to choose a single strategic directive for the next few cycles.
|
| 159 |
+
|
| 160 |
+
Allowed modes:
|
| 161 |
+
- none
|
| 162 |
+
- prioritize_ns
|
| 163 |
+
- prioritize_ew
|
| 164 |
+
- green_wave
|
| 165 |
+
- emergency_route
|
| 166 |
+
- damp_border_inflow
|
| 167 |
+
|
| 168 |
+
Return ONLY valid JSON with these fields:
|
| 169 |
+
{{
|
| 170 |
+
"mode": string,
|
| 171 |
+
"target_intersections": list[string],
|
| 172 |
+
"duration": int,
|
| 173 |
+
"rationale": string,
|
| 174 |
+
"corridor": string or null,
|
| 175 |
+
"district_weight": float
|
| 176 |
+
}}
|
| 177 |
+
|
| 178 |
+
Guidelines:
|
| 179 |
+
- Use emergency_route if an emergency vehicle is present.
|
| 180 |
+
- Use prioritize_ns or prioritize_ew when one corridor is clearly more congested.
|
| 181 |
+
- Use damp_border_inflow when cross-district border pressure is high.
|
| 182 |
+
- Keep duration between 1 and 5.
|
| 183 |
+
- district_weight should be between 0.0 and 1.0.
|
| 184 |
+
|
| 185 |
+
District summary:
|
| 186 |
+
{summary_text}
|
| 187 |
+
"""
|
agents/district_coordinator.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from agents.district_controller import (
|
| 2 |
+
BaseDistrictCoordinator,
|
| 3 |
+
LLMDistrictCoordinator,
|
| 4 |
+
RuleBasedDistrictCoordinator,
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"BaseDistrictCoordinator",
|
| 9 |
+
"LLMDistrictCoordinator",
|
| 10 |
+
"RuleBasedDistrictCoordinator",
|
| 11 |
+
]
|
agents/heuristic_controller.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class HeuristicController:
|
| 7 |
+
"""
|
| 8 |
+
Simple local traffic-light controller.
|
| 9 |
+
|
| 10 |
+
Action space:
|
| 11 |
+
0 -> choose NS green
|
| 12 |
+
1 -> choose EW green
|
| 13 |
+
|
| 14 |
+
Assumes:
|
| 15 |
+
queue_lengths = [N, S, E, W]
|
| 16 |
+
waiting_counts = [N, S, E, W]
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
min_green_steps: int = 5,
|
| 22 |
+
switch_margin: float = 1.0,
|
| 23 |
+
district_bonus_scale: float = 3.0,
|
| 24 |
+
neighbor_pressure_scale: float = 0.25,
|
| 25 |
+
):
|
| 26 |
+
self.min_green_steps = min_green_steps
|
| 27 |
+
self.switch_margin = switch_margin
|
| 28 |
+
self.district_bonus_scale = district_bonus_scale
|
| 29 |
+
self.neighbor_pressure_scale = neighbor_pressure_scale
|
| 30 |
+
|
| 31 |
+
def act(self, obs: dict[str, Any]) -> int:
|
| 32 |
+
queue_lengths = obs.get("queue_lengths", [0, 0, 0, 0])
|
| 33 |
+
waiting_counts = obs.get("waiting_counts", [0, 0, 0, 0])
|
| 34 |
+
current_phase = int(obs.get("current_phase", 0))
|
| 35 |
+
time_since_switch = int(obs.get("time_since_switch", 0))
|
| 36 |
+
district_mode = obs.get("district_mode", "none")
|
| 37 |
+
district_weight = float(obs.get("district_weight", 0.5))
|
| 38 |
+
neighbor_pressure = obs.get("neighbor_pressure", [0.0, 0.0])
|
| 39 |
+
|
| 40 |
+
ns_score = (
|
| 41 |
+
queue_lengths[0]
|
| 42 |
+
+ queue_lengths[1]
|
| 43 |
+
+ 1.5 * (waiting_counts[0] + waiting_counts[1])
|
| 44 |
+
)
|
| 45 |
+
ew_score = (
|
| 46 |
+
queue_lengths[2]
|
| 47 |
+
+ queue_lengths[3]
|
| 48 |
+
+ 1.5 * (waiting_counts[2] + waiting_counts[3])
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Optional small neighbor-pressure bias
|
| 52 |
+
if isinstance(neighbor_pressure, list) and len(neighbor_pressure) >= 2:
|
| 53 |
+
ns_score += self.neighbor_pressure_scale * float(neighbor_pressure[0])
|
| 54 |
+
ew_score += self.neighbor_pressure_scale * float(neighbor_pressure[1])
|
| 55 |
+
|
| 56 |
+
# District-level strategic bias
|
| 57 |
+
district_bonus = self.district_bonus_scale * district_weight
|
| 58 |
+
if district_mode == "prioritize_ns":
|
| 59 |
+
ns_score += district_bonus
|
| 60 |
+
elif district_mode == "prioritize_ew":
|
| 61 |
+
ew_score += district_bonus
|
| 62 |
+
elif district_mode == "green_wave":
|
| 63 |
+
corridor = obs.get("district_corridor")
|
| 64 |
+
if corridor == "ns":
|
| 65 |
+
ns_score += district_bonus
|
| 66 |
+
elif corridor == "ew":
|
| 67 |
+
ew_score += district_bonus
|
| 68 |
+
elif district_mode == "emergency_route":
|
| 69 |
+
corridor = obs.get("district_corridor")
|
| 70 |
+
if corridor in {"north_to_south", "south_to_north", "ns"}:
|
| 71 |
+
ns_score += district_bonus * 1.5
|
| 72 |
+
elif corridor in {"west_to_east", "east_to_west", "ew"}:
|
| 73 |
+
ew_score += district_bonus * 1.5
|
| 74 |
+
|
| 75 |
+
desired_phase = 0 if ns_score >= ew_score else 1
|
| 76 |
+
|
| 77 |
+
# Avoid thrashing
|
| 78 |
+
if time_since_switch < self.min_green_steps:
|
| 79 |
+
return current_phase
|
| 80 |
+
|
| 81 |
+
# Only switch if the other phase is meaningfully better
|
| 82 |
+
current_score = ns_score if current_phase == 0 else ew_score
|
| 83 |
+
desired_score = ns_score if desired_phase == 0 else ew_score
|
| 84 |
+
|
| 85 |
+
if (
|
| 86 |
+
desired_phase != current_phase
|
| 87 |
+
and desired_score < current_score + self.switch_margin
|
| 88 |
+
):
|
| 89 |
+
return current_phase
|
| 90 |
+
|
| 91 |
+
return desired_phase
|
agents/local_policy.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseLocalPolicy(ABC):
|
| 9 |
+
@abstractmethod
|
| 10 |
+
def act(self, observation_batch: dict[str, np.ndarray]) -> np.ndarray:
|
| 11 |
+
raise NotImplementedError
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class HoldPhasePolicy(BaseLocalPolicy):
|
| 15 |
+
def act(self, observation_batch: dict[str, np.ndarray]) -> np.ndarray:
|
| 16 |
+
intersection_count = len(observation_batch["intersection_ids"])
|
| 17 |
+
return np.zeros(intersection_count, dtype=np.int64)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RandomPhasePolicy(BaseLocalPolicy):
|
| 21 |
+
def __init__(self, seed: int = 7):
|
| 22 |
+
self.rng = np.random.default_rng(seed)
|
| 23 |
+
|
| 24 |
+
def act(self, observation_batch: dict[str, np.ndarray]) -> np.ndarray:
|
| 25 |
+
action_mask = observation_batch["action_mask"]
|
| 26 |
+
actions = np.zeros(action_mask.shape[0], dtype=np.int64)
|
| 27 |
+
for row_index, mask in enumerate(action_mask):
|
| 28 |
+
valid_actions = np.flatnonzero(mask > 0.0)
|
| 29 |
+
actions[row_index] = int(self.rng.choice(valid_actions))
|
| 30 |
+
return actions
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class FixedCyclePolicy(BaseLocalPolicy):
|
| 34 |
+
def __init__(self, green_time: int = 20):
|
| 35 |
+
self.green_time = int(green_time)
|
| 36 |
+
|
| 37 |
+
def act(self, observation_batch: dict[str, np.ndarray]) -> np.ndarray:
|
| 38 |
+
elapsed = observation_batch["phase_elapsed"]
|
| 39 |
+
action_mask = observation_batch["action_mask"]
|
| 40 |
+
should_switch = (elapsed >= self.green_time) & (action_mask[:, 1] > 0.0)
|
| 41 |
+
return should_switch.astype(np.int64)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class QueueGreedyPolicy(BaseLocalPolicy):
|
| 45 |
+
def __init__(self, switch_margin: float = 1.0):
|
| 46 |
+
self.switch_margin = float(switch_margin)
|
| 47 |
+
|
| 48 |
+
def act(self, observation_batch: dict[str, np.ndarray]) -> np.ndarray:
|
| 49 |
+
counts = observation_batch["incoming_counts"]
|
| 50 |
+
waiting = observation_batch["incoming_waiting"]
|
| 51 |
+
lane_mask = observation_batch["lane_mask"]
|
| 52 |
+
current_phase = observation_batch["current_phase"]
|
| 53 |
+
action_mask = observation_batch["action_mask"]
|
| 54 |
+
|
| 55 |
+
midpoint = counts.shape[1] // 2
|
| 56 |
+
ns_score = (
|
| 57 |
+
counts[:, :midpoint].sum(axis=1)
|
| 58 |
+
+ 1.5 * waiting[:, :midpoint].sum(axis=1)
|
| 59 |
+
)
|
| 60 |
+
ew_score = (
|
| 61 |
+
counts[:, midpoint:].sum(axis=1)
|
| 62 |
+
+ 1.5 * waiting[:, midpoint:].sum(axis=1)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
valid_midpoint = lane_mask[:, :midpoint].sum(axis=1) > 0
|
| 66 |
+
ns_score = np.where(valid_midpoint, ns_score, 0.0)
|
| 67 |
+
|
| 68 |
+
desired_switch = np.where(
|
| 69 |
+
current_phase == 0,
|
| 70 |
+
ew_score > ns_score + self.switch_margin,
|
| 71 |
+
ns_score > ew_score + self.switch_margin,
|
| 72 |
+
)
|
| 73 |
+
desired_switch = desired_switch & (action_mask[:, 1] > 0.0)
|
| 74 |
+
return desired_switch.astype(np.int64)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class SharedHeuristicLocalPolicy(QueueGreedyPolicy):
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
min_green_steps: int = 5,
|
| 81 |
+
switch_margin: float = 1.0,
|
| 82 |
+
district_bonus_scale: float = 0.0,
|
| 83 |
+
neighbor_pressure_scale: float = 0.0,
|
| 84 |
+
):
|
| 85 |
+
self.min_green_steps = int(min_green_steps)
|
| 86 |
+
del district_bonus_scale, neighbor_pressure_scale
|
| 87 |
+
super().__init__(switch_margin=switch_margin)
|
| 88 |
+
|
| 89 |
+
def act_batch(self, observation_batch):
|
| 90 |
+
if "intersection_ids" in observation_batch:
|
| 91 |
+
return self.act(observation_batch)
|
| 92 |
+
|
| 93 |
+
actions: dict[str, int] = {}
|
| 94 |
+
for intersection_id, payload in observation_batch.items():
|
| 95 |
+
waiting = payload.get("waiting_counts", [0, 0, 0, 0])
|
| 96 |
+
queues = payload.get("queue_lengths", [0, 0, 0, 0])
|
| 97 |
+
current_phase = int(payload.get("current_phase", 0))
|
| 98 |
+
time_since_switch = int(payload.get("time_since_switch", 0))
|
| 99 |
+
|
| 100 |
+
ns_score = float(sum(queues[:2]) + 1.5 * sum(waiting[:2]))
|
| 101 |
+
ew_score = float(sum(queues[2:4]) + 1.5 * sum(waiting[2:4]))
|
| 102 |
+
desired_phase = 0 if ns_score >= ew_score else 1
|
| 103 |
+
|
| 104 |
+
if time_since_switch < self.min_green_steps:
|
| 105 |
+
actions[intersection_id] = current_phase
|
| 106 |
+
elif desired_phase != current_phase and abs(ns_score - ew_score) <= self.switch_margin:
|
| 107 |
+
actions[intersection_id] = current_phase
|
| 108 |
+
else:
|
| 109 |
+
actions[intersection_id] = desired_phase
|
| 110 |
+
return actions
|
agents/message_protocol.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from dataclasses import asdict, dataclass, field
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
VALID_MODES = {
|
| 9 |
+
"none",
|
| 10 |
+
"prioritize_ns",
|
| 11 |
+
"prioritize_ew",
|
| 12 |
+
"green_wave",
|
| 13 |
+
"emergency_route",
|
| 14 |
+
"damp_border_inflow",
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class NeighborMessage:
|
| 20 |
+
sender_intersection: str
|
| 21 |
+
receiver_intersection: str
|
| 22 |
+
congestion_level: float
|
| 23 |
+
spillback_risk: bool
|
| 24 |
+
dominant_direction: str # "ns", "ew", or "balanced"
|
| 25 |
+
queue_total: int
|
| 26 |
+
|
| 27 |
+
def to_dict(self) -> dict[str, Any]:
|
| 28 |
+
return asdict(self)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class DistrictDirective:
|
| 33 |
+
mode: str = "none"
|
| 34 |
+
target_intersections: list[str] = field(default_factory=list)
|
| 35 |
+
duration: int = 1
|
| 36 |
+
rationale: str = ""
|
| 37 |
+
corridor: str | None = None
|
| 38 |
+
district_weight: float = 0.5
|
| 39 |
+
|
| 40 |
+
def validate(self) -> "DistrictDirective":
|
| 41 |
+
if self.mode not in VALID_MODES:
|
| 42 |
+
self.mode = "none"
|
| 43 |
+
|
| 44 |
+
if not isinstance(self.target_intersections, list):
|
| 45 |
+
self.target_intersections = []
|
| 46 |
+
|
| 47 |
+
if not isinstance(self.duration, int):
|
| 48 |
+
self.duration = 1
|
| 49 |
+
self.duration = max(1, min(self.duration, 10))
|
| 50 |
+
|
| 51 |
+
if not isinstance(self.rationale, str):
|
| 52 |
+
self.rationale = ""
|
| 53 |
+
|
| 54 |
+
if self.corridor is not None and self.corridor not in {
|
| 55 |
+
"ns",
|
| 56 |
+
"ew",
|
| 57 |
+
"west_to_east",
|
| 58 |
+
"east_to_west",
|
| 59 |
+
"north_to_south",
|
| 60 |
+
"south_to_north",
|
| 61 |
+
}:
|
| 62 |
+
self.corridor = None
|
| 63 |
+
|
| 64 |
+
if not isinstance(self.district_weight, (int, float)):
|
| 65 |
+
self.district_weight = 0.5
|
| 66 |
+
self.district_weight = float(max(0.0, min(1.0, self.district_weight)))
|
| 67 |
+
|
| 68 |
+
return self
|
| 69 |
+
|
| 70 |
+
def to_dict(self) -> dict[str, Any]:
|
| 71 |
+
return asdict(self)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def parse_district_directive(payload: str | dict[str, Any]) -> DistrictDirective:
|
| 75 |
+
"""
|
| 76 |
+
Accept either raw JSON text or a dict and return a validated DistrictDirective.
|
| 77 |
+
Falls back safely to a no-op directive.
|
| 78 |
+
"""
|
| 79 |
+
try:
|
| 80 |
+
if isinstance(payload, str):
|
| 81 |
+
payload = payload.strip()
|
| 82 |
+
if not payload:
|
| 83 |
+
return DistrictDirective().validate()
|
| 84 |
+
|
| 85 |
+
# Try direct JSON parse
|
| 86 |
+
try:
|
| 87 |
+
data = json.loads(payload)
|
| 88 |
+
except json.JSONDecodeError:
|
| 89 |
+
# Try to extract JSON object from surrounding text
|
| 90 |
+
start = payload.find("{")
|
| 91 |
+
end = payload.rfind("}")
|
| 92 |
+
if start == -1 or end == -1 or end <= start:
|
| 93 |
+
return DistrictDirective().validate()
|
| 94 |
+
data = json.loads(payload[start : end + 1])
|
| 95 |
+
elif isinstance(payload, dict):
|
| 96 |
+
data = payload
|
| 97 |
+
else:
|
| 98 |
+
return DistrictDirective().validate()
|
| 99 |
+
|
| 100 |
+
directive = DistrictDirective(
|
| 101 |
+
mode=data.get("mode", "none"),
|
| 102 |
+
target_intersections=data.get("target_intersections", []),
|
| 103 |
+
duration=data.get("duration", 1),
|
| 104 |
+
rationale=data.get("rationale", ""),
|
| 105 |
+
corridor=data.get("corridor"),
|
| 106 |
+
district_weight=data.get("district_weight", 0.5),
|
| 107 |
+
)
|
| 108 |
+
return directive.validate()
|
| 109 |
+
except Exception:
|
| 110 |
+
return DistrictDirective().validate()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def safe_directive_dict(payload: str | dict[str, Any] | None) -> dict[str, Any]:
|
| 114 |
+
if payload is None:
|
| 115 |
+
return DistrictDirective().validate().to_dict()
|
| 116 |
+
return parse_district_directive(payload).to_dict()
|
artifacts/district_llm_adapter_v3/main_run/adapter/README.md
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: unsloth/llama-3.1-8b-unsloth-bnb-4bit
|
| 3 |
+
library_name: peft
|
| 4 |
+
pipeline_tag: text-generation
|
| 5 |
+
tags:
|
| 6 |
+
- base_model:adapter:unsloth/llama-3.1-8b-unsloth-bnb-4bit
|
| 7 |
+
- lora
|
| 8 |
+
- sft
|
| 9 |
+
- transformers
|
| 10 |
+
- trl
|
| 11 |
+
- unsloth
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# Model Card for Model ID
|
| 15 |
+
|
| 16 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
## Model Details
|
| 21 |
+
|
| 22 |
+
### Model Description
|
| 23 |
+
|
| 24 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
- **Developed by:** [More Information Needed]
|
| 29 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 30 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 31 |
+
- **Model type:** [More Information Needed]
|
| 32 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 33 |
+
- **License:** [More Information Needed]
|
| 34 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
### Model Sources [optional]
|
| 37 |
+
|
| 38 |
+
<!-- Provide the basic links for the model. -->
|
| 39 |
+
|
| 40 |
+
- **Repository:** [More Information Needed]
|
| 41 |
+
- **Paper [optional]:** [More Information Needed]
|
| 42 |
+
- **Demo [optional]:** [More Information Needed]
|
| 43 |
+
|
| 44 |
+
## Uses
|
| 45 |
+
|
| 46 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 47 |
+
|
| 48 |
+
### Direct Use
|
| 49 |
+
|
| 50 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 51 |
+
|
| 52 |
+
[More Information Needed]
|
| 53 |
+
|
| 54 |
+
### Downstream Use [optional]
|
| 55 |
+
|
| 56 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 57 |
+
|
| 58 |
+
[More Information Needed]
|
| 59 |
+
|
| 60 |
+
### Out-of-Scope Use
|
| 61 |
+
|
| 62 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 63 |
+
|
| 64 |
+
[More Information Needed]
|
| 65 |
+
|
| 66 |
+
## Bias, Risks, and Limitations
|
| 67 |
+
|
| 68 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 69 |
+
|
| 70 |
+
[More Information Needed]
|
| 71 |
+
|
| 72 |
+
### Recommendations
|
| 73 |
+
|
| 74 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 75 |
+
|
| 76 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 77 |
+
|
| 78 |
+
## How to Get Started with the Model
|
| 79 |
+
|
| 80 |
+
Use the code below to get started with the model.
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
## Training Details
|
| 85 |
+
|
| 86 |
+
### Training Data
|
| 87 |
+
|
| 88 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
### Training Procedure
|
| 93 |
+
|
| 94 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 95 |
+
|
| 96 |
+
#### Preprocessing [optional]
|
| 97 |
+
|
| 98 |
+
[More Information Needed]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
#### Training Hyperparameters
|
| 102 |
+
|
| 103 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 104 |
+
|
| 105 |
+
#### Speeds, Sizes, Times [optional]
|
| 106 |
+
|
| 107 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 108 |
+
|
| 109 |
+
[More Information Needed]
|
| 110 |
+
|
| 111 |
+
## Evaluation
|
| 112 |
+
|
| 113 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 114 |
+
|
| 115 |
+
### Testing Data, Factors & Metrics
|
| 116 |
+
|
| 117 |
+
#### Testing Data
|
| 118 |
+
|
| 119 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 120 |
+
|
| 121 |
+
[More Information Needed]
|
| 122 |
+
|
| 123 |
+
#### Factors
|
| 124 |
+
|
| 125 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 126 |
+
|
| 127 |
+
[More Information Needed]
|
| 128 |
+
|
| 129 |
+
#### Metrics
|
| 130 |
+
|
| 131 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 132 |
+
|
| 133 |
+
[More Information Needed]
|
| 134 |
+
|
| 135 |
+
### Results
|
| 136 |
+
|
| 137 |
+
[More Information Needed]
|
| 138 |
+
|
| 139 |
+
#### Summary
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
## Model Examination [optional]
|
| 144 |
+
|
| 145 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 146 |
+
|
| 147 |
+
[More Information Needed]
|
| 148 |
+
|
| 149 |
+
## Environmental Impact
|
| 150 |
+
|
| 151 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 152 |
+
|
| 153 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 154 |
+
|
| 155 |
+
- **Hardware Type:** [More Information Needed]
|
| 156 |
+
- **Hours used:** [More Information Needed]
|
| 157 |
+
- **Cloud Provider:** [More Information Needed]
|
| 158 |
+
- **Compute Region:** [More Information Needed]
|
| 159 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 160 |
+
|
| 161 |
+
## Technical Specifications [optional]
|
| 162 |
+
|
| 163 |
+
### Model Architecture and Objective
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
### Compute Infrastructure
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
#### Hardware
|
| 172 |
+
|
| 173 |
+
[More Information Needed]
|
| 174 |
+
|
| 175 |
+
#### Software
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
## Citation [optional]
|
| 180 |
+
|
| 181 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 182 |
+
|
| 183 |
+
**BibTeX:**
|
| 184 |
+
|
| 185 |
+
[More Information Needed]
|
| 186 |
+
|
| 187 |
+
**APA:**
|
| 188 |
+
|
| 189 |
+
[More Information Needed]
|
| 190 |
+
|
| 191 |
+
## Glossary [optional]
|
| 192 |
+
|
| 193 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## More Information [optional]
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
| 200 |
+
|
| 201 |
+
## Model Card Authors [optional]
|
| 202 |
+
|
| 203 |
+
[More Information Needed]
|
| 204 |
+
|
| 205 |
+
## Model Card Contact
|
| 206 |
+
|
| 207 |
+
[More Information Needed]
|
| 208 |
+
### Framework versions
|
| 209 |
+
|
| 210 |
+
- PEFT 0.18.1
|
artifacts/district_llm_adapter_v3/main_run/adapter/adapter_config.json
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alora_invocation_tokens": null,
|
| 3 |
+
"alpha_pattern": {},
|
| 4 |
+
"arrow_config": null,
|
| 5 |
+
"auto_mapping": {
|
| 6 |
+
"base_model_class": "LlamaForCausalLM",
|
| 7 |
+
"parent_library": "transformers.models.llama.modeling_llama",
|
| 8 |
+
"unsloth_fixed": true
|
| 9 |
+
},
|
| 10 |
+
"base_model_name_or_path": "unsloth/llama-3.1-8b-unsloth-bnb-4bit",
|
| 11 |
+
"bias": "none",
|
| 12 |
+
"corda_config": null,
|
| 13 |
+
"ensure_weight_tying": false,
|
| 14 |
+
"eva_config": null,
|
| 15 |
+
"exclude_modules": null,
|
| 16 |
+
"fan_in_fan_out": false,
|
| 17 |
+
"inference_mode": true,
|
| 18 |
+
"init_lora_weights": true,
|
| 19 |
+
"layer_replication": null,
|
| 20 |
+
"layers_pattern": null,
|
| 21 |
+
"layers_to_transform": null,
|
| 22 |
+
"loftq_config": {},
|
| 23 |
+
"lora_alpha": 16,
|
| 24 |
+
"lora_bias": false,
|
| 25 |
+
"lora_dropout": 0.05,
|
| 26 |
+
"megatron_config": null,
|
| 27 |
+
"megatron_core": "megatron.core",
|
| 28 |
+
"modules_to_save": null,
|
| 29 |
+
"peft_type": "LORA",
|
| 30 |
+
"peft_version": "0.18.1",
|
| 31 |
+
"qalora_group_size": 16,
|
| 32 |
+
"r": 16,
|
| 33 |
+
"rank_pattern": {},
|
| 34 |
+
"revision": null,
|
| 35 |
+
"target_modules": [
|
| 36 |
+
"down_proj",
|
| 37 |
+
"q_proj",
|
| 38 |
+
"up_proj",
|
| 39 |
+
"gate_proj",
|
| 40 |
+
"v_proj",
|
| 41 |
+
"k_proj",
|
| 42 |
+
"o_proj"
|
| 43 |
+
],
|
| 44 |
+
"target_parameters": null,
|
| 45 |
+
"task_type": "CAUSAL_LM",
|
| 46 |
+
"trainable_token_indices": null,
|
| 47 |
+
"use_dora": false,
|
| 48 |
+
"use_qalora": false,
|
| 49 |
+
"use_rslora": false
|
| 50 |
+
}
|
artifacts/district_llm_adapter_v3/main_run/adapter/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:69163c214c0ad5462dc64dfb110bea79e50a2c4d5affdf084a8b351352231777
|
| 3 |
+
size 167832240
|
artifacts/district_llm_adapter_v3/main_run/adapter/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b9e4e7fb171f92fd137b777cc2714bf87d11576700a1dcd7a399e7bbe39537b
|
| 3 |
+
size 17209920
|
artifacts/district_llm_adapter_v3/main_run/adapter/tokenizer_config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"bos_token": "<|begin_of_text|>",
|
| 4 |
+
"clean_up_tokenization_spaces": true,
|
| 5 |
+
"eos_token": "<|end_of_text|>",
|
| 6 |
+
"from_slow": true,
|
| 7 |
+
"is_local": false,
|
| 8 |
+
"legacy": false,
|
| 9 |
+
"model_input_names": [
|
| 10 |
+
"input_ids",
|
| 11 |
+
"attention_mask"
|
| 12 |
+
],
|
| 13 |
+
"model_max_length": 131072,
|
| 14 |
+
"pad_token": "<|finetune_right_pad_id|>",
|
| 15 |
+
"padding_side": "left",
|
| 16 |
+
"tokenizer_class": "TokenizersBackend",
|
| 17 |
+
"unk_token": null
|
| 18 |
+
}
|
artifacts/dqn_shared/best_validation.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8a478cd3149c74ef0f0a57e5510bccd35262272a0a25bfb76b6aac2e8417af85
|
| 3 |
+
size 1320091
|
client.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import requests
|
| 6 |
+
|
| 7 |
+
from models import (
|
| 8 |
+
AgenticTrafficAction,
|
| 9 |
+
AgenticTrafficObservation,
|
| 10 |
+
AgenticTrafficState,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AgenticTrafficClient:
|
| 15 |
+
"""Thin HTTP client for the DistrictFlow OpenEnv server."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, base_url: str):
|
| 18 |
+
self.base_url = base_url.rstrip("/")
|
| 19 |
+
|
| 20 |
+
def reset(self, seed: int | None = None) -> AgenticTrafficObservation:
|
| 21 |
+
response = requests.post(
|
| 22 |
+
f"{self.base_url}/reset",
|
| 23 |
+
json={"seed": seed},
|
| 24 |
+
timeout=60,
|
| 25 |
+
)
|
| 26 |
+
response.raise_for_status()
|
| 27 |
+
payload = response.json()
|
| 28 |
+
return AgenticTrafficObservation.model_validate(payload["observation"])
|
| 29 |
+
|
| 30 |
+
def step(self, action: AgenticTrafficAction) -> AgenticTrafficObservation:
|
| 31 |
+
response = requests.post(
|
| 32 |
+
f"{self.base_url}/step",
|
| 33 |
+
json={"action": action.model_dump()},
|
| 34 |
+
timeout=60,
|
| 35 |
+
)
|
| 36 |
+
response.raise_for_status()
|
| 37 |
+
payload = response.json()
|
| 38 |
+
observation = AgenticTrafficObservation.model_validate(payload["observation"])
|
| 39 |
+
observation.done = bool(payload.get("done", False))
|
| 40 |
+
observation.reward = float(payload.get("reward", 0.0))
|
| 41 |
+
return observation
|
| 42 |
+
|
| 43 |
+
def state(self) -> AgenticTrafficState:
|
| 44 |
+
response = requests.get(f"{self.base_url}/state", timeout=60)
|
| 45 |
+
response.raise_for_status()
|
| 46 |
+
payload = response.json()
|
| 47 |
+
return AgenticTrafficState.model_validate(payload["state"])
|
| 48 |
+
|
| 49 |
+
def health(self) -> dict[str, Any]:
|
| 50 |
+
response = requests.get(f"{self.base_url}/health", timeout=30)
|
| 51 |
+
response.raise_for_status()
|
| 52 |
+
return response.json()
|
data/splits/README.md
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# data/splits
|
| 2 |
+
|
| 3 |
+
City-level train/validation/test splits for the generated dataset.
|
| 4 |
+
|
| 5 |
+
## Files
|
| 6 |
+
|
| 7 |
+
- [train_cities.txt](/Users/aditya/Developer/traffic-llm/data/splits/train_cities.txt)
|
| 8 |
+
- [val_cities.txt](/Users/aditya/Developer/traffic-llm/data/splits/val_cities.txt)
|
| 9 |
+
- [test_cities.txt](/Users/aditya/Developer/traffic-llm/data/splits/test_cities.txt)
|
| 10 |
+
|
| 11 |
+
## Important rule
|
| 12 |
+
|
| 13 |
+
Splits are by city only. All scenarios for a given city belong to the same split.
|
| 14 |
+
|
| 15 |
+
## Regeneration
|
| 16 |
+
|
| 17 |
+
Use:
|
| 18 |
+
|
| 19 |
+
`python3 -m training.train_local_policy make-splits`
|
| 20 |
+
|
| 21 |
+
The split logic is implemented in [training/dataset.py](/Users/aditya/Developer/traffic-llm/training/dataset.py).
|
data/splits/test_cities.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
city_0005
|
| 2 |
+
city_0007
|
| 3 |
+
city_0008
|
| 4 |
+
city_0010
|
| 5 |
+
city_0012
|
| 6 |
+
city_0013
|
| 7 |
+
city_0020
|
| 8 |
+
city_0028
|
| 9 |
+
city_0042
|
| 10 |
+
city_0047
|
| 11 |
+
city_0051
|
| 12 |
+
city_0065
|
| 13 |
+
city_0069
|
| 14 |
+
city_0075
|
| 15 |
+
city_0084
|
data/splits/train_cities.txt
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
city_0001
|
| 2 |
+
city_0002
|
| 3 |
+
city_0003
|
| 4 |
+
city_0004
|
| 5 |
+
city_0006
|
| 6 |
+
city_0011
|
| 7 |
+
city_0014
|
| 8 |
+
city_0015
|
| 9 |
+
city_0017
|
| 10 |
+
city_0018
|
| 11 |
+
city_0019
|
| 12 |
+
city_0021
|
| 13 |
+
city_0022
|
| 14 |
+
city_0023
|
| 15 |
+
city_0024
|
| 16 |
+
city_0025
|
| 17 |
+
city_0026
|
| 18 |
+
city_0027
|
| 19 |
+
city_0030
|
| 20 |
+
city_0032
|
| 21 |
+
city_0033
|
| 22 |
+
city_0034
|
| 23 |
+
city_0035
|
| 24 |
+
city_0036
|
| 25 |
+
city_0037
|
| 26 |
+
city_0038
|
| 27 |
+
city_0039
|
| 28 |
+
city_0040
|
| 29 |
+
city_0041
|
| 30 |
+
city_0043
|
| 31 |
+
city_0044
|
| 32 |
+
city_0045
|
| 33 |
+
city_0046
|
| 34 |
+
city_0048
|
| 35 |
+
city_0049
|
| 36 |
+
city_0050
|
| 37 |
+
city_0052
|
| 38 |
+
city_0053
|
| 39 |
+
city_0057
|
| 40 |
+
city_0058
|
| 41 |
+
city_0059
|
| 42 |
+
city_0060
|
| 43 |
+
city_0061
|
| 44 |
+
city_0062
|
| 45 |
+
city_0063
|
| 46 |
+
city_0064
|
| 47 |
+
city_0066
|
| 48 |
+
city_0067
|
| 49 |
+
city_0068
|
| 50 |
+
city_0070
|
| 51 |
+
city_0072
|
| 52 |
+
city_0074
|
| 53 |
+
city_0076
|
| 54 |
+
city_0077
|
| 55 |
+
city_0079
|
| 56 |
+
city_0080
|
| 57 |
+
city_0081
|
| 58 |
+
city_0082
|
| 59 |
+
city_0083
|
| 60 |
+
city_0085
|
| 61 |
+
city_0087
|
| 62 |
+
city_0088
|
| 63 |
+
city_0089
|
| 64 |
+
city_0092
|
| 65 |
+
city_0093
|
| 66 |
+
city_0094
|
| 67 |
+
city_0095
|
| 68 |
+
city_0097
|
| 69 |
+
city_0099
|
| 70 |
+
city_0100
|
data/splits/val_cities.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
city_0009
|
| 2 |
+
city_0016
|
| 3 |
+
city_0029
|
| 4 |
+
city_0031
|
| 5 |
+
city_0054
|
| 6 |
+
city_0055
|
| 7 |
+
city_0056
|
| 8 |
+
city_0071
|
| 9 |
+
city_0073
|
| 10 |
+
city_0078
|
| 11 |
+
city_0086
|
| 12 |
+
city_0090
|
| 13 |
+
city_0091
|
| 14 |
+
city_0096
|
| 15 |
+
city_0098
|
district_llm/FINAL_ABLATION_RUNBOOK.md
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Final Ablation Runbook
|
| 2 |
+
|
| 3 |
+
## Dataset
|
| 4 |
+
|
| 5 |
+
Generate the constrained v3 dataset:
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
python scripts/generate_large_district_dataset.py \
|
| 9 |
+
--num-train 10000 \
|
| 10 |
+
--num-val 2500 \
|
| 11 |
+
--output-dir data/district_llm_dataset_v3 \
|
| 12 |
+
--checkpoint artifacts/dqn_shared/best_validation.pt \
|
| 13 |
+
--max-candidate-intersections 6 \
|
| 14 |
+
--max-target-intersections 3
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
Defaults:
|
| 18 |
+
|
| 19 |
+
- candidate pool is visible in the prompt via `candidate_intersections`
|
| 20 |
+
- labels are constrained to visible candidates
|
| 21 |
+
- DQN teacher sources are preferred by default
|
| 22 |
+
|
| 23 |
+
## Notebook
|
| 24 |
+
|
| 25 |
+
Use [notebooks/llama_finetune.ipynb](/root/aditya/agentic-traffic/notebooks/llama_finetune.ipynb).
|
| 26 |
+
|
| 27 |
+
Recommended defaults for the A100 main run:
|
| 28 |
+
|
| 29 |
+
- `RUN_MODE = "main_run"`
|
| 30 |
+
- `num_train_epochs = 2`
|
| 31 |
+
- `per_device_train_batch_size = 8`
|
| 32 |
+
- `gradient_accumulation_steps = 4`
|
| 33 |
+
- effective batch size = 32
|
| 34 |
+
- `learning_rate = 1e-4`
|
| 35 |
+
- `warmup_ratio = 0.05`
|
| 36 |
+
- `eval_steps = 100`
|
| 37 |
+
- `save_steps = 100`
|
| 38 |
+
|
| 39 |
+
Smoke test mode:
|
| 40 |
+
|
| 41 |
+
- `RUN_MODE = "smoke_test"`
|
| 42 |
+
- short `max_steps`
|
| 43 |
+
- verifies formatting, checkpointing, and eval wiring
|
| 44 |
+
|
| 45 |
+
Optional max-step override:
|
| 46 |
+
|
| 47 |
+
- set `MAX_STEPS_OVERRIDE = 5000` only for explicit experimentation
|
| 48 |
+
- do not use it as the default main run
|
| 49 |
+
|
| 50 |
+
Artifacts:
|
| 51 |
+
|
| 52 |
+
- checkpoints: `artifacts/district_llm_adapter_v3/<run_mode>/checkpoints`
|
| 53 |
+
- saved adapter: `artifacts/district_llm_adapter_v3/<run_mode>/adapter`
|
| 54 |
+
|
| 55 |
+
## Evaluation
|
| 56 |
+
|
| 57 |
+
Run offline eval with repair enabled:
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
python -m district_llm.eval \
|
| 61 |
+
--model-path artifacts/district_llm_adapter_v3/main_run/adapter \
|
| 62 |
+
--val-jsonl data/district_llm_dataset_v3/val.jsonl \
|
| 63 |
+
--generated-root data/generated \
|
| 64 |
+
--max-examples 250 \
|
| 65 |
+
--debug-examples 10 \
|
| 66 |
+
--allow-only-visible-candidates \
|
| 67 |
+
--max-target-intersections 3 \
|
| 68 |
+
--fallback-on-empty-targets \
|
| 69 |
+
--fallback-mode heuristic \
|
| 70 |
+
--restrict-targets-to-visible-summary \
|
| 71 |
+
--report-before-after-repair
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
Key outputs:
|
| 75 |
+
|
| 76 |
+
- raw vs repaired target metrics
|
| 77 |
+
- invalid target-id rate before and after repair
|
| 78 |
+
- visible-candidate-restricted metrics
|
| 79 |
+
- target failure buckets and debug examples
|
district_llm/RL_GUIDANCE_EVAL_RUNBOOK.md
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RL Guidance Eval Runbook
|
| 2 |
+
|
| 3 |
+
This ablation keeps the RL checkpoint fixed.
|
| 4 |
+
|
| 5 |
+
District guidance is only used at inference time through the wrapper in
|
| 6 |
+
`district_llm/rl_guidance_wrapper.py`. The safest default is
|
| 7 |
+
`target_only_soft`, which applies a small local Q-value bias only at
|
| 8 |
+
`target_intersections`.
|
| 9 |
+
|
| 10 |
+
## Wrapper Modes
|
| 11 |
+
|
| 12 |
+
- `no_op`: guidance is computed and logged, but RL actions are unchanged.
|
| 13 |
+
- `target_only_soft`: weak local prior on target intersections. Default debug mode.
|
| 14 |
+
- `target_only_medium`: same scope, slightly stronger.
|
| 15 |
+
- `corridor_soft`: small corridor prior on targets plus a few aligned boundary intersections.
|
| 16 |
+
- `global_soft`: weak district-wide prior. Use only as an ablation.
|
| 17 |
+
- `current_legacy`: reference mode approximating the old strong/global wrapper.
|
| 18 |
+
|
| 19 |
+
## Fast Debug Matrix
|
| 20 |
+
|
| 21 |
+
Use a short horizon first so the wrapper can be debugged quickly:
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
python scripts/eval_rl_guidance_ablation.py \
|
| 25 |
+
--rl-checkpoint artifacts/dqn_shared/best_validation.pt \
|
| 26 |
+
--llm-model-path artifacts/district_llm_adapter_v3/main_run/adapter \
|
| 27 |
+
--modes rl_only rl_heuristic rl_llm \
|
| 28 |
+
--wrapper-modes no_op target_only_soft current_legacy \
|
| 29 |
+
--split val \
|
| 30 |
+
--cities city_0001 \
|
| 31 |
+
--scenarios normal \
|
| 32 |
+
--seeds 7 11 13 \
|
| 33 |
+
--num-episodes 1 \
|
| 34 |
+
--max-episode-seconds 300 \
|
| 35 |
+
--guidance-refresh-steps 10 \
|
| 36 |
+
--guidance-persistence-steps 3 \
|
| 37 |
+
--bias-strength 0.12 \
|
| 38 |
+
--target-only-bias-strength 0.18 \
|
| 39 |
+
--corridor-bias-strength 0.05 \
|
| 40 |
+
--max-intersections-affected 3 \
|
| 41 |
+
--fallback-policy hold_previous \
|
| 42 |
+
--save-guidance-traces \
|
| 43 |
+
--output-dir artifacts/rl_guidance_eval/debug_matrix_300s
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
This expands into the paired comparison:
|
| 47 |
+
|
| 48 |
+
- `rl_only`
|
| 49 |
+
- `rl_heuristic+no_op`
|
| 50 |
+
- `rl_heuristic+target_only_soft`
|
| 51 |
+
- `rl_heuristic+current_legacy`
|
| 52 |
+
- `rl_llm+no_op`
|
| 53 |
+
- `rl_llm+target_only_soft`
|
| 54 |
+
- `rl_llm+current_legacy`
|
| 55 |
+
|
| 56 |
+
That command runs a superset of the exact smaller matrix from the wrapper audit prompt. Focus analysis on:
|
| 57 |
+
|
| 58 |
+
- `rl_only`
|
| 59 |
+
- `rl_heuristic+no_op`
|
| 60 |
+
- `rl_heuristic+target_only_soft`
|
| 61 |
+
- `rl_llm+no_op`
|
| 62 |
+
- `rl_llm+target_only_soft`
|
| 63 |
+
- `rl_llm+current_legacy`
|
| 64 |
+
|
| 65 |
+
## What To Look At
|
| 66 |
+
|
| 67 |
+
Primary files:
|
| 68 |
+
|
| 69 |
+
- `summary.json`
|
| 70 |
+
- `episode_metrics.csv`
|
| 71 |
+
- `guidance_traces.jsonl`
|
| 72 |
+
- `config.json`
|
| 73 |
+
|
| 74 |
+
Key wrapper metrics in `episode_metrics.csv`:
|
| 75 |
+
|
| 76 |
+
- `wrapper_mode`
|
| 77 |
+
- `mean_bias_magnitude`
|
| 78 |
+
- `max_bias_magnitude`
|
| 79 |
+
- `avg_num_targeted_intersections`
|
| 80 |
+
- `avg_num_affected_intersections`
|
| 81 |
+
- `percent_steps_with_active_guidance`
|
| 82 |
+
- `num_guidance_refreshes`
|
| 83 |
+
- `num_noop_guidance_events`
|
| 84 |
+
- `fallback_policy_used_count`
|
| 85 |
+
|
| 86 |
+
Interpretation:
|
| 87 |
+
|
| 88 |
+
- If `rl_heuristic+no_op` and `rl_llm+no_op` match `rl_only`, the harness itself is fine.
|
| 89 |
+
- If `current_legacy` collapses while `target_only_soft` stays near `rl_only`, the wrapper was too strong/global.
|
| 90 |
+
- If `rl_llm+target_only_soft` diverges from `rl_heuristic+target_only_soft`, the LLM is adding signal under safe integration.
|
| 91 |
+
- If `avg_num_affected_intersections` is large or `percent_steps_with_active_guidance` is near `1.0`, the wrapper is still too persistent or too broad.
|
| 92 |
+
- If `fallback_policy_used_count` stays high in `rl_llm`, inspect `guidance_traces.jsonl` before trusting traffic metrics.
|
| 93 |
+
|
| 94 |
+
## Cheap Follow-Up Ablations
|
| 95 |
+
|
| 96 |
+
Softer local prior:
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
--wrapper-modes no_op target_only_soft target_only_medium
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
Scope ablation:
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
--wrapper-modes target_only_soft corridor_soft global_soft current_legacy
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
More conservative persistence:
|
| 109 |
+
|
| 110 |
+
```bash
|
| 111 |
+
--guidance-refresh-steps 8 --guidance-persistence-steps 2
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
## Output Layout
|
| 115 |
+
|
| 116 |
+
Outputs are saved under the requested directory, for example:
|
| 117 |
+
|
| 118 |
+
```text
|
| 119 |
+
artifacts/rl_guidance_eval/debug_matrix_300s/
|
| 120 |
+
config.json
|
| 121 |
+
summary.json
|
| 122 |
+
episode_metrics.csv
|
| 123 |
+
episode_metrics.jsonl
|
| 124 |
+
guidance_traces.jsonl
|
| 125 |
+
seeded_configs/
|
| 126 |
+
```
|
district_llm/RL_LLM_WRAPPER_SWEEP_RUNBOOK.md
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RL+LLM Wrapper Sweep
|
| 2 |
+
|
| 3 |
+
This sweep keeps both checkpoints fixed:
|
| 4 |
+
|
| 5 |
+
- RL weights stay fixed.
|
| 6 |
+
- LLM weights stay fixed.
|
| 7 |
+
- Only the inference-time `target_only_soft` wrapper settings change.
|
| 8 |
+
|
| 9 |
+
## Recommended First Sweep
|
| 10 |
+
|
| 11 |
+
Run the default cheap preset on one city, one scenario, and three seeds:
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
python scripts/sweep_rl_llm_wrapper.py \
|
| 15 |
+
--rl-checkpoint artifacts/dqn_shared/best_validation.pt \
|
| 16 |
+
--llm-model-path artifacts/district_llm_adapter_v3/main_run/adapter \
|
| 17 |
+
--preset strength_targets_gating \
|
| 18 |
+
--split val \
|
| 19 |
+
--cities city_0001 \
|
| 20 |
+
--scenarios normal \
|
| 21 |
+
--seeds 7 11 13 \
|
| 22 |
+
--episodes-per-seed 1 \
|
| 23 |
+
--max-episode-seconds 300 \
|
| 24 |
+
--guidance-refresh-steps 10 \
|
| 25 |
+
--queue-threshold 150 \
|
| 26 |
+
--imbalance-threshold 20 \
|
| 27 |
+
--fallback-policy no_op \
|
| 28 |
+
--output-dir artifacts/rl_llm_wrapper_sweep/first_pass
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
This preset sweeps a small curated grid over:
|
| 32 |
+
|
| 33 |
+
- `bias_strength` in `{0.025, 0.05, 0.075}`
|
| 34 |
+
- `max_intersections_affected` in `{1, 2}`
|
| 35 |
+
- `gating_mode` in `{always_on, incident_or_spillback, queue_or_imbalance}`
|
| 36 |
+
- `guidance_persistence_steps = 5`
|
| 37 |
+
- `enable_bias_decay = false`
|
| 38 |
+
|
| 39 |
+
It also includes `baseline_current_soft` as a reference row.
|
| 40 |
+
|
| 41 |
+
## Cheaper Probe
|
| 42 |
+
|
| 43 |
+
If you only want the fastest possible read on strength sensitivity:
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
python scripts/sweep_rl_llm_wrapper.py \
|
| 47 |
+
--rl-checkpoint artifacts/dqn_shared/best_validation.pt \
|
| 48 |
+
--llm-model-path artifacts/district_llm_adapter_v3/main_run/adapter \
|
| 49 |
+
--preset strength_only \
|
| 50 |
+
--cities city_0001 \
|
| 51 |
+
--scenarios normal \
|
| 52 |
+
--seeds 7 11 13 \
|
| 53 |
+
--episodes-per-seed 1 \
|
| 54 |
+
--max-episode-seconds 300 \
|
| 55 |
+
--output-dir artifacts/rl_llm_wrapper_sweep/strength_only
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
## Broader Conservative Follow-Up
|
| 59 |
+
|
| 60 |
+
After the first pass identifies a promising strength/gating region:
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
python scripts/sweep_rl_llm_wrapper.py \
|
| 64 |
+
--rl-checkpoint artifacts/dqn_shared/best_validation.pt \
|
| 65 |
+
--llm-model-path artifacts/district_llm_adapter_v3/main_run/adapter \
|
| 66 |
+
--preset full_conservative \
|
| 67 |
+
--cities city_0001 \
|
| 68 |
+
--scenarios normal \
|
| 69 |
+
--seeds 7 11 13 \
|
| 70 |
+
--episodes-per-seed 1 \
|
| 71 |
+
--max-episode-seconds 300 \
|
| 72 |
+
--output-dir artifacts/rl_llm_wrapper_sweep/full_conservative
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
## Outputs
|
| 76 |
+
|
| 77 |
+
Each sweep writes:
|
| 78 |
+
|
| 79 |
+
- `config.json`
|
| 80 |
+
- `sweep_results.csv`
|
| 81 |
+
- `sweep_results.parquet` when parquet support is available
|
| 82 |
+
- `paired_episode_metrics.csv`
|
| 83 |
+
- `ranking.json`
|
| 84 |
+
- `summary_report.json`
|
| 85 |
+
- optional `step_metrics.*`
|
| 86 |
+
- optional `guidance_traces.jsonl`
|
| 87 |
+
|
| 88 |
+
## What To Inspect
|
| 89 |
+
|
| 90 |
+
Start with:
|
| 91 |
+
|
| 92 |
+
- `summary_report.json`
|
| 93 |
+
- `ranking.json`
|
| 94 |
+
- `paired_episode_metrics.csv`
|
| 95 |
+
|
| 96 |
+
Key fields:
|
| 97 |
+
|
| 98 |
+
- `mean_return_delta_vs_rl_only`
|
| 99 |
+
- `mean_throughput_delta_vs_rl_only`
|
| 100 |
+
- `mean_avg_queue_delta_vs_rl_only`
|
| 101 |
+
- `mean_avg_wait_delta_vs_rl_only`
|
| 102 |
+
- `mean_percent_steps_with_active_guidance`
|
| 103 |
+
- `mean_avg_num_affected_intersections`
|
| 104 |
+
- `mean_num_steps_guidance_blocked_by_gate`
|
| 105 |
+
|
| 106 |
+
## Interpretation
|
| 107 |
+
|
| 108 |
+
The most promising configs should usually look like:
|
| 109 |
+
|
| 110 |
+
- small negative or positive `mean_return_delta_vs_rl_only`
|
| 111 |
+
- low `mean_avg_num_affected_intersections`
|
| 112 |
+
- moderate or low `mean_percent_steps_with_active_guidance`
|
| 113 |
+
- low fallback / invalid guidance counts
|
| 114 |
+
|
| 115 |
+
If the best configs cluster around:
|
| 116 |
+
|
| 117 |
+
- lower `bias_strength`
|
| 118 |
+
- `max_intersections_affected = 1`
|
| 119 |
+
- gated modes like `incident_or_spillback` or `queue_or_imbalance`
|
| 120 |
+
|
| 121 |
+
then the wrapper was still too active and guidance needs to remain a rare local prior.
|
district_llm/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from district_llm.derivation import DistrictWindowData, LocalIntersectionAction, derive_district_action
|
| 2 |
+
from district_llm.prompting import build_system_prompt, format_district_prompt, format_sft_text
|
| 3 |
+
from district_llm.schema import CandidateIntersection, CongestedIntersection, DistrictAction, DistrictStateSummary
|
| 4 |
+
from district_llm.summary_builder import DistrictStateSummaryBuilder
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"CandidateIntersection",
|
| 8 |
+
"CongestedIntersection",
|
| 9 |
+
"DistrictAction",
|
| 10 |
+
"DistrictStateSummary",
|
| 11 |
+
"DistrictStateSummaryBuilder",
|
| 12 |
+
"DistrictWindowData",
|
| 13 |
+
"LocalIntersectionAction",
|
| 14 |
+
"derive_district_action",
|
| 15 |
+
"build_system_prompt",
|
| 16 |
+
"format_district_prompt",
|
| 17 |
+
"format_sft_text",
|
| 18 |
+
]
|
district_llm/data.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
def load_jsonl_text_dataset(
|
| 6 |
+
path: str | Path,
|
| 7 |
+
controller_families: list[str] | None = None,
|
| 8 |
+
controller_types: list[str] | None = None,
|
| 9 |
+
):
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
|
| 12 |
+
dataset = load_dataset("json", data_files=str(Path(path)), split="train")
|
| 13 |
+
if "text" not in dataset.column_names:
|
| 14 |
+
raise ValueError("Expected a JSONL dataset with a 'text' field.")
|
| 15 |
+
if controller_families:
|
| 16 |
+
allowed_families = set(controller_families)
|
| 17 |
+
dataset = dataset.filter(
|
| 18 |
+
lambda row: row.get("controller_family") in allowed_families
|
| 19 |
+
)
|
| 20 |
+
if controller_types:
|
| 21 |
+
allowed_types = set(controller_types)
|
| 22 |
+
dataset = dataset.filter(
|
| 23 |
+
lambda row: row.get("controller_type") in allowed_types
|
| 24 |
+
)
|
| 25 |
+
if len(dataset) == 0:
|
| 26 |
+
raise ValueError("No dataset rows remain after applying the requested filters.")
|
| 27 |
+
return dataset
|
district_llm/derivation.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from district_llm.repair import fallback_target_intersections
|
| 7 |
+
from district_llm.schema import DistrictAction, DistrictStateSummary
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class LocalIntersectionAction:
|
| 12 |
+
intersection_id: str
|
| 13 |
+
district_id: str
|
| 14 |
+
action: int
|
| 15 |
+
current_phase: int
|
| 16 |
+
next_phase: int
|
| 17 |
+
queue_total: float
|
| 18 |
+
wait_total: float
|
| 19 |
+
outgoing_load: float
|
| 20 |
+
is_boundary: bool
|
| 21 |
+
|
| 22 |
+
@property
|
| 23 |
+
def switched(self) -> bool:
|
| 24 |
+
return int(self.action) == 1 and self.next_phase != self.current_phase
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class DistrictWindowData:
|
| 29 |
+
district_id: str
|
| 30 |
+
start_summary: DistrictStateSummary
|
| 31 |
+
end_summary: DistrictStateSummary
|
| 32 |
+
controller_actions: list[LocalIntersectionAction] = field(default_factory=list)
|
| 33 |
+
step_count: int = 0
|
| 34 |
+
|
| 35 |
+
def to_dict(self) -> dict[str, Any]:
|
| 36 |
+
return {
|
| 37 |
+
"district_id": self.district_id,
|
| 38 |
+
"step_count": int(self.step_count),
|
| 39 |
+
"queue_delta": round(self.end_summary.total_queue - self.start_summary.total_queue, 3),
|
| 40 |
+
"wait_delta": round(self.end_summary.total_wait - self.start_summary.total_wait, 3),
|
| 41 |
+
"throughput_delta": round(
|
| 42 |
+
self.end_summary.recent_throughput - self.start_summary.recent_throughput,
|
| 43 |
+
3,
|
| 44 |
+
),
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def derive_district_action(
|
| 49 |
+
window_data: DistrictWindowData,
|
| 50 |
+
controller_actions: list[LocalIntersectionAction] | None = None,
|
| 51 |
+
district_state: DistrictStateSummary | None = None,
|
| 52 |
+
max_target_intersections: int = 3,
|
| 53 |
+
) -> DistrictAction:
|
| 54 |
+
"""
|
| 55 |
+
Deterministic first-pass label extraction from local-controller behavior.
|
| 56 |
+
|
| 57 |
+
Heuristic order:
|
| 58 |
+
1. Incident-heavy windows map to `incident_response`.
|
| 59 |
+
2. Strong spillback / boundary pressure maps to `clear_spillback`.
|
| 60 |
+
3. Rising boundary demand maps to `drain_inbound`.
|
| 61 |
+
4. Persistently high outgoing pressure maps to `drain_outbound`.
|
| 62 |
+
5. Boundary-heavy rush windows map to `arterial_priority`.
|
| 63 |
+
6. Clear NS/EW directional dominance maps to `favor_NS` / `favor_EW`.
|
| 64 |
+
7. Otherwise emit `hold`.
|
| 65 |
+
"""
|
| 66 |
+
actions = controller_actions if controller_actions is not None else window_data.controller_actions
|
| 67 |
+
state = district_state if district_state is not None else window_data.start_summary
|
| 68 |
+
end_state = window_data.end_summary
|
| 69 |
+
|
| 70 |
+
duration_steps = max(1, min(int(window_data.step_count or 1), 20))
|
| 71 |
+
phase_counts = {"NS": 0, "EW": 0}
|
| 72 |
+
focus_scores: dict[str, float] = {}
|
| 73 |
+
boundary_focus = 0
|
| 74 |
+
switch_count = 0
|
| 75 |
+
|
| 76 |
+
for item in actions:
|
| 77 |
+
phase_key = "NS" if int(item.next_phase) == 0 else "EW"
|
| 78 |
+
phase_counts[phase_key] += 1
|
| 79 |
+
switch_count += int(item.switched)
|
| 80 |
+
if item.is_boundary:
|
| 81 |
+
boundary_focus += 1
|
| 82 |
+
focus_scores[item.intersection_id] = focus_scores.get(item.intersection_id, 0.0) + (
|
| 83 |
+
item.queue_total + 1.5 * item.wait_total + 2.0 * float(item.switched)
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
total_action_records = max(1, len(actions))
|
| 87 |
+
ns_phase_ratio = phase_counts["NS"] / float(total_action_records)
|
| 88 |
+
ew_phase_ratio = phase_counts["EW"] / float(total_action_records)
|
| 89 |
+
boundary_focus_ratio = boundary_focus / float(total_action_records)
|
| 90 |
+
queue_delta = end_state.total_queue - state.total_queue
|
| 91 |
+
wait_delta = end_state.total_wait - state.total_wait
|
| 92 |
+
boundary_share = state.boundary_queue_total / max(1.0, state.total_queue)
|
| 93 |
+
outgoing_pressure = end_state.total_outgoing_load / max(1.0, end_state.total_queue)
|
| 94 |
+
|
| 95 |
+
if ns_phase_ratio > ew_phase_ratio + 0.1:
|
| 96 |
+
phase_bias = "NS"
|
| 97 |
+
elif ew_phase_ratio > ns_phase_ratio + 0.1:
|
| 98 |
+
phase_bias = "EW"
|
| 99 |
+
else:
|
| 100 |
+
phase_bias = "NONE"
|
| 101 |
+
|
| 102 |
+
if phase_bias == "NONE" and state.dominant_flow in {"NS", "EW"}:
|
| 103 |
+
phase_bias = state.dominant_flow
|
| 104 |
+
|
| 105 |
+
def select_targets(
|
| 106 |
+
strategy: str,
|
| 107 |
+
priority_corridor: str | None,
|
| 108 |
+
selected_phase_bias: str,
|
| 109 |
+
) -> list[str]:
|
| 110 |
+
return fallback_target_intersections(
|
| 111 |
+
summary=state,
|
| 112 |
+
max_target_intersections=max_target_intersections,
|
| 113 |
+
strategy=strategy,
|
| 114 |
+
priority_corridor=priority_corridor,
|
| 115 |
+
phase_bias=selected_phase_bias,
|
| 116 |
+
focus_scores=focus_scores,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
if state.incident_flag or end_state.incident_flag:
|
| 120 |
+
target_intersections = select_targets(
|
| 121 |
+
strategy="incident_response",
|
| 122 |
+
priority_corridor=phase_bias if phase_bias in {"NS", "EW"} else "arterial",
|
| 123 |
+
selected_phase_bias=phase_bias,
|
| 124 |
+
)
|
| 125 |
+
return DistrictAction(
|
| 126 |
+
strategy="incident_response",
|
| 127 |
+
priority_corridor=phase_bias if phase_bias in {"NS", "EW"} else "arterial",
|
| 128 |
+
target_intersections=target_intersections,
|
| 129 |
+
phase_bias=phase_bias,
|
| 130 |
+
duration_steps=duration_steps,
|
| 131 |
+
).validate()
|
| 132 |
+
|
| 133 |
+
if state.spillback_risk or end_state.spillback_risk or (boundary_share >= 0.55 and outgoing_pressure >= 0.45):
|
| 134 |
+
priority_corridor = "inbound" if boundary_share >= 0.55 else phase_bias if phase_bias in {"NS", "EW"} else None
|
| 135 |
+
target_intersections = select_targets(
|
| 136 |
+
strategy="clear_spillback",
|
| 137 |
+
priority_corridor=priority_corridor,
|
| 138 |
+
selected_phase_bias=phase_bias,
|
| 139 |
+
)
|
| 140 |
+
return DistrictAction(
|
| 141 |
+
strategy="clear_spillback",
|
| 142 |
+
priority_corridor=priority_corridor,
|
| 143 |
+
target_intersections=target_intersections,
|
| 144 |
+
phase_bias=phase_bias,
|
| 145 |
+
duration_steps=duration_steps,
|
| 146 |
+
).validate()
|
| 147 |
+
|
| 148 |
+
if boundary_share >= 0.55 and (queue_delta >= 0.0 or wait_delta >= 0.0):
|
| 149 |
+
target_intersections = select_targets(
|
| 150 |
+
strategy="drain_inbound",
|
| 151 |
+
priority_corridor="inbound",
|
| 152 |
+
selected_phase_bias=phase_bias,
|
| 153 |
+
)
|
| 154 |
+
return DistrictAction(
|
| 155 |
+
strategy="drain_inbound",
|
| 156 |
+
priority_corridor="inbound",
|
| 157 |
+
target_intersections=target_intersections,
|
| 158 |
+
phase_bias=phase_bias,
|
| 159 |
+
duration_steps=duration_steps,
|
| 160 |
+
).validate()
|
| 161 |
+
|
| 162 |
+
if outgoing_pressure >= 0.65 and end_state.total_queue >= state.total_queue * 0.9:
|
| 163 |
+
target_intersections = select_targets(
|
| 164 |
+
strategy="drain_outbound",
|
| 165 |
+
priority_corridor="outbound",
|
| 166 |
+
selected_phase_bias=phase_bias,
|
| 167 |
+
)
|
| 168 |
+
return DistrictAction(
|
| 169 |
+
strategy="drain_outbound",
|
| 170 |
+
priority_corridor="outbound",
|
| 171 |
+
target_intersections=target_intersections,
|
| 172 |
+
phase_bias=phase_bias,
|
| 173 |
+
duration_steps=duration_steps,
|
| 174 |
+
).validate()
|
| 175 |
+
|
| 176 |
+
if (
|
| 177 |
+
state.event_flag
|
| 178 |
+
or state.overload_flag
|
| 179 |
+
or end_state.overload_flag
|
| 180 |
+
or (boundary_focus_ratio >= 0.6 and switch_count >= max(2, duration_steps))
|
| 181 |
+
):
|
| 182 |
+
priority_corridor = phase_bias if phase_bias in {"NS", "EW"} else "arterial"
|
| 183 |
+
target_intersections = select_targets(
|
| 184 |
+
strategy="arterial_priority",
|
| 185 |
+
priority_corridor=priority_corridor,
|
| 186 |
+
selected_phase_bias=phase_bias,
|
| 187 |
+
)
|
| 188 |
+
return DistrictAction(
|
| 189 |
+
strategy="arterial_priority",
|
| 190 |
+
priority_corridor=priority_corridor,
|
| 191 |
+
target_intersections=target_intersections,
|
| 192 |
+
phase_bias=phase_bias,
|
| 193 |
+
duration_steps=duration_steps,
|
| 194 |
+
).validate()
|
| 195 |
+
|
| 196 |
+
ns_pressure = state.ns_queue + 1.5 * state.ns_wait
|
| 197 |
+
ew_pressure = state.ew_queue + 1.5 * state.ew_wait
|
| 198 |
+
imbalance_threshold = max(5.0, 0.15 * max(1.0, ns_pressure + ew_pressure))
|
| 199 |
+
|
| 200 |
+
if ns_pressure - ew_pressure >= imbalance_threshold:
|
| 201 |
+
target_intersections = select_targets(
|
| 202 |
+
strategy="favor_NS",
|
| 203 |
+
priority_corridor="NS",
|
| 204 |
+
selected_phase_bias="NS",
|
| 205 |
+
)
|
| 206 |
+
return DistrictAction(
|
| 207 |
+
strategy="favor_NS",
|
| 208 |
+
priority_corridor="NS",
|
| 209 |
+
target_intersections=target_intersections,
|
| 210 |
+
phase_bias="NS",
|
| 211 |
+
duration_steps=duration_steps,
|
| 212 |
+
).validate()
|
| 213 |
+
|
| 214 |
+
if ew_pressure - ns_pressure >= imbalance_threshold:
|
| 215 |
+
target_intersections = select_targets(
|
| 216 |
+
strategy="favor_EW",
|
| 217 |
+
priority_corridor="EW",
|
| 218 |
+
selected_phase_bias="EW",
|
| 219 |
+
)
|
| 220 |
+
return DistrictAction(
|
| 221 |
+
strategy="favor_EW",
|
| 222 |
+
priority_corridor="EW",
|
| 223 |
+
target_intersections=target_intersections,
|
| 224 |
+
phase_bias="EW",
|
| 225 |
+
duration_steps=duration_steps,
|
| 226 |
+
).validate()
|
| 227 |
+
|
| 228 |
+
return DistrictAction.default_hold(duration_steps=duration_steps)
|
district_llm/eval.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
from collections import Counter
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
from district_llm.metrics import aggregate_target_metrics, compute_target_metrics, safe_ratio, target_failure_buckets
|
| 10 |
+
from district_llm.repair import RepairConfig, extract_visible_candidate_ids, sanitize_action_payload
|
| 11 |
+
from district_llm.schema import DistrictAction
|
| 12 |
+
from env.utils import build_topology
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from tqdm.auto import tqdm
|
| 16 |
+
except ImportError: # pragma: no cover
|
| 17 |
+
tqdm = None
|
| 18 |
+
|
| 19 |
+
def parse_args() -> argparse.Namespace:
|
| 20 |
+
parser = argparse.ArgumentParser(
|
| 21 |
+
description="Offline evaluation for district-LLM outputs."
|
| 22 |
+
)
|
| 23 |
+
parser.add_argument("--model-path", required=True)
|
| 24 |
+
parser.add_argument("--val-jsonl", required=True)
|
| 25 |
+
parser.add_argument("--max-examples", type=int, default=200)
|
| 26 |
+
parser.add_argument("--debug-examples", type=int, default=10)
|
| 27 |
+
parser.add_argument("--max-new-tokens", type=int, default=128)
|
| 28 |
+
parser.add_argument("--device", default=None)
|
| 29 |
+
parser.add_argument("--generated-root", default="data/generated")
|
| 30 |
+
parser.add_argument("--restrict-targets-to-visible-summary", action="store_true")
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--allow-only-visible-candidates",
|
| 33 |
+
action=argparse.BooleanOptionalAction,
|
| 34 |
+
default=True,
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument("--max-target-intersections", type=int, default=3)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--fallback-on-empty-targets",
|
| 39 |
+
action=argparse.BooleanOptionalAction,
|
| 40 |
+
default=True,
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--fallback-mode",
|
| 44 |
+
choices=("heuristic", "hold", "none"),
|
| 45 |
+
default="heuristic",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--report-before-after-repair",
|
| 49 |
+
action=argparse.BooleanOptionalAction,
|
| 50 |
+
default=True,
|
| 51 |
+
)
|
| 52 |
+
return parser.parse_args()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def load_rows(path: str | Path, max_examples: int | None = None) -> list[dict[str, Any]]:
|
| 56 |
+
rows = []
|
| 57 |
+
with Path(path).open("r", encoding="utf-8") as handle:
|
| 58 |
+
for line in handle:
|
| 59 |
+
if not line.strip():
|
| 60 |
+
continue
|
| 61 |
+
rows.append(json.loads(line))
|
| 62 |
+
if max_examples is not None and len(rows) >= max_examples:
|
| 63 |
+
break
|
| 64 |
+
return rows
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def extract_json_object(payload: str) -> str:
|
| 68 |
+
start = payload.find("{")
|
| 69 |
+
end = payload.rfind("}")
|
| 70 |
+
if start == -1 or end == -1 or end <= start:
|
| 71 |
+
raise ValueError("No JSON object found.")
|
| 72 |
+
return payload[start : end + 1]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def load_model_and_tokenizer(model_path: str, device: str | None = None):
|
| 76 |
+
import torch
|
| 77 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 78 |
+
|
| 79 |
+
model_dir = Path(model_path)
|
| 80 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 81 |
+
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
| 82 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 83 |
+
|
| 84 |
+
if (model_dir / "adapter_config.json").exists():
|
| 85 |
+
try:
|
| 86 |
+
from peft import AutoPeftModelForCausalLM
|
| 87 |
+
except ImportError as exc:
|
| 88 |
+
raise ImportError(
|
| 89 |
+
"Evaluating a LoRA adapter requires the 'peft' package."
|
| 90 |
+
) from exc
|
| 91 |
+
model = AutoPeftModelForCausalLM.from_pretrained(model_path)
|
| 92 |
+
else:
|
| 93 |
+
target_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 94 |
+
model = AutoModelForCausalLM.from_pretrained(model_path).to(target_device)
|
| 95 |
+
model.eval()
|
| 96 |
+
return model, tokenizer
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def build_generation_prompt(tokenizer, messages: list[dict[str, str]]) -> str:
|
| 100 |
+
if getattr(tokenizer, "chat_template", None):
|
| 101 |
+
return tokenizer.apply_chat_template(
|
| 102 |
+
messages,
|
| 103 |
+
tokenize=False,
|
| 104 |
+
add_generation_prompt=True,
|
| 105 |
+
)
|
| 106 |
+
return "\n".join(f"{message['role']}: {message['content']}" for message in messages) + "\nassistant:"
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def generate_response(model, tokenizer, messages: list[dict[str, str]], max_new_tokens: int) -> str:
|
| 110 |
+
import torch
|
| 111 |
+
|
| 112 |
+
prompt = build_generation_prompt(tokenizer, messages)
|
| 113 |
+
device = getattr(model, "device", None)
|
| 114 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 115 |
+
if device is not None:
|
| 116 |
+
inputs = {key: value.to(device) for key, value in inputs.items()}
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
outputs = model.generate(
|
| 119 |
+
**inputs,
|
| 120 |
+
max_new_tokens=max_new_tokens,
|
| 121 |
+
do_sample=False,
|
| 122 |
+
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 123 |
+
)
|
| 124 |
+
generated = outputs[0][inputs["input_ids"].shape[1] :]
|
| 125 |
+
return tokenizer.decode(generated, skip_special_tokens=True)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def parse_prediction(payload: str) -> tuple[bool, bool, dict[str, Any] | None]:
|
| 129 |
+
try:
|
| 130 |
+
json_payload = json.loads(extract_json_object(payload))
|
| 131 |
+
except Exception:
|
| 132 |
+
return False, False, None
|
| 133 |
+
try:
|
| 134 |
+
action = DistrictAction.from_dict(json_payload)
|
| 135 |
+
except Exception:
|
| 136 |
+
return True, False, json_payload
|
| 137 |
+
return True, True, action.to_dict()
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class DistrictTopologyIndex:
|
| 141 |
+
def __init__(self, generated_root: str | Path):
|
| 142 |
+
self.generated_root = Path(generated_root)
|
| 143 |
+
self._cache: dict[str, dict[str, set[str]]] = {}
|
| 144 |
+
|
| 145 |
+
def district_intersections(self, city_id: str, district_id: str) -> set[str]:
|
| 146 |
+
if city_id not in self._cache:
|
| 147 |
+
roadnet_path = self.generated_root / city_id / "roadnet.json"
|
| 148 |
+
district_map_path = self.generated_root / city_id / "district_map.json"
|
| 149 |
+
metadata_path = self.generated_root / city_id / "metadata.json"
|
| 150 |
+
_, districts = build_topology(
|
| 151 |
+
roadnet_path=roadnet_path,
|
| 152 |
+
district_map_path=district_map_path,
|
| 153 |
+
metadata_path=metadata_path,
|
| 154 |
+
)
|
| 155 |
+
self._cache[city_id] = {
|
| 156 |
+
key: set(value.intersection_ids)
|
| 157 |
+
for key, value in districts.items()
|
| 158 |
+
}
|
| 159 |
+
return self._cache[city_id].get(district_id, set())
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def field_accuracy(pred: dict[str, Any] | None, gt: dict[str, Any], field: str) -> float:
|
| 163 |
+
if pred is None:
|
| 164 |
+
return 0.0
|
| 165 |
+
return float(pred.get(field) == gt.get(field))
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def invalid_target_fraction(pred_targets: list[str], district_candidates: set[str]) -> float:
|
| 169 |
+
if not pred_targets:
|
| 170 |
+
return 0.0
|
| 171 |
+
invalid_count = sum(1 for item in pred_targets if item not in district_candidates)
|
| 172 |
+
return safe_ratio(invalid_count, len(pred_targets))
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def evaluate_rows(
|
| 176 |
+
rows: list[dict[str, Any]],
|
| 177 |
+
model,
|
| 178 |
+
tokenizer,
|
| 179 |
+
max_new_tokens: int,
|
| 180 |
+
topology_index: DistrictTopologyIndex,
|
| 181 |
+
restrict_targets_to_visible_summary: bool,
|
| 182 |
+
debug_examples: int,
|
| 183 |
+
repair_config: RepairConfig,
|
| 184 |
+
report_before_after_repair: bool,
|
| 185 |
+
) -> dict[str, Any]:
|
| 186 |
+
json_valid_count = 0
|
| 187 |
+
schema_valid_count = 0
|
| 188 |
+
field_totals_before = Counter()
|
| 189 |
+
field_totals_after = Counter()
|
| 190 |
+
full_object_correct_before = 0
|
| 191 |
+
full_object_correct_after = 0
|
| 192 |
+
target_rows_before: list[dict[str, float]] = []
|
| 193 |
+
target_rows_after: list[dict[str, float]] = []
|
| 194 |
+
restricted_target_rows_before: list[dict[str, float]] = []
|
| 195 |
+
restricted_target_rows_after: list[dict[str, float]] = []
|
| 196 |
+
invalid_rates_before: list[float] = []
|
| 197 |
+
invalid_rates_after: list[float] = []
|
| 198 |
+
fallback_used_count = 0
|
| 199 |
+
failure_buckets = Counter()
|
| 200 |
+
debug_rows = []
|
| 201 |
+
|
| 202 |
+
progress = (
|
| 203 |
+
tqdm(total=len(rows), desc="eval", dynamic_ncols=True)
|
| 204 |
+
if tqdm is not None
|
| 205 |
+
else None
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
try:
|
| 209 |
+
for row in rows:
|
| 210 |
+
messages = row["messages"]
|
| 211 |
+
ground_truth = json.loads(messages[2]["content"])
|
| 212 |
+
raw_prediction = generate_response(
|
| 213 |
+
model=model,
|
| 214 |
+
tokenizer=tokenizer,
|
| 215 |
+
messages=messages[:2],
|
| 216 |
+
max_new_tokens=max_new_tokens,
|
| 217 |
+
)
|
| 218 |
+
json_valid, schema_valid, prediction_before = parse_prediction(raw_prediction)
|
| 219 |
+
repaired_action, repair_report = sanitize_action_payload(
|
| 220 |
+
payload=prediction_before if json_valid else None,
|
| 221 |
+
summary=row,
|
| 222 |
+
prompt_text=messages[1]["content"],
|
| 223 |
+
config=repair_config,
|
| 224 |
+
)
|
| 225 |
+
prediction_after = repaired_action.to_dict()
|
| 226 |
+
json_valid_count += int(json_valid)
|
| 227 |
+
schema_valid_count += int(schema_valid)
|
| 228 |
+
fallback_used_count += int(repair_report.fallback_used)
|
| 229 |
+
|
| 230 |
+
field_totals_before["strategy"] += field_accuracy(prediction_before, ground_truth, "strategy")
|
| 231 |
+
field_totals_before["priority_corridor"] += field_accuracy(prediction_before, ground_truth, "priority_corridor")
|
| 232 |
+
field_totals_before["phase_bias"] += field_accuracy(prediction_before, ground_truth, "phase_bias")
|
| 233 |
+
field_totals_before["duration_steps"] += field_accuracy(prediction_before, ground_truth, "duration_steps")
|
| 234 |
+
|
| 235 |
+
field_totals_after["strategy"] += field_accuracy(prediction_after, ground_truth, "strategy")
|
| 236 |
+
field_totals_after["priority_corridor"] += field_accuracy(prediction_after, ground_truth, "priority_corridor")
|
| 237 |
+
field_totals_after["phase_bias"] += field_accuracy(prediction_after, ground_truth, "phase_bias")
|
| 238 |
+
field_totals_after["duration_steps"] += field_accuracy(prediction_after, ground_truth, "duration_steps")
|
| 239 |
+
|
| 240 |
+
if prediction_before == ground_truth:
|
| 241 |
+
full_object_correct_before += 1
|
| 242 |
+
if prediction_after == ground_truth:
|
| 243 |
+
full_object_correct_after += 1
|
| 244 |
+
|
| 245 |
+
pred_targets_before = [] if prediction_before is None else list(prediction_before.get("target_intersections", []))
|
| 246 |
+
pred_targets_after = list(prediction_after.get("target_intersections", []))
|
| 247 |
+
gt_targets = list(ground_truth.get("target_intersections", []))
|
| 248 |
+
visible_candidates = set(
|
| 249 |
+
extract_visible_candidate_ids(summary=row, prompt_text=messages[1]["content"])
|
| 250 |
+
)
|
| 251 |
+
district_candidates = topology_index.district_intersections(
|
| 252 |
+
city_id=row["city_id"],
|
| 253 |
+
district_id=row["district_id"],
|
| 254 |
+
)
|
| 255 |
+
invalid_before = [item for item in pred_targets_before if item not in district_candidates]
|
| 256 |
+
invalid_after = [item for item in pred_targets_after if item not in district_candidates]
|
| 257 |
+
non_visible_before = [
|
| 258 |
+
item for item in pred_targets_before
|
| 259 |
+
if visible_candidates and item not in visible_candidates
|
| 260 |
+
]
|
| 261 |
+
|
| 262 |
+
metrics_before = compute_target_metrics(pred_targets_before, gt_targets)
|
| 263 |
+
metrics_after = compute_target_metrics(pred_targets_after, gt_targets)
|
| 264 |
+
target_rows_before.append(metrics_before)
|
| 265 |
+
target_rows_after.append(metrics_after)
|
| 266 |
+
invalid_rates_before.append(invalid_target_fraction(pred_targets_before, district_candidates))
|
| 267 |
+
invalid_rates_after.append(invalid_target_fraction(pred_targets_after, district_candidates))
|
| 268 |
+
|
| 269 |
+
if restrict_targets_to_visible_summary:
|
| 270 |
+
filtered_pred_before = [item for item in pred_targets_before if item in visible_candidates]
|
| 271 |
+
filtered_pred_after = [item for item in pred_targets_after if item in visible_candidates]
|
| 272 |
+
filtered_gt = [item for item in gt_targets if item in visible_candidates]
|
| 273 |
+
restricted_target_rows_before.append(
|
| 274 |
+
compute_target_metrics(filtered_pred_before, filtered_gt)
|
| 275 |
+
)
|
| 276 |
+
restricted_target_rows_after.append(
|
| 277 |
+
compute_target_metrics(filtered_pred_after, filtered_gt)
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
for failure_bucket in set(
|
| 281 |
+
target_failure_buckets(
|
| 282 |
+
pred_list=pred_targets_before,
|
| 283 |
+
gt_list=gt_targets,
|
| 284 |
+
visible_candidates=visible_candidates,
|
| 285 |
+
invalid_ids=invalid_before,
|
| 286 |
+
non_visible_ids=non_visible_before,
|
| 287 |
+
repaired_targets=pred_targets_after,
|
| 288 |
+
fallback_used=repair_report.fallback_used,
|
| 289 |
+
)
|
| 290 |
+
):
|
| 291 |
+
failure_buckets[failure_bucket] += 1
|
| 292 |
+
|
| 293 |
+
if len(debug_rows) < debug_examples:
|
| 294 |
+
debug_rows.append(
|
| 295 |
+
{
|
| 296 |
+
"district_summary": messages[1]["content"],
|
| 297 |
+
"predicted_json_raw": raw_prediction,
|
| 298 |
+
"predicted_json_parsed_before_repair": prediction_before,
|
| 299 |
+
"predicted_json_parsed_after_repair": prediction_after,
|
| 300 |
+
"ground_truth_json": ground_truth,
|
| 301 |
+
"target_intersections_metrics_before_repair": metrics_before,
|
| 302 |
+
"target_intersections_metrics_after_repair": metrics_after,
|
| 303 |
+
"repair_report": repair_report.to_dict(),
|
| 304 |
+
"visible_candidate_ids": sorted(visible_candidates),
|
| 305 |
+
"failure_buckets": sorted(
|
| 306 |
+
set(
|
| 307 |
+
target_failure_buckets(
|
| 308 |
+
pred_list=pred_targets_before,
|
| 309 |
+
gt_list=gt_targets,
|
| 310 |
+
visible_candidates=visible_candidates,
|
| 311 |
+
invalid_ids=invalid_before,
|
| 312 |
+
non_visible_ids=non_visible_before,
|
| 313 |
+
repaired_targets=pred_targets_after,
|
| 314 |
+
fallback_used=repair_report.fallback_used,
|
| 315 |
+
)
|
| 316 |
+
)
|
| 317 |
+
),
|
| 318 |
+
}
|
| 319 |
+
)
|
| 320 |
+
if progress is not None:
|
| 321 |
+
progress.update(1)
|
| 322 |
+
finally:
|
| 323 |
+
if progress is not None:
|
| 324 |
+
progress.close()
|
| 325 |
+
|
| 326 |
+
total_rows = max(1, len(rows))
|
| 327 |
+
results = {
|
| 328 |
+
"num_examples": len(rows),
|
| 329 |
+
"json_validity_rate": float(json_valid_count) / total_rows,
|
| 330 |
+
"schema_validity_rate": float(schema_valid_count) / total_rows,
|
| 331 |
+
"field_accuracy": {
|
| 332 |
+
"strategy": float(field_totals_before["strategy"]) / total_rows,
|
| 333 |
+
"priority_corridor": float(field_totals_before["priority_corridor"]) / total_rows,
|
| 334 |
+
"phase_bias": float(field_totals_before["phase_bias"]) / total_rows,
|
| 335 |
+
"duration_steps": float(field_totals_before["duration_steps"]) / total_rows,
|
| 336 |
+
},
|
| 337 |
+
"field_accuracy_after_repair": {
|
| 338 |
+
"strategy": float(field_totals_after["strategy"]) / total_rows,
|
| 339 |
+
"priority_corridor": float(field_totals_after["priority_corridor"]) / total_rows,
|
| 340 |
+
"phase_bias": float(field_totals_after["phase_bias"]) / total_rows,
|
| 341 |
+
"duration_steps": float(field_totals_after["duration_steps"]) / total_rows,
|
| 342 |
+
},
|
| 343 |
+
"target_intersections_before_repair": aggregate_target_metrics(target_rows_before),
|
| 344 |
+
"target_intersections_after_repair": aggregate_target_metrics(target_rows_after),
|
| 345 |
+
"target_intersections": aggregate_target_metrics(target_rows_after),
|
| 346 |
+
"target_intersections_failure_buckets": dict(sorted(failure_buckets.items())),
|
| 347 |
+
"exact_full_object_accuracy": float(full_object_correct_before) / total_rows,
|
| 348 |
+
"exact_full_object_accuracy_after_repair": float(full_object_correct_after) / total_rows,
|
| 349 |
+
"debug_examples": debug_rows,
|
| 350 |
+
}
|
| 351 |
+
if restrict_targets_to_visible_summary:
|
| 352 |
+
results["target_intersections_restricted_to_visible_summary_before_repair"] = aggregate_target_metrics(
|
| 353 |
+
restricted_target_rows_before
|
| 354 |
+
)
|
| 355 |
+
results["target_intersections_restricted_to_visible_summary_after_repair"] = aggregate_target_metrics(
|
| 356 |
+
restricted_target_rows_after
|
| 357 |
+
)
|
| 358 |
+
results["target_intersections_restricted_to_visible_summary"] = aggregate_target_metrics(
|
| 359 |
+
restricted_target_rows_after
|
| 360 |
+
)
|
| 361 |
+
if report_before_after_repair:
|
| 362 |
+
results["target_intersections_before_after_repair"] = {
|
| 363 |
+
"invalid_id_rate_before_repair": float(sum(invalid_rates_before) / total_rows),
|
| 364 |
+
"invalid_id_rate_after_repair": float(sum(invalid_rates_after) / total_rows),
|
| 365 |
+
"exact_set_match_before_repair": aggregate_target_metrics(target_rows_before).get("exact_set_match", 0.0),
|
| 366 |
+
"exact_set_match_after_repair": aggregate_target_metrics(target_rows_after).get("exact_set_match", 0.0),
|
| 367 |
+
"jaccard_before_repair": aggregate_target_metrics(target_rows_before).get("jaccard", 0.0),
|
| 368 |
+
"jaccard_after_repair": aggregate_target_metrics(target_rows_after).get("jaccard", 0.0),
|
| 369 |
+
"fallback_used_rate": float(fallback_used_count) / total_rows,
|
| 370 |
+
}
|
| 371 |
+
return results
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def print_debug_examples(debug_rows: list[dict[str, Any]]) -> None:
|
| 375 |
+
for index, item in enumerate(debug_rows, start=1):
|
| 376 |
+
print(f"[debug {index}] district_summary:")
|
| 377 |
+
print(item["district_summary"])
|
| 378 |
+
print(f"[debug {index}] predicted_json_raw={item['predicted_json_raw']}")
|
| 379 |
+
print(
|
| 380 |
+
f"[debug {index}] predicted_json_parsed_before_repair="
|
| 381 |
+
f"{json.dumps(item['predicted_json_parsed_before_repair'], sort_keys=True)}"
|
| 382 |
+
)
|
| 383 |
+
print(
|
| 384 |
+
f"[debug {index}] predicted_json_parsed_after_repair="
|
| 385 |
+
f"{json.dumps(item['predicted_json_parsed_after_repair'], sort_keys=True)}"
|
| 386 |
+
)
|
| 387 |
+
print(
|
| 388 |
+
f"[debug {index}] ground_truth_json="
|
| 389 |
+
f"{json.dumps(item['ground_truth_json'], sort_keys=True)}"
|
| 390 |
+
)
|
| 391 |
+
print(
|
| 392 |
+
f"[debug {index}] target_intersections_metrics_before_repair="
|
| 393 |
+
f"{json.dumps(item['target_intersections_metrics_before_repair'], sort_keys=True)}"
|
| 394 |
+
)
|
| 395 |
+
print(
|
| 396 |
+
f"[debug {index}] target_intersections_metrics_after_repair="
|
| 397 |
+
f"{json.dumps(item['target_intersections_metrics_after_repair'], sort_keys=True)}"
|
| 398 |
+
)
|
| 399 |
+
print(
|
| 400 |
+
f"[debug {index}] repair_report="
|
| 401 |
+
f"{json.dumps(item['repair_report'], sort_keys=True)}"
|
| 402 |
+
)
|
| 403 |
+
print(
|
| 404 |
+
f"[debug {index}] visible_candidate_ids="
|
| 405 |
+
f"{json.dumps(item['visible_candidate_ids'], sort_keys=True)}"
|
| 406 |
+
)
|
| 407 |
+
print(f"[debug {index}] failure_buckets={json.dumps(item['failure_buckets'])}")
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def main() -> None:
|
| 411 |
+
args = parse_args()
|
| 412 |
+
rows = load_rows(args.val_jsonl, max_examples=args.max_examples)
|
| 413 |
+
model, tokenizer = load_model_and_tokenizer(args.model_path, device=args.device)
|
| 414 |
+
topology_index = DistrictTopologyIndex(args.generated_root)
|
| 415 |
+
results = evaluate_rows(
|
| 416 |
+
rows=rows,
|
| 417 |
+
model=model,
|
| 418 |
+
tokenizer=tokenizer,
|
| 419 |
+
max_new_tokens=args.max_new_tokens,
|
| 420 |
+
topology_index=topology_index,
|
| 421 |
+
restrict_targets_to_visible_summary=args.restrict_targets_to_visible_summary,
|
| 422 |
+
debug_examples=args.debug_examples,
|
| 423 |
+
repair_config=RepairConfig(
|
| 424 |
+
allow_only_visible_candidates=args.allow_only_visible_candidates,
|
| 425 |
+
max_target_intersections=args.max_target_intersections,
|
| 426 |
+
fallback_on_empty_targets=args.fallback_on_empty_targets,
|
| 427 |
+
fallback_mode=args.fallback_mode,
|
| 428 |
+
),
|
| 429 |
+
report_before_after_repair=args.report_before_after_repair,
|
| 430 |
+
)
|
| 431 |
+
print(json.dumps({k: v for k, v in results.items() if k != "debug_examples"}, indent=2, sort_keys=True))
|
| 432 |
+
print_debug_examples(results["debug_examples"])
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
if __name__ == "__main__":
|
| 436 |
+
main()
|
district_llm/generate_dataset.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import random
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from district_llm.derivation import DistrictWindowData, LocalIntersectionAction, derive_district_action
|
| 13 |
+
from district_llm.prompting import format_district_prompt, format_sft_text
|
| 14 |
+
from district_llm.summary_builder import DistrictStateSummaryBuilder
|
| 15 |
+
from district_llm.teachers import BaseTeacher, build_teacher, parse_teacher_spec
|
| 16 |
+
from env.observation_builder import ObservationConfig
|
| 17 |
+
from env.reward import RewardConfig
|
| 18 |
+
from env.traffic_env import EnvConfig, TrafficEnv
|
| 19 |
+
from training.cityflow_dataset import CityFlowDataset, ScenarioSpec
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class _WindowBuffer:
|
| 24 |
+
start_summary: Any
|
| 25 |
+
controller_actions: list[LocalIntersectionAction] = field(default_factory=list)
|
| 26 |
+
step_count: int = 0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def parse_args() -> argparse.Namespace:
|
| 30 |
+
parser = argparse.ArgumentParser(
|
| 31 |
+
description="Generate district-LLM SFT data from CityFlow rollouts."
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--controller",
|
| 35 |
+
default="queue_greedy",
|
| 36 |
+
choices=("rl_checkpoint", "hold", "fixed", "random", "queue_greedy"),
|
| 37 |
+
help="Single controller source used when --teacher-spec is not provided.",
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument("--checkpoint", default=None)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--teacher-spec",
|
| 42 |
+
action="append",
|
| 43 |
+
default=[],
|
| 44 |
+
help="Repeatable source spec, e.g. rl_checkpoint=artifacts/dqn_shared/best_validation.pt or fixed.",
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument("--episodes", type=int, default=10)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--decision-interval",
|
| 49 |
+
"--district-decision-interval",
|
| 50 |
+
dest="district_decision_interval",
|
| 51 |
+
type=int,
|
| 52 |
+
default=10,
|
| 53 |
+
help="District-LLM decision interval in local-controller decision steps.",
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument("--output", required=True)
|
| 56 |
+
parser.add_argument("--generated-root", default="data/generated")
|
| 57 |
+
parser.add_argument("--splits-root", default="data/splits")
|
| 58 |
+
parser.add_argument("--split", default="train", choices=("train", "val", "test"))
|
| 59 |
+
parser.add_argument("--city-id", default=None)
|
| 60 |
+
parser.add_argument("--scenario-name", default=None)
|
| 61 |
+
parser.add_argument("--seed", type=int, default=7)
|
| 62 |
+
parser.add_argument("--fixed-green-time", type=int, default=20)
|
| 63 |
+
parser.add_argument("--device", default=None)
|
| 64 |
+
parser.add_argument("--append", action="store_true")
|
| 65 |
+
parser.add_argument("--top-k-congested", type=int, default=3)
|
| 66 |
+
parser.add_argument("--max-candidate-intersections", type=int, default=6)
|
| 67 |
+
parser.add_argument("--max-target-intersections", type=int, default=3)
|
| 68 |
+
parser.add_argument("--use-checkpoint-env-config", action="store_true")
|
| 69 |
+
|
| 70 |
+
parser.add_argument("--env-decision-interval", type=int, default=5)
|
| 71 |
+
parser.add_argument("--simulator-interval", type=int, default=1)
|
| 72 |
+
parser.add_argument("--min-green-time", type=int, default=10)
|
| 73 |
+
parser.add_argument("--thread-num", type=int, default=1)
|
| 74 |
+
parser.add_argument("--max-episode-seconds", type=int, default=None)
|
| 75 |
+
parser.add_argument("--max-incoming-lanes", type=int, default=16)
|
| 76 |
+
parser.add_argument("--count-scale", type=float, default=20.0)
|
| 77 |
+
parser.add_argument("--elapsed-time-scale", type=float, default=60.0)
|
| 78 |
+
parser.add_argument("--disable-district-context", action="store_true")
|
| 79 |
+
parser.add_argument("--disable-outgoing-congestion", action="store_true")
|
| 80 |
+
parser.add_argument("--reward-variant", default="wait_queue_throughput")
|
| 81 |
+
parser.add_argument("--waiting-weight", type=float, default=1.0)
|
| 82 |
+
parser.add_argument("--vehicle-weight", type=float, default=0.1)
|
| 83 |
+
parser.add_argument("--pressure-weight", type=float, default=0.0)
|
| 84 |
+
parser.add_argument("--reward-scale", type=float, default=0.1)
|
| 85 |
+
parser.add_argument("--disable-lane-reward-normalization", action="store_true")
|
| 86 |
+
parser.add_argument("--reward-clip", type=float, default=5.0)
|
| 87 |
+
parser.add_argument("--queue-delta-weight", type=float, default=2.0)
|
| 88 |
+
parser.add_argument("--wait-delta-weight", type=float, default=4.0)
|
| 89 |
+
parser.add_argument("--queue-level-weight", type=float, default=0.5)
|
| 90 |
+
parser.add_argument("--wait-level-weight", type=float, default=1.0)
|
| 91 |
+
parser.add_argument("--throughput-weight", type=float, default=0.1)
|
| 92 |
+
parser.add_argument("--imbalance-weight", type=float, default=0.1)
|
| 93 |
+
parser.add_argument("--reward-delta-clip", type=float, default=2.0)
|
| 94 |
+
parser.add_argument("--reward-level-normalizer", type=float, default=10.0)
|
| 95 |
+
parser.add_argument("--throughput-normalizer", type=float, default=2.0)
|
| 96 |
+
return parser.parse_args()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def build_env_config(args: argparse.Namespace) -> EnvConfig:
|
| 100 |
+
return EnvConfig(
|
| 101 |
+
simulator_interval=args.simulator_interval,
|
| 102 |
+
decision_interval=args.env_decision_interval,
|
| 103 |
+
min_green_time=args.min_green_time,
|
| 104 |
+
thread_num=args.thread_num,
|
| 105 |
+
max_episode_seconds=args.max_episode_seconds,
|
| 106 |
+
observation=ObservationConfig(
|
| 107 |
+
max_incoming_lanes=args.max_incoming_lanes,
|
| 108 |
+
count_scale=args.count_scale,
|
| 109 |
+
elapsed_time_scale=args.elapsed_time_scale,
|
| 110 |
+
include_outgoing_congestion=not args.disable_outgoing_congestion,
|
| 111 |
+
include_district_context=not args.disable_district_context,
|
| 112 |
+
include_district_type_feature=True,
|
| 113 |
+
),
|
| 114 |
+
reward=RewardConfig(
|
| 115 |
+
variant=args.reward_variant,
|
| 116 |
+
waiting_weight=args.waiting_weight,
|
| 117 |
+
vehicle_weight=args.vehicle_weight,
|
| 118 |
+
pressure_weight=args.pressure_weight,
|
| 119 |
+
reward_scale=args.reward_scale,
|
| 120 |
+
normalize_by_lane_count=not args.disable_lane_reward_normalization,
|
| 121 |
+
clip_reward=args.reward_clip,
|
| 122 |
+
queue_delta_weight=args.queue_delta_weight,
|
| 123 |
+
wait_delta_weight=args.wait_delta_weight,
|
| 124 |
+
queue_level_weight=args.queue_level_weight,
|
| 125 |
+
wait_level_weight=args.wait_level_weight,
|
| 126 |
+
throughput_weight=args.throughput_weight,
|
| 127 |
+
imbalance_weight=args.imbalance_weight,
|
| 128 |
+
delta_clip=args.reward_delta_clip,
|
| 129 |
+
level_normalizer=args.reward_level_normalizer,
|
| 130 |
+
throughput_normalizer=args.throughput_normalizer,
|
| 131 |
+
),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def build_env(env_config: EnvConfig, scenario_spec: ScenarioSpec) -> TrafficEnv:
|
| 136 |
+
return TrafficEnv(
|
| 137 |
+
city_id=scenario_spec.city_id,
|
| 138 |
+
scenario_name=scenario_spec.scenario_name,
|
| 139 |
+
city_dir=scenario_spec.city_dir,
|
| 140 |
+
scenario_dir=scenario_spec.scenario_dir,
|
| 141 |
+
config_path=scenario_spec.config_path,
|
| 142 |
+
roadnet_path=scenario_spec.roadnet_path,
|
| 143 |
+
district_map_path=scenario_spec.district_map_path,
|
| 144 |
+
metadata_path=scenario_spec.metadata_path,
|
| 145 |
+
env_config=env_config,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def resolve_teachers(args: argparse.Namespace) -> list[BaseTeacher]:
|
| 150 |
+
teacher_specs = list(args.teacher_spec)
|
| 151 |
+
if not teacher_specs:
|
| 152 |
+
teacher_specs = [args.controller if args.controller != "rl_checkpoint" else f"rl_checkpoint={args.checkpoint}"]
|
| 153 |
+
|
| 154 |
+
teachers = []
|
| 155 |
+
for spec in teacher_specs:
|
| 156 |
+
controller_type, checkpoint = parse_teacher_spec(spec)
|
| 157 |
+
if controller_type == "rl_checkpoint":
|
| 158 |
+
checkpoint = checkpoint or args.checkpoint
|
| 159 |
+
teachers.append(
|
| 160 |
+
build_teacher(
|
| 161 |
+
controller_type=controller_type,
|
| 162 |
+
checkpoint=checkpoint,
|
| 163 |
+
fixed_green_time=args.fixed_green_time,
|
| 164 |
+
seed=args.seed,
|
| 165 |
+
device=args.device,
|
| 166 |
+
)
|
| 167 |
+
)
|
| 168 |
+
return teachers
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def resolve_env_config(args: argparse.Namespace, teachers: list[BaseTeacher]) -> EnvConfig:
|
| 172 |
+
env_config = build_env_config(args)
|
| 173 |
+
if not args.use_checkpoint_env_config:
|
| 174 |
+
return env_config
|
| 175 |
+
|
| 176 |
+
checkpoint_env_configs = [
|
| 177 |
+
teacher.env_config for teacher in teachers if teacher.env_config is not None
|
| 178 |
+
]
|
| 179 |
+
if not checkpoint_env_configs:
|
| 180 |
+
return env_config
|
| 181 |
+
|
| 182 |
+
first_payload = checkpoint_env_configs[0]
|
| 183 |
+
assert first_payload is not None
|
| 184 |
+
for item in checkpoint_env_configs[1:]:
|
| 185 |
+
if item != first_payload:
|
| 186 |
+
raise ValueError("Checkpoint teachers use different env configs. Generate separate datasets.")
|
| 187 |
+
return first_payload
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def sample_scenario(
|
| 191 |
+
dataset: CityFlowDataset,
|
| 192 |
+
rng: random.Random,
|
| 193 |
+
split: str,
|
| 194 |
+
city_id: str | None,
|
| 195 |
+
scenario_name: str | None,
|
| 196 |
+
) -> ScenarioSpec:
|
| 197 |
+
if city_id and scenario_name:
|
| 198 |
+
return dataset.build_scenario_spec(city_id, scenario_name)
|
| 199 |
+
return dataset.sample_scenario(
|
| 200 |
+
split_name=split,
|
| 201 |
+
rng=rng,
|
| 202 |
+
city_id=city_id,
|
| 203 |
+
scenario_name=scenario_name,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def extract_step_actions(
|
| 208 |
+
env: TrafficEnv,
|
| 209 |
+
observation_batch: dict[str, Any],
|
| 210 |
+
next_observation_batch: dict[str, Any],
|
| 211 |
+
actions: np.ndarray,
|
| 212 |
+
) -> dict[str, list[LocalIntersectionAction]]:
|
| 213 |
+
grouped: dict[str, list[LocalIntersectionAction]] = {district_id: [] for district_id in env.districts}
|
| 214 |
+
lane_vehicle_count = env.adapter.get_lane_vehicle_count()
|
| 215 |
+
|
| 216 |
+
for index, intersection_id in enumerate(observation_batch["intersection_ids"]):
|
| 217 |
+
district_id = observation_batch["district_ids"][index]
|
| 218 |
+
grouped[district_id].append(
|
| 219 |
+
LocalIntersectionAction(
|
| 220 |
+
intersection_id=intersection_id,
|
| 221 |
+
district_id=district_id,
|
| 222 |
+
action=int(actions[index]),
|
| 223 |
+
current_phase=int(observation_batch["current_phase"][index]),
|
| 224 |
+
next_phase=int(next_observation_batch["current_phase"][index]),
|
| 225 |
+
queue_total=float(np.asarray(observation_batch["incoming_counts"][index], dtype=np.float32).sum()),
|
| 226 |
+
wait_total=float(np.asarray(observation_batch["incoming_waiting"][index], dtype=np.float32).sum()),
|
| 227 |
+
outgoing_load=float(
|
| 228 |
+
sum(
|
| 229 |
+
float(lane_vehicle_count.get(lane_id, 0))
|
| 230 |
+
for lane_id in env.intersections[intersection_id].outgoing_lanes
|
| 231 |
+
)
|
| 232 |
+
),
|
| 233 |
+
is_boundary=bool(env.intersections[intersection_id].is_boundary),
|
| 234 |
+
)
|
| 235 |
+
)
|
| 236 |
+
return grouped
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def generate_examples_for_episode(
|
| 240 |
+
env: TrafficEnv,
|
| 241 |
+
teacher: BaseTeacher,
|
| 242 |
+
district_interval: int,
|
| 243 |
+
top_k_congested: int,
|
| 244 |
+
max_candidate_intersections: int,
|
| 245 |
+
max_target_intersections: int,
|
| 246 |
+
episode_index: int,
|
| 247 |
+
) -> list[dict[str, Any]]:
|
| 248 |
+
summary_builder = DistrictStateSummaryBuilder(
|
| 249 |
+
top_k=top_k_congested,
|
| 250 |
+
candidate_limit=max_candidate_intersections,
|
| 251 |
+
)
|
| 252 |
+
observation_batch = env.reset()
|
| 253 |
+
summary_builder.reset()
|
| 254 |
+
current_summaries = summary_builder.build_all(env, observation_batch)
|
| 255 |
+
windows = {
|
| 256 |
+
district_id: _WindowBuffer(start_summary=summary)
|
| 257 |
+
for district_id, summary in current_summaries.items()
|
| 258 |
+
}
|
| 259 |
+
samples: list[dict[str, Any]] = []
|
| 260 |
+
done = False
|
| 261 |
+
window_index = 0
|
| 262 |
+
|
| 263 |
+
while not done:
|
| 264 |
+
actions = teacher.act(observation_batch)
|
| 265 |
+
next_observation_batch, rewards, done, info = env.step(actions)
|
| 266 |
+
del rewards, info
|
| 267 |
+
step_actions = extract_step_actions(env, observation_batch, next_observation_batch, actions)
|
| 268 |
+
next_summaries = summary_builder.build_all(env, next_observation_batch)
|
| 269 |
+
|
| 270 |
+
for district_id, buffer in windows.items():
|
| 271 |
+
buffer.controller_actions.extend(step_actions[district_id])
|
| 272 |
+
buffer.step_count += 1
|
| 273 |
+
should_emit = buffer.step_count >= district_interval or done
|
| 274 |
+
if not should_emit:
|
| 275 |
+
continue
|
| 276 |
+
|
| 277 |
+
end_summary = next_summaries[district_id]
|
| 278 |
+
window_data = DistrictWindowData(
|
| 279 |
+
district_id=district_id,
|
| 280 |
+
start_summary=buffer.start_summary,
|
| 281 |
+
end_summary=end_summary,
|
| 282 |
+
controller_actions=list(buffer.controller_actions),
|
| 283 |
+
step_count=buffer.step_count,
|
| 284 |
+
)
|
| 285 |
+
action = derive_district_action(
|
| 286 |
+
window_data=window_data,
|
| 287 |
+
max_target_intersections=max_target_intersections,
|
| 288 |
+
)
|
| 289 |
+
prompt = format_district_prompt(
|
| 290 |
+
buffer.start_summary,
|
| 291 |
+
max_target_intersections=max_target_intersections,
|
| 292 |
+
allow_only_visible_candidates=True,
|
| 293 |
+
)
|
| 294 |
+
samples.append(
|
| 295 |
+
{
|
| 296 |
+
"text": format_sft_text(
|
| 297 |
+
buffer.start_summary,
|
| 298 |
+
action,
|
| 299 |
+
max_target_intersections=max_target_intersections,
|
| 300 |
+
allow_only_visible_candidates=True,
|
| 301 |
+
),
|
| 302 |
+
"prompt": prompt,
|
| 303 |
+
"response_json": action.to_dict(),
|
| 304 |
+
"state": buffer.start_summary.to_dict(),
|
| 305 |
+
"candidate_intersections": buffer.start_summary.to_dict().get("candidate_intersections", []),
|
| 306 |
+
"window_summary": window_data.to_dict(),
|
| 307 |
+
"city_id": env.city_id,
|
| 308 |
+
"district_id": district_id,
|
| 309 |
+
"district_type": env.districts[district_id].district_type,
|
| 310 |
+
"scenario": env.scenario_name,
|
| 311 |
+
"controller_type": teacher.metadata.controller_type,
|
| 312 |
+
"controller_id": teacher.metadata.controller_id,
|
| 313 |
+
"controller_family": teacher.metadata.controller_family,
|
| 314 |
+
"teacher_algorithm": teacher.metadata.teacher_algorithm,
|
| 315 |
+
"checkpoint_path": teacher.metadata.checkpoint_path,
|
| 316 |
+
"episode_index": int(episode_index),
|
| 317 |
+
"window_index": int(window_index),
|
| 318 |
+
"decision_interval": int(district_interval),
|
| 319 |
+
"sim_time": int(buffer.start_summary.sim_time),
|
| 320 |
+
}
|
| 321 |
+
)
|
| 322 |
+
windows[district_id] = _WindowBuffer(start_summary=end_summary)
|
| 323 |
+
window_index += 1
|
| 324 |
+
|
| 325 |
+
observation_batch = next_observation_batch
|
| 326 |
+
|
| 327 |
+
return samples
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def append_jsonl(path: Path, records: list[dict[str, Any]], append: bool) -> None:
|
| 331 |
+
mode = "a" if append else "w"
|
| 332 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 333 |
+
with path.open(mode, encoding="utf-8") as handle:
|
| 334 |
+
for record in records:
|
| 335 |
+
handle.write(json.dumps(record, sort_keys=True))
|
| 336 |
+
handle.write("\n")
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def main() -> None:
|
| 340 |
+
args = parse_args()
|
| 341 |
+
dataset = CityFlowDataset(
|
| 342 |
+
generated_root=args.generated_root,
|
| 343 |
+
splits_root=args.splits_root,
|
| 344 |
+
)
|
| 345 |
+
dataset.generate_default_splits()
|
| 346 |
+
rng = random.Random(args.seed)
|
| 347 |
+
teachers = resolve_teachers(args)
|
| 348 |
+
env_config = resolve_env_config(args, teachers)
|
| 349 |
+
|
| 350 |
+
output_path = Path(args.output)
|
| 351 |
+
write_mode_append = bool(args.append)
|
| 352 |
+
|
| 353 |
+
for episode_index in range(args.episodes):
|
| 354 |
+
scenario_spec = sample_scenario(
|
| 355 |
+
dataset=dataset,
|
| 356 |
+
rng=rng,
|
| 357 |
+
split=args.split,
|
| 358 |
+
city_id=args.city_id,
|
| 359 |
+
scenario_name=args.scenario_name,
|
| 360 |
+
)
|
| 361 |
+
episode_records: list[dict[str, Any]] = []
|
| 362 |
+
for teacher in teachers:
|
| 363 |
+
env = build_env(env_config=env_config, scenario_spec=scenario_spec)
|
| 364 |
+
episode_records.extend(
|
| 365 |
+
generate_examples_for_episode(
|
| 366 |
+
env=env,
|
| 367 |
+
teacher=teacher,
|
| 368 |
+
district_interval=args.district_decision_interval,
|
| 369 |
+
top_k_congested=args.top_k_congested,
|
| 370 |
+
max_candidate_intersections=args.max_candidate_intersections,
|
| 371 |
+
max_target_intersections=args.max_target_intersections,
|
| 372 |
+
episode_index=episode_index,
|
| 373 |
+
)
|
| 374 |
+
)
|
| 375 |
+
append_jsonl(output_path, episode_records, append=write_mode_append)
|
| 376 |
+
write_mode_append = True
|
| 377 |
+
print(
|
| 378 |
+
json.dumps(
|
| 379 |
+
{
|
| 380 |
+
"episode_index": episode_index,
|
| 381 |
+
"city_id": scenario_spec.city_id,
|
| 382 |
+
"scenario_name": scenario_spec.scenario_name,
|
| 383 |
+
"records_written": len(episode_records),
|
| 384 |
+
}
|
| 385 |
+
)
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
if __name__ == "__main__":
|
| 390 |
+
main()
|
district_llm/guided_control.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from district_llm.schema import DistrictAction
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DistrictGuidedLocalController:
|
| 11 |
+
"""
|
| 12 |
+
Wrap a low-level controller and bias its actions with district directives.
|
| 13 |
+
|
| 14 |
+
The shared DQN still produces the base per-intersection action, and the
|
| 15 |
+
district plan only nudges hold/switch decisions toward the requested phase.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, base_teacher):
|
| 19 |
+
self.base_teacher = base_teacher
|
| 20 |
+
|
| 21 |
+
def act(
|
| 22 |
+
self,
|
| 23 |
+
observation_batch: dict[str, Any],
|
| 24 |
+
district_actions: dict[str, DistrictAction] | None = None,
|
| 25 |
+
) -> np.ndarray:
|
| 26 |
+
base_actions = np.asarray(self.base_teacher.act(observation_batch), dtype=np.int64)
|
| 27 |
+
if not district_actions:
|
| 28 |
+
return base_actions
|
| 29 |
+
|
| 30 |
+
guided_actions = base_actions.copy()
|
| 31 |
+
for index, district_id in enumerate(observation_batch["district_ids"]):
|
| 32 |
+
directive = district_actions.get(district_id)
|
| 33 |
+
if directive is None:
|
| 34 |
+
continue
|
| 35 |
+
guided_actions[index] = self._apply_directive(
|
| 36 |
+
observation_batch=observation_batch,
|
| 37 |
+
index=index,
|
| 38 |
+
base_action=int(base_actions[index]),
|
| 39 |
+
directive=directive,
|
| 40 |
+
)
|
| 41 |
+
return guided_actions
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def _apply_directive(
|
| 45 |
+
observation_batch: dict[str, Any],
|
| 46 |
+
index: int,
|
| 47 |
+
base_action: int,
|
| 48 |
+
directive: DistrictAction,
|
| 49 |
+
) -> int:
|
| 50 |
+
action_mask = observation_batch["action_mask"][index]
|
| 51 |
+
current_phase = int(observation_batch["current_phase"][index])
|
| 52 |
+
can_switch = bool(action_mask[1] > 0.0)
|
| 53 |
+
|
| 54 |
+
if directive.strategy == "hold" or directive.phase_bias == "NONE":
|
| 55 |
+
return int(base_action)
|
| 56 |
+
|
| 57 |
+
if directive.phase_bias == "NS":
|
| 58 |
+
if current_phase == 0:
|
| 59 |
+
return 0
|
| 60 |
+
return 1 if can_switch else 0
|
| 61 |
+
|
| 62 |
+
if directive.phase_bias == "EW":
|
| 63 |
+
if current_phase != 0:
|
| 64 |
+
return 0
|
| 65 |
+
return 1 if can_switch else 0
|
| 66 |
+
|
| 67 |
+
return int(base_action)
|
district_llm/heuristic_guidance.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
from district_llm.repair import fallback_target_intersections
|
| 6 |
+
from district_llm.schema import DistrictAction, DistrictStateSummary
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass(frozen=True)
|
| 10 |
+
class HeuristicGuidanceConfig:
|
| 11 |
+
max_target_intersections: int = 3
|
| 12 |
+
incident_duration_steps: int = 12
|
| 13 |
+
spillback_duration_steps: int = 10
|
| 14 |
+
default_duration_steps: int = 8
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def generate_heuristic_guidance(
|
| 18 |
+
summary: DistrictStateSummary,
|
| 19 |
+
config: HeuristicGuidanceConfig | None = None,
|
| 20 |
+
) -> DistrictAction:
|
| 21 |
+
config = config or HeuristicGuidanceConfig()
|
| 22 |
+
|
| 23 |
+
if summary.incident_flag or summary.construction_flag:
|
| 24 |
+
strategy = "incident_response"
|
| 25 |
+
priority_corridor = summary.dominant_flow if summary.dominant_flow in {"NS", "EW"} else "arterial"
|
| 26 |
+
phase_bias = summary.dominant_flow if summary.dominant_flow in {"NS", "EW"} else "NONE"
|
| 27 |
+
duration_steps = config.incident_duration_steps
|
| 28 |
+
elif summary.spillback_risk:
|
| 29 |
+
strategy = "clear_spillback"
|
| 30 |
+
boundary_share = summary.boundary_queue_total / max(1.0, summary.total_queue)
|
| 31 |
+
if boundary_share >= 0.45:
|
| 32 |
+
priority_corridor = "inbound"
|
| 33 |
+
elif summary.dominant_flow in {"NS", "EW"}:
|
| 34 |
+
priority_corridor = summary.dominant_flow
|
| 35 |
+
else:
|
| 36 |
+
priority_corridor = None
|
| 37 |
+
phase_bias = summary.dominant_flow if summary.dominant_flow in {"NS", "EW"} else "NONE"
|
| 38 |
+
duration_steps = config.spillback_duration_steps
|
| 39 |
+
elif summary.event_flag or summary.overload_flag:
|
| 40 |
+
strategy = "arterial_priority"
|
| 41 |
+
priority_corridor = summary.dominant_flow if summary.dominant_flow in {"NS", "EW"} else "arterial"
|
| 42 |
+
phase_bias = summary.dominant_flow if summary.dominant_flow in {"NS", "EW"} else "NONE"
|
| 43 |
+
duration_steps = config.spillback_duration_steps
|
| 44 |
+
elif summary.dominant_flow == "NS":
|
| 45 |
+
strategy = "favor_NS"
|
| 46 |
+
priority_corridor = "NS"
|
| 47 |
+
phase_bias = "NS"
|
| 48 |
+
duration_steps = config.default_duration_steps
|
| 49 |
+
elif summary.dominant_flow == "EW":
|
| 50 |
+
strategy = "favor_EW"
|
| 51 |
+
priority_corridor = "EW"
|
| 52 |
+
phase_bias = "EW"
|
| 53 |
+
duration_steps = config.default_duration_steps
|
| 54 |
+
else:
|
| 55 |
+
strategy = "hold"
|
| 56 |
+
priority_corridor = None
|
| 57 |
+
phase_bias = "NONE"
|
| 58 |
+
duration_steps = config.default_duration_steps
|
| 59 |
+
|
| 60 |
+
targets = fallback_target_intersections(
|
| 61 |
+
summary=summary,
|
| 62 |
+
max_target_intersections=config.max_target_intersections,
|
| 63 |
+
strategy=strategy,
|
| 64 |
+
priority_corridor=priority_corridor,
|
| 65 |
+
phase_bias=phase_bias,
|
| 66 |
+
)
|
| 67 |
+
return DistrictAction(
|
| 68 |
+
strategy=strategy,
|
| 69 |
+
priority_corridor=priority_corridor,
|
| 70 |
+
target_intersections=targets,
|
| 71 |
+
phase_bias=phase_bias,
|
| 72 |
+
duration_steps=duration_steps,
|
| 73 |
+
).validate()
|
district_llm/inference.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Callable
|
| 8 |
+
|
| 9 |
+
from district_llm.prompting import format_district_prompt
|
| 10 |
+
from district_llm.repair import RepairConfig, RepairReport, sanitize_action_payload
|
| 11 |
+
from district_llm.schema import DistrictAction, DistrictStateSummary
|
| 12 |
+
from district_llm.summary_builder import DistrictStateSummaryBuilder
|
| 13 |
+
from env.observation_builder import ObservationConfig
|
| 14 |
+
from env.reward import RewardConfig
|
| 15 |
+
from env.traffic_env import EnvConfig, TrafficEnv
|
| 16 |
+
from training.cityflow_dataset import CityFlowDataset
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _extract_json_object(payload: str) -> str:
|
| 20 |
+
start = payload.find("{")
|
| 21 |
+
end = payload.rfind("}")
|
| 22 |
+
if start == -1 or end == -1 or end <= start:
|
| 23 |
+
raise ValueError("No JSON object found in model output.")
|
| 24 |
+
return payload[start : end + 1]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass(frozen=True)
|
| 28 |
+
class DistrictLLMInferenceResult:
|
| 29 |
+
action: DistrictAction
|
| 30 |
+
raw_text: str
|
| 31 |
+
parsed_payload_before_repair: dict[str, Any]
|
| 32 |
+
repair_report: RepairReport
|
| 33 |
+
json_valid: bool
|
| 34 |
+
schema_valid_before_repair: bool
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class DistrictLLMInference:
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
generator_fn: Callable[[str], str] | None = None,
|
| 41 |
+
model_name_or_path: str | None = None,
|
| 42 |
+
device: str | None = None,
|
| 43 |
+
fallback_action: DistrictAction | None = None,
|
| 44 |
+
repair_config: RepairConfig | None = None,
|
| 45 |
+
):
|
| 46 |
+
self.fallback_action = fallback_action or DistrictAction.default_hold()
|
| 47 |
+
self.generator_fn = generator_fn
|
| 48 |
+
self.repair_config = repair_config or RepairConfig()
|
| 49 |
+
self.tokenizer = None
|
| 50 |
+
self.model = None
|
| 51 |
+
self.device = device or "cpu"
|
| 52 |
+
|
| 53 |
+
if self.generator_fn is None:
|
| 54 |
+
if not model_name_or_path:
|
| 55 |
+
raise ValueError("Provide either generator_fn or model_name_or_path.")
|
| 56 |
+
import torch
|
| 57 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 58 |
+
|
| 59 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 60 |
+
model_dir = Path(model_name_or_path)
|
| 61 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 62 |
+
if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None:
|
| 63 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 64 |
+
|
| 65 |
+
if (model_dir / "adapter_config.json").exists():
|
| 66 |
+
try:
|
| 67 |
+
from peft import AutoPeftModelForCausalLM
|
| 68 |
+
except ImportError as exc:
|
| 69 |
+
raise ImportError("Loading a LoRA adapter requires the 'peft' package.") from exc
|
| 70 |
+
self.model = AutoPeftModelForCausalLM.from_pretrained(model_name_or_path).to(self.device)
|
| 71 |
+
else:
|
| 72 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(self.device)
|
| 73 |
+
self.model.eval()
|
| 74 |
+
|
| 75 |
+
def generate_raw(self, prompt: str, max_new_tokens: int = 128) -> str:
|
| 76 |
+
if self.generator_fn is not None:
|
| 77 |
+
return self.generator_fn(prompt)
|
| 78 |
+
import torch
|
| 79 |
+
|
| 80 |
+
assert self.model is not None and self.tokenizer is not None
|
| 81 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
outputs = self.model.generate(
|
| 84 |
+
**inputs,
|
| 85 |
+
max_new_tokens=max_new_tokens,
|
| 86 |
+
do_sample=False,
|
| 87 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 88 |
+
)
|
| 89 |
+
generated = outputs[0][inputs["input_ids"].shape[1] :]
|
| 90 |
+
return self.tokenizer.decode(generated, skip_special_tokens=True)
|
| 91 |
+
|
| 92 |
+
def parse_action(
|
| 93 |
+
self,
|
| 94 |
+
payload: str,
|
| 95 |
+
summary: DistrictStateSummary | None = None,
|
| 96 |
+
) -> tuple[DistrictAction, RepairReport, dict[str, Any], bool, bool]:
|
| 97 |
+
json_valid = True
|
| 98 |
+
schema_valid_before_repair = True
|
| 99 |
+
try:
|
| 100 |
+
parsed_payload = json.loads(_extract_json_object(payload))
|
| 101 |
+
except Exception:
|
| 102 |
+
json_valid = False
|
| 103 |
+
schema_valid_before_repair = False
|
| 104 |
+
parsed_payload = self.fallback_action.to_dict()
|
| 105 |
+
action, repair_report = sanitize_action_payload(
|
| 106 |
+
payload=parsed_payload,
|
| 107 |
+
summary=summary,
|
| 108 |
+
config=self.repair_config,
|
| 109 |
+
)
|
| 110 |
+
return action, repair_report, parsed_payload, json_valid, schema_valid_before_repair
|
| 111 |
+
|
| 112 |
+
def predict_with_result(
|
| 113 |
+
self,
|
| 114 |
+
summary: DistrictStateSummary,
|
| 115 |
+
max_new_tokens: int = 128,
|
| 116 |
+
) -> DistrictLLMInferenceResult:
|
| 117 |
+
prompt = format_district_prompt(
|
| 118 |
+
summary,
|
| 119 |
+
max_target_intersections=self.repair_config.max_target_intersections,
|
| 120 |
+
allow_only_visible_candidates=self.repair_config.allow_only_visible_candidates,
|
| 121 |
+
)
|
| 122 |
+
raw = self.generate_raw(prompt=prompt, max_new_tokens=max_new_tokens)
|
| 123 |
+
action, repair_report, parsed_payload, json_valid, schema_valid_before_repair = self.parse_action(
|
| 124 |
+
raw,
|
| 125 |
+
summary=summary,
|
| 126 |
+
)
|
| 127 |
+
return DistrictLLMInferenceResult(
|
| 128 |
+
action=action,
|
| 129 |
+
raw_text=raw,
|
| 130 |
+
parsed_payload_before_repair=parsed_payload,
|
| 131 |
+
repair_report=repair_report,
|
| 132 |
+
json_valid=json_valid,
|
| 133 |
+
schema_valid_before_repair=schema_valid_before_repair,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def predict(self, summary: DistrictStateSummary, max_new_tokens: int = 128) -> DistrictAction:
|
| 137 |
+
return self.predict_with_result(summary=summary, max_new_tokens=max_new_tokens).action
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def parse_args() -> argparse.Namespace:
|
| 141 |
+
parser = argparse.ArgumentParser(description="Run single-sample district LLM inference.")
|
| 142 |
+
parser.add_argument("--model", required=True, help="Model name, local path, or LoRA adapter path.")
|
| 143 |
+
parser.add_argument("--generated-root", default="data/generated")
|
| 144 |
+
parser.add_argument("--splits-root", default="data/splits")
|
| 145 |
+
parser.add_argument("--city-id", required=True)
|
| 146 |
+
parser.add_argument("--scenario-name", required=True)
|
| 147 |
+
parser.add_argument("--district-id", required=True)
|
| 148 |
+
parser.add_argument("--device", default=None)
|
| 149 |
+
parser.add_argument("--max-new-tokens", type=int, default=128)
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--allow-only-visible-candidates",
|
| 152 |
+
action=argparse.BooleanOptionalAction,
|
| 153 |
+
default=True,
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument("--max-target-intersections", type=int, default=3)
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--fallback-on-empty-targets",
|
| 158 |
+
action=argparse.BooleanOptionalAction,
|
| 159 |
+
default=True,
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--fallback-mode",
|
| 163 |
+
choices=("heuristic", "hold", "none"),
|
| 164 |
+
default="heuristic",
|
| 165 |
+
)
|
| 166 |
+
return parser.parse_args()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def build_env(scenario_spec) -> TrafficEnv:
|
| 170 |
+
env_config = EnvConfig(
|
| 171 |
+
simulator_interval=1,
|
| 172 |
+
decision_interval=5,
|
| 173 |
+
min_green_time=10,
|
| 174 |
+
thread_num=1,
|
| 175 |
+
observation=ObservationConfig(),
|
| 176 |
+
reward=RewardConfig(variant="wait_queue_throughput"),
|
| 177 |
+
)
|
| 178 |
+
return TrafficEnv(
|
| 179 |
+
city_id=scenario_spec.city_id,
|
| 180 |
+
scenario_name=scenario_spec.scenario_name,
|
| 181 |
+
city_dir=scenario_spec.city_dir,
|
| 182 |
+
scenario_dir=scenario_spec.scenario_dir,
|
| 183 |
+
config_path=scenario_spec.config_path,
|
| 184 |
+
roadnet_path=scenario_spec.roadnet_path,
|
| 185 |
+
district_map_path=scenario_spec.district_map_path,
|
| 186 |
+
metadata_path=scenario_spec.metadata_path,
|
| 187 |
+
env_config=env_config,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def main() -> None:
|
| 192 |
+
args = parse_args()
|
| 193 |
+
dataset = CityFlowDataset(
|
| 194 |
+
generated_root=args.generated_root,
|
| 195 |
+
splits_root=args.splits_root,
|
| 196 |
+
)
|
| 197 |
+
scenario_spec = dataset.build_scenario_spec(args.city_id, args.scenario_name)
|
| 198 |
+
env = build_env(scenario_spec)
|
| 199 |
+
summary_builder = DistrictStateSummaryBuilder(candidate_limit=max(6, args.max_target_intersections))
|
| 200 |
+
observation_batch = env.reset()
|
| 201 |
+
summaries = summary_builder.build_all(env, observation_batch)
|
| 202 |
+
if args.district_id not in summaries:
|
| 203 |
+
raise ValueError(f"Unknown district_id '{args.district_id}' for {args.city_id}/{args.scenario_name}.")
|
| 204 |
+
inference = DistrictLLMInference(
|
| 205 |
+
model_name_or_path=args.model,
|
| 206 |
+
device=args.device,
|
| 207 |
+
fallback_action=DistrictAction.default_hold(),
|
| 208 |
+
repair_config=RepairConfig(
|
| 209 |
+
allow_only_visible_candidates=args.allow_only_visible_candidates,
|
| 210 |
+
max_target_intersections=args.max_target_intersections,
|
| 211 |
+
fallback_on_empty_targets=args.fallback_on_empty_targets,
|
| 212 |
+
fallback_mode=args.fallback_mode,
|
| 213 |
+
),
|
| 214 |
+
)
|
| 215 |
+
action = inference.predict(
|
| 216 |
+
summary=summaries[args.district_id],
|
| 217 |
+
max_new_tokens=args.max_new_tokens,
|
| 218 |
+
)
|
| 219 |
+
print(action.to_pretty_json())
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
if __name__ == "__main__":
|
| 223 |
+
main()
|
district_llm/metrics.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def safe_ratio(numerator: int | float, denominator: int | float, default_if_empty: float = 0.0) -> float:
|
| 7 |
+
if denominator <= 0:
|
| 8 |
+
return default_if_empty
|
| 9 |
+
return float(numerator) / float(denominator)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def compute_target_metrics(pred_list: list[str], gt_list: list[str]) -> dict[str, float]:
|
| 13 |
+
pred = list(pred_list)
|
| 14 |
+
gt = list(gt_list)
|
| 15 |
+
pred_set = set(pred)
|
| 16 |
+
gt_set = set(gt)
|
| 17 |
+
overlap = pred_set & gt_set
|
| 18 |
+
union = pred_set | gt_set
|
| 19 |
+
|
| 20 |
+
both_empty = not pred_set and not gt_set
|
| 21 |
+
precision_default = 1.0 if both_empty else 0.0
|
| 22 |
+
recall_default = 1.0 if both_empty else 0.0
|
| 23 |
+
jaccard_default = 1.0 if both_empty else 0.0
|
| 24 |
+
overlap_rate_default = 1.0 if both_empty else 0.0
|
| 25 |
+
|
| 26 |
+
overlap_count = len(overlap)
|
| 27 |
+
return {
|
| 28 |
+
"exact_list_match": float(pred == gt),
|
| 29 |
+
"exact_set_match": float(pred_set == gt_set),
|
| 30 |
+
"overlap_count": float(overlap_count),
|
| 31 |
+
"overlap_rate": safe_ratio(overlap_count, len(gt_set), overlap_rate_default),
|
| 32 |
+
"precision": safe_ratio(overlap_count, len(pred_set), precision_default),
|
| 33 |
+
"recall": safe_ratio(overlap_count, len(gt_set), recall_default),
|
| 34 |
+
"jaccard": safe_ratio(overlap_count, len(union), jaccard_default),
|
| 35 |
+
"hit_at_1": float(overlap_count >= 1),
|
| 36 |
+
"hit_at_2": float(overlap_count >= 2),
|
| 37 |
+
"hit_at_3": float(overlap_count >= 3),
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def aggregate_target_metrics(metric_rows: list[dict[str, float]]) -> dict[str, float]:
|
| 42 |
+
if not metric_rows:
|
| 43 |
+
return {}
|
| 44 |
+
keys = metric_rows[0].keys()
|
| 45 |
+
return {
|
| 46 |
+
key: float(sum(row[key] for row in metric_rows) / len(metric_rows))
|
| 47 |
+
for key in keys
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def target_failure_buckets(
|
| 52 |
+
pred_list: list[str],
|
| 53 |
+
gt_list: list[str],
|
| 54 |
+
visible_candidates: set[str],
|
| 55 |
+
invalid_ids: list[str] | None = None,
|
| 56 |
+
non_visible_ids: list[str] | None = None,
|
| 57 |
+
repaired_targets: list[str] | None = None,
|
| 58 |
+
fallback_used: bool = False,
|
| 59 |
+
) -> list[str]:
|
| 60 |
+
buckets: list[str] = []
|
| 61 |
+
pred_set = set(pred_list)
|
| 62 |
+
gt_set = set(gt_list)
|
| 63 |
+
|
| 64 |
+
if not pred_list:
|
| 65 |
+
buckets.append("prediction_empty")
|
| 66 |
+
if not gt_list:
|
| 67 |
+
buckets.append("ground_truth_empty")
|
| 68 |
+
if pred_list and gt_list and pred_set == gt_set and pred_list != gt_list:
|
| 69 |
+
buckets.append("same_set_different_order")
|
| 70 |
+
elif pred_set & gt_set:
|
| 71 |
+
buckets.append("partial_overlap")
|
| 72 |
+
elif pred_list and gt_list:
|
| 73 |
+
buckets.append("no_overlap")
|
| 74 |
+
|
| 75 |
+
if invalid_ids:
|
| 76 |
+
buckets.append("prediction_contains_invalid_ids")
|
| 77 |
+
if non_visible_ids:
|
| 78 |
+
buckets.append("prediction_contains_ids_not_visible_in_summary")
|
| 79 |
+
if pred_list and visible_candidates and any(item not in visible_candidates for item in pred_list):
|
| 80 |
+
buckets.append("prediction_contains_ids_not_visible_in_summary")
|
| 81 |
+
if fallback_used:
|
| 82 |
+
buckets.append("fallback_used")
|
| 83 |
+
|
| 84 |
+
if repaired_targets is not None:
|
| 85 |
+
repaired_set = set(repaired_targets)
|
| 86 |
+
if repaired_set == gt_set and pred_set != gt_set:
|
| 87 |
+
buckets.append("repaired_successfully")
|
| 88 |
+
elif (invalid_ids or non_visible_ids or fallback_used) and repaired_set != gt_set:
|
| 89 |
+
buckets.append("repair_failed")
|
| 90 |
+
|
| 91 |
+
return buckets
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def average_item_rate(values: list[list[Any]]) -> float:
|
| 95 |
+
numerators = sum(len(item) for item in values)
|
| 96 |
+
denominators = sum(max(len(item), 1) for item in values)
|
| 97 |
+
return safe_ratio(numerators, denominators)
|
district_llm/prompting.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from district_llm.schema import DISTRICT_STRATEGIES, PHASE_BIASES, PRIORITY_CORRIDORS, DistrictAction, DistrictStateSummary
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
DEFAULT_MAX_TARGET_INTERSECTIONS = 3
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def build_system_prompt(
|
| 10 |
+
max_target_intersections: int = DEFAULT_MAX_TARGET_INTERSECTIONS,
|
| 11 |
+
allow_only_visible_candidates: bool = True,
|
| 12 |
+
) -> str:
|
| 13 |
+
candidate_rule = (
|
| 14 |
+
" If candidate_intersections is present, target_intersections must use only ids from that list."
|
| 15 |
+
if allow_only_visible_candidates
|
| 16 |
+
else ""
|
| 17 |
+
)
|
| 18 |
+
return (
|
| 19 |
+
"You are a district traffic coordinator for RL traffic lights. "
|
| 20 |
+
"Return only valid JSON with exactly these keys: "
|
| 21 |
+
"strategy, priority_corridor, target_intersections, phase_bias, duration_steps. "
|
| 22 |
+
f"target_intersections must be a JSON array with at most {int(max_target_intersections)} unique ids."
|
| 23 |
+
f"{candidate_rule} "
|
| 24 |
+
"Do not invent intersection ids. Deduplicate ids. If uncertain, prefer the most congested valid candidates."
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def format_district_prompt(
|
| 29 |
+
summary: DistrictStateSummary,
|
| 30 |
+
max_target_intersections: int = DEFAULT_MAX_TARGET_INTERSECTIONS,
|
| 31 |
+
allow_only_visible_candidates: bool = True,
|
| 32 |
+
) -> str:
|
| 33 |
+
target_rule = (
|
| 34 |
+
f"target_intersections: up to {int(max_target_intersections)} ids from candidate_intersections only"
|
| 35 |
+
if allow_only_visible_candidates
|
| 36 |
+
else f"target_intersections: up to {int(max_target_intersections)} valid ids"
|
| 37 |
+
)
|
| 38 |
+
return "\n".join(
|
| 39 |
+
[
|
| 40 |
+
"### DISTRICT ACTION SCHEMA",
|
| 41 |
+
f"strategy: {'|'.join(DISTRICT_STRATEGIES)}",
|
| 42 |
+
f"phase_bias: {'|'.join(PHASE_BIASES)}",
|
| 43 |
+
f"priority_corridor: {'|'.join(PRIORITY_CORRIDORS)}|none",
|
| 44 |
+
"duration_steps: integer 1..20",
|
| 45 |
+
target_rule,
|
| 46 |
+
"rules: return only valid JSON; do not invent ids; deduplicate target_intersections",
|
| 47 |
+
"fallback: if uncertain, prefer the most congested visible candidates",
|
| 48 |
+
"",
|
| 49 |
+
"### DISTRICT STATE",
|
| 50 |
+
summary.to_prompt_text(),
|
| 51 |
+
"",
|
| 52 |
+
"### DECISION",
|
| 53 |
+
]
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def format_sft_text(
|
| 58 |
+
summary: DistrictStateSummary,
|
| 59 |
+
action: DistrictAction,
|
| 60 |
+
max_target_intersections: int = DEFAULT_MAX_TARGET_INTERSECTIONS,
|
| 61 |
+
allow_only_visible_candidates: bool = True,
|
| 62 |
+
) -> str:
|
| 63 |
+
return (
|
| 64 |
+
f"{format_district_prompt(summary, max_target_intersections=max_target_intersections, allow_only_visible_candidates=allow_only_visible_candidates)}\n"
|
| 65 |
+
f"{action.to_pretty_json()}"
|
| 66 |
+
)
|
district_llm/repair.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from district_llm.schema import (
|
| 8 |
+
DISTRICT_STRATEGIES,
|
| 9 |
+
PHASE_BIASES,
|
| 10 |
+
PRIORITY_CORRIDORS,
|
| 11 |
+
CandidateIntersection,
|
| 12 |
+
DistrictAction,
|
| 13 |
+
DistrictStateSummary,
|
| 14 |
+
candidate_priority_score,
|
| 15 |
+
canonicalize_target_intersections,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
INTERSECTION_ID_PATTERN = re.compile(r"\bi_\d+\b")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass(frozen=True)
|
| 23 |
+
class RepairConfig:
|
| 24 |
+
allow_only_visible_candidates: bool = True
|
| 25 |
+
max_target_intersections: int = 3
|
| 26 |
+
fallback_on_empty_targets: bool = True
|
| 27 |
+
fallback_mode: str = "heuristic"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class RepairReport:
|
| 32 |
+
raw_targets: list[str] = field(default_factory=list)
|
| 33 |
+
repaired_targets: list[str] = field(default_factory=list)
|
| 34 |
+
invalid_ids_removed: list[str] = field(default_factory=list)
|
| 35 |
+
non_visible_ids_removed: list[str] = field(default_factory=list)
|
| 36 |
+
deduplicated: bool = False
|
| 37 |
+
truncated: bool = False
|
| 38 |
+
fallback_used: bool = False
|
| 39 |
+
fallback_mode: str | None = None
|
| 40 |
+
empty_after_filtering: bool = False
|
| 41 |
+
|
| 42 |
+
def to_dict(self) -> dict[str, Any]:
|
| 43 |
+
return {
|
| 44 |
+
"raw_targets": list(self.raw_targets),
|
| 45 |
+
"repaired_targets": list(self.repaired_targets),
|
| 46 |
+
"invalid_ids_removed": list(self.invalid_ids_removed),
|
| 47 |
+
"non_visible_ids_removed": list(self.non_visible_ids_removed),
|
| 48 |
+
"deduplicated": bool(self.deduplicated),
|
| 49 |
+
"truncated": bool(self.truncated),
|
| 50 |
+
"fallback_used": bool(self.fallback_used),
|
| 51 |
+
"fallback_mode": self.fallback_mode,
|
| 52 |
+
"empty_after_filtering": bool(self.empty_after_filtering),
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def normalize_candidate_intersections(
|
| 57 |
+
payload: list[CandidateIntersection | dict[str, Any]] | None,
|
| 58 |
+
) -> list[dict[str, Any]]:
|
| 59 |
+
normalized: list[dict[str, Any]] = []
|
| 60 |
+
for item in payload or []:
|
| 61 |
+
if isinstance(item, CandidateIntersection):
|
| 62 |
+
normalized.append(item.to_dict())
|
| 63 |
+
elif isinstance(item, dict):
|
| 64 |
+
normalized.append(dict(item))
|
| 65 |
+
return normalized
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def parse_candidate_intersections_from_text(text: str) -> list[dict[str, Any]]:
|
| 69 |
+
if "candidate_intersections:" not in text:
|
| 70 |
+
return []
|
| 71 |
+
|
| 72 |
+
candidates: list[dict[str, Any]] = []
|
| 73 |
+
capture = False
|
| 74 |
+
for line in text.splitlines():
|
| 75 |
+
stripped = line.strip()
|
| 76 |
+
if stripped == "candidate_intersections:":
|
| 77 |
+
capture = True
|
| 78 |
+
continue
|
| 79 |
+
if not capture:
|
| 80 |
+
continue
|
| 81 |
+
if stripped == "- none":
|
| 82 |
+
continue
|
| 83 |
+
if not stripped.startswith("- "):
|
| 84 |
+
if stripped.endswith(":"):
|
| 85 |
+
break
|
| 86 |
+
continue
|
| 87 |
+
fields = stripped[2:].split()
|
| 88 |
+
if not fields:
|
| 89 |
+
continue
|
| 90 |
+
candidate: dict[str, Any] = {
|
| 91 |
+
"intersection_id": fields[0],
|
| 92 |
+
"queue_total": 0.0,
|
| 93 |
+
"wait_total": 0.0,
|
| 94 |
+
"outgoing_load": 0.0,
|
| 95 |
+
"current_phase": 0,
|
| 96 |
+
"is_boundary": False,
|
| 97 |
+
"spillback_risk": False,
|
| 98 |
+
"incident_proximity": False,
|
| 99 |
+
"overload_marker": False,
|
| 100 |
+
"event_proximity": False,
|
| 101 |
+
"corridor_alignment": "BALANCED",
|
| 102 |
+
"selection_reasons": [],
|
| 103 |
+
}
|
| 104 |
+
for token in fields[1:]:
|
| 105 |
+
if "=" not in token:
|
| 106 |
+
continue
|
| 107 |
+
key, value = token.split("=", 1)
|
| 108 |
+
if key == "q":
|
| 109 |
+
candidate["queue_total"] = float(value)
|
| 110 |
+
elif key == "w":
|
| 111 |
+
candidate["wait_total"] = float(value)
|
| 112 |
+
elif key == "out":
|
| 113 |
+
candidate["outgoing_load"] = float(value)
|
| 114 |
+
elif key == "phase":
|
| 115 |
+
candidate["current_phase"] = int(value)
|
| 116 |
+
elif key == "boundary":
|
| 117 |
+
candidate["is_boundary"] = value == "1"
|
| 118 |
+
elif key == "spillback":
|
| 119 |
+
candidate["spillback_risk"] = value == "1"
|
| 120 |
+
elif key == "incident":
|
| 121 |
+
candidate["incident_proximity"] = value == "1"
|
| 122 |
+
elif key == "overload":
|
| 123 |
+
candidate["overload_marker"] = value == "1"
|
| 124 |
+
elif key == "event":
|
| 125 |
+
candidate["event_proximity"] = value == "1"
|
| 126 |
+
elif key == "align":
|
| 127 |
+
candidate["corridor_alignment"] = value
|
| 128 |
+
elif key == "reasons":
|
| 129 |
+
candidate["selection_reasons"] = [] if value == "none" else value.split("|")
|
| 130 |
+
candidates.append(candidate)
|
| 131 |
+
return normalized_candidate_intersections_from_dicts(candidates)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def normalized_candidate_intersections_from_dicts(
|
| 135 |
+
payload: list[dict[str, Any]],
|
| 136 |
+
) -> list[dict[str, Any]]:
|
| 137 |
+
normalized: list[dict[str, Any]] = []
|
| 138 |
+
for item in payload:
|
| 139 |
+
try:
|
| 140 |
+
normalized.append(
|
| 141 |
+
CandidateIntersection(
|
| 142 |
+
intersection_id=str(item.get("intersection_id", "")).strip(),
|
| 143 |
+
queue_total=float(item.get("queue_total", 0.0)),
|
| 144 |
+
wait_total=float(item.get("wait_total", 0.0)),
|
| 145 |
+
outgoing_load=float(item.get("outgoing_load", 0.0)),
|
| 146 |
+
current_phase=int(item.get("current_phase", 0)),
|
| 147 |
+
is_boundary=bool(item.get("is_boundary", False)),
|
| 148 |
+
spillback_risk=bool(item.get("spillback_risk", False)),
|
| 149 |
+
incident_proximity=bool(item.get("incident_proximity", False)),
|
| 150 |
+
overload_marker=bool(item.get("overload_marker", False)),
|
| 151 |
+
event_proximity=bool(item.get("event_proximity", False)),
|
| 152 |
+
corridor_alignment=str(item.get("corridor_alignment", "BALANCED")),
|
| 153 |
+
selection_reasons=list(item.get("selection_reasons", [])),
|
| 154 |
+
).to_dict()
|
| 155 |
+
)
|
| 156 |
+
except Exception:
|
| 157 |
+
continue
|
| 158 |
+
return normalized
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def candidate_intersections_from_context(
|
| 162 |
+
summary: DistrictStateSummary | dict[str, Any] | None = None,
|
| 163 |
+
prompt_text: str | None = None,
|
| 164 |
+
) -> list[dict[str, Any]]:
|
| 165 |
+
if isinstance(summary, DistrictStateSummary):
|
| 166 |
+
return normalize_candidate_intersections(summary.candidate_intersections)
|
| 167 |
+
if isinstance(summary, dict):
|
| 168 |
+
if "candidate_intersections" in summary:
|
| 169 |
+
return normalize_candidate_intersections(summary.get("candidate_intersections"))
|
| 170 |
+
state_payload = summary.get("state")
|
| 171 |
+
if isinstance(state_payload, dict) and "candidate_intersections" in state_payload:
|
| 172 |
+
return normalize_candidate_intersections(state_payload.get("candidate_intersections"))
|
| 173 |
+
if prompt_text:
|
| 174 |
+
return parse_candidate_intersections_from_text(prompt_text)
|
| 175 |
+
return []
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def fallback_target_intersections(
|
| 179 |
+
summary: DistrictStateSummary | dict[str, Any] | None = None,
|
| 180 |
+
prompt_text: str | None = None,
|
| 181 |
+
max_target_intersections: int = 3,
|
| 182 |
+
strategy: str | None = None,
|
| 183 |
+
priority_corridor: str | None = None,
|
| 184 |
+
phase_bias: str | None = None,
|
| 185 |
+
focus_scores: dict[str, float] | None = None,
|
| 186 |
+
) -> list[str]:
|
| 187 |
+
candidate_intersections = candidate_intersections_from_context(summary=summary, prompt_text=prompt_text)
|
| 188 |
+
if candidate_intersections:
|
| 189 |
+
ordered_candidates = sorted(
|
| 190 |
+
candidate_intersections,
|
| 191 |
+
key=lambda item: (
|
| 192 |
+
-(
|
| 193 |
+
candidate_priority_score(item)
|
| 194 |
+
+ _focus_score_bonus(item, focus_scores)
|
| 195 |
+
+ _strategy_target_bonus(
|
| 196 |
+
candidate=item,
|
| 197 |
+
strategy=strategy,
|
| 198 |
+
priority_corridor=priority_corridor,
|
| 199 |
+
phase_bias=phase_bias,
|
| 200 |
+
)
|
| 201 |
+
),
|
| 202 |
+
-float(item.get("queue_total", 0.0)),
|
| 203 |
+
-float(item.get("wait_total", 0.0)),
|
| 204 |
+
-float(item.get("outgoing_load", 0.0)),
|
| 205 |
+
str(item.get("intersection_id", "")),
|
| 206 |
+
),
|
| 207 |
+
)
|
| 208 |
+
ordered_ids = canonicalize_target_intersections(
|
| 209 |
+
[item["intersection_id"] for item in ordered_candidates],
|
| 210 |
+
ordered_candidates,
|
| 211 |
+
limit=max_target_intersections,
|
| 212 |
+
)
|
| 213 |
+
return ordered_ids[:max_target_intersections]
|
| 214 |
+
|
| 215 |
+
if isinstance(summary, DistrictStateSummary):
|
| 216 |
+
return [item.intersection_id for item in summary.top_congested_intersections[:max_target_intersections]]
|
| 217 |
+
if isinstance(summary, dict):
|
| 218 |
+
top_congested = summary.get("top_congested_intersections") or summary.get("state", {}).get("top_congested_intersections", [])
|
| 219 |
+
return [
|
| 220 |
+
str(item.get("intersection_id"))
|
| 221 |
+
for item in top_congested[:max_target_intersections]
|
| 222 |
+
if str(item.get("intersection_id", "")).strip()
|
| 223 |
+
]
|
| 224 |
+
if prompt_text:
|
| 225 |
+
return list(dict.fromkeys(INTERSECTION_ID_PATTERN.findall(prompt_text)))[:max_target_intersections]
|
| 226 |
+
return []
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _focus_score_bonus(candidate: dict[str, Any], focus_scores: dict[str, float] | None) -> float:
|
| 230 |
+
if not focus_scores:
|
| 231 |
+
return 0.0
|
| 232 |
+
max_focus = max(max(focus_scores.values()), 1.0)
|
| 233 |
+
return 4.0 * float(focus_scores.get(str(candidate.get("intersection_id", "")), 0.0)) / max_focus
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _strategy_target_bonus(
|
| 237 |
+
candidate: dict[str, Any],
|
| 238 |
+
strategy: str | None,
|
| 239 |
+
priority_corridor: str | None,
|
| 240 |
+
phase_bias: str | None,
|
| 241 |
+
) -> float:
|
| 242 |
+
reasons = set(candidate.get("selection_reasons", []))
|
| 243 |
+
corridor_alignment = str(candidate.get("corridor_alignment", "BALANCED"))
|
| 244 |
+
bonus = 0.0
|
| 245 |
+
|
| 246 |
+
if strategy == "incident_response":
|
| 247 |
+
bonus += 4.0 * float(bool(candidate.get("incident_proximity", False)))
|
| 248 |
+
elif strategy == "clear_spillback":
|
| 249 |
+
bonus += 4.0 * float(bool(candidate.get("spillback_risk", False)))
|
| 250 |
+
bonus += 1.0 * float(bool(candidate.get("is_boundary", False)))
|
| 251 |
+
elif strategy == "drain_inbound":
|
| 252 |
+
bonus += 4.0 * float(bool(candidate.get("is_boundary", False)))
|
| 253 |
+
bonus += 1.0 * float(bool(candidate.get("spillback_risk", False)))
|
| 254 |
+
elif strategy == "drain_outbound":
|
| 255 |
+
bonus += 4.0 * float("outgoing" in reasons)
|
| 256 |
+
bonus += 1.0 * float(bool(candidate.get("spillback_risk", False)))
|
| 257 |
+
elif strategy == "arterial_priority":
|
| 258 |
+
bonus += 2.0 * float(bool(candidate.get("is_boundary", False)))
|
| 259 |
+
bonus += 1.5 * float(bool(candidate.get("overload_marker", False)))
|
| 260 |
+
bonus += 1.5 * float(bool(candidate.get("event_proximity", False)))
|
| 261 |
+
elif strategy == "favor_NS":
|
| 262 |
+
bonus += 4.0 * float(corridor_alignment == "NS")
|
| 263 |
+
elif strategy == "favor_EW":
|
| 264 |
+
bonus += 4.0 * float(corridor_alignment == "EW")
|
| 265 |
+
|
| 266 |
+
if priority_corridor in {"NS", "EW"}:
|
| 267 |
+
bonus += 1.5 * float(corridor_alignment == priority_corridor)
|
| 268 |
+
elif priority_corridor == "inbound":
|
| 269 |
+
bonus += 1.5 * float(bool(candidate.get("is_boundary", False)))
|
| 270 |
+
elif priority_corridor == "outbound":
|
| 271 |
+
bonus += 1.5 * float("outgoing" in reasons)
|
| 272 |
+
elif priority_corridor == "arterial":
|
| 273 |
+
bonus += 0.75 * float(bool(candidate.get("is_boundary", False)))
|
| 274 |
+
|
| 275 |
+
if phase_bias in {"NS", "EW"}:
|
| 276 |
+
bonus += 0.5 * float(corridor_alignment == phase_bias)
|
| 277 |
+
|
| 278 |
+
return bonus
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def extract_visible_candidate_ids(
|
| 282 |
+
summary: DistrictStateSummary | dict[str, Any] | None = None,
|
| 283 |
+
prompt_text: str | None = None,
|
| 284 |
+
) -> list[str]:
|
| 285 |
+
candidate_intersections = candidate_intersections_from_context(summary=summary, prompt_text=prompt_text)
|
| 286 |
+
if candidate_intersections:
|
| 287 |
+
return [item["intersection_id"] for item in candidate_intersections]
|
| 288 |
+
if prompt_text:
|
| 289 |
+
return list(dict.fromkeys(INTERSECTION_ID_PATTERN.findall(prompt_text)))
|
| 290 |
+
return []
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def sanitize_action_payload(
|
| 294 |
+
payload: dict[str, Any] | None,
|
| 295 |
+
summary: DistrictStateSummary | dict[str, Any] | None = None,
|
| 296 |
+
prompt_text: str | None = None,
|
| 297 |
+
config: RepairConfig | None = None,
|
| 298 |
+
) -> tuple[DistrictAction, RepairReport]:
|
| 299 |
+
config = config or RepairConfig()
|
| 300 |
+
payload = dict(payload or {})
|
| 301 |
+
candidate_intersections = candidate_intersections_from_context(summary=summary, prompt_text=prompt_text)
|
| 302 |
+
visible_candidate_ids = [item["intersection_id"] for item in candidate_intersections]
|
| 303 |
+
visible_candidate_set = set(visible_candidate_ids)
|
| 304 |
+
|
| 305 |
+
raw_target_payload = payload.get("target_intersections", [])
|
| 306 |
+
if isinstance(raw_target_payload, str):
|
| 307 |
+
raw_targets = INTERSECTION_ID_PATTERN.findall(raw_target_payload)
|
| 308 |
+
elif isinstance(raw_target_payload, (list, tuple)):
|
| 309 |
+
raw_targets = [str(item).strip() for item in raw_target_payload if str(item).strip()]
|
| 310 |
+
else:
|
| 311 |
+
raw_targets = []
|
| 312 |
+
|
| 313 |
+
report = RepairReport(raw_targets=list(raw_targets))
|
| 314 |
+
deduped_targets: list[str] = []
|
| 315 |
+
seen: set[str] = set()
|
| 316 |
+
for item in raw_targets:
|
| 317 |
+
if item in seen:
|
| 318 |
+
report.deduplicated = True
|
| 319 |
+
continue
|
| 320 |
+
seen.add(item)
|
| 321 |
+
deduped_targets.append(item)
|
| 322 |
+
|
| 323 |
+
filtered_targets: list[str] = []
|
| 324 |
+
for item in deduped_targets:
|
| 325 |
+
if not INTERSECTION_ID_PATTERN.fullmatch(item):
|
| 326 |
+
report.invalid_ids_removed.append(item)
|
| 327 |
+
continue
|
| 328 |
+
if config.allow_only_visible_candidates and visible_candidate_set and item not in visible_candidate_set:
|
| 329 |
+
report.non_visible_ids_removed.append(item)
|
| 330 |
+
continue
|
| 331 |
+
filtered_targets.append(item)
|
| 332 |
+
|
| 333 |
+
if len(filtered_targets) > int(config.max_target_intersections):
|
| 334 |
+
report.truncated = True
|
| 335 |
+
filtered_targets = canonicalize_target_intersections(
|
| 336 |
+
filtered_targets,
|
| 337 |
+
candidate_intersections,
|
| 338 |
+
limit=int(config.max_target_intersections),
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
if not filtered_targets:
|
| 342 |
+
report.empty_after_filtering = bool(raw_targets)
|
| 343 |
+
if config.fallback_on_empty_targets:
|
| 344 |
+
report.fallback_used = True
|
| 345 |
+
report.fallback_mode = config.fallback_mode
|
| 346 |
+
if config.fallback_mode == "heuristic":
|
| 347 |
+
filtered_targets = fallback_target_intersections(
|
| 348 |
+
summary=summary,
|
| 349 |
+
prompt_text=prompt_text,
|
| 350 |
+
max_target_intersections=int(config.max_target_intersections),
|
| 351 |
+
)
|
| 352 |
+
elif config.fallback_mode == "hold":
|
| 353 |
+
filtered_targets = []
|
| 354 |
+
elif config.fallback_mode == "none":
|
| 355 |
+
filtered_targets = []
|
| 356 |
+
else:
|
| 357 |
+
raise ValueError(f"Unsupported fallback_mode '{config.fallback_mode}'.")
|
| 358 |
+
|
| 359 |
+
strategy = str(payload.get("strategy", "hold"))
|
| 360 |
+
if strategy not in DISTRICT_STRATEGIES:
|
| 361 |
+
strategy = "hold"
|
| 362 |
+
|
| 363 |
+
priority_corridor = payload.get("priority_corridor")
|
| 364 |
+
if priority_corridor is not None:
|
| 365 |
+
priority_corridor = str(priority_corridor)
|
| 366 |
+
if priority_corridor not in PRIORITY_CORRIDORS:
|
| 367 |
+
priority_corridor = None
|
| 368 |
+
|
| 369 |
+
phase_bias = str(payload.get("phase_bias", "NONE"))
|
| 370 |
+
if phase_bias not in PHASE_BIASES:
|
| 371 |
+
phase_bias = "NONE"
|
| 372 |
+
|
| 373 |
+
duration_steps_raw = payload.get("duration_steps", 1)
|
| 374 |
+
try:
|
| 375 |
+
duration_steps = int(duration_steps_raw)
|
| 376 |
+
except (TypeError, ValueError):
|
| 377 |
+
duration_steps = 1
|
| 378 |
+
duration_steps = max(1, min(duration_steps, 20))
|
| 379 |
+
|
| 380 |
+
if config.fallback_mode == "hold" and report.fallback_used and not filtered_targets:
|
| 381 |
+
action = DistrictAction.default_hold(duration_steps=duration_steps)
|
| 382 |
+
else:
|
| 383 |
+
action = DistrictAction(
|
| 384 |
+
strategy=strategy,
|
| 385 |
+
priority_corridor=priority_corridor,
|
| 386 |
+
target_intersections=filtered_targets,
|
| 387 |
+
phase_bias=phase_bias,
|
| 388 |
+
duration_steps=duration_steps,
|
| 389 |
+
).validate()
|
| 390 |
+
|
| 391 |
+
report.repaired_targets = list(action.target_intersections)
|
| 392 |
+
return action, report
|
district_llm/rl_guidance_wrapper.py
ADDED
|
@@ -0,0 +1,1004 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from dataclasses import asdict, dataclass, field, replace
|
| 5 |
+
import hashlib
|
| 6 |
+
from time import perf_counter
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from district_llm.heuristic_guidance import (
|
| 13 |
+
HeuristicGuidanceConfig,
|
| 14 |
+
generate_heuristic_guidance,
|
| 15 |
+
)
|
| 16 |
+
from district_llm.inference import DistrictLLMInference, DistrictLLMInferenceResult
|
| 17 |
+
from district_llm.repair import RepairReport
|
| 18 |
+
from district_llm.schema import CandidateIntersection, DistrictAction, DistrictStateSummary
|
| 19 |
+
from district_llm.summary_builder import DistrictStateSummaryBuilder
|
| 20 |
+
from district_llm.teachers import RLCheckpointTeacher
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
WRAPPER_MODES: tuple[str, ...] = (
|
| 24 |
+
"no_op",
|
| 25 |
+
"target_only_soft",
|
| 26 |
+
"target_only_medium",
|
| 27 |
+
"corridor_soft",
|
| 28 |
+
"global_soft",
|
| 29 |
+
"current_legacy",
|
| 30 |
+
)
|
| 31 |
+
FALLBACK_POLICIES: tuple[str, ...] = (
|
| 32 |
+
"no_op",
|
| 33 |
+
"hold_previous",
|
| 34 |
+
"heuristic_weak",
|
| 35 |
+
)
|
| 36 |
+
GATING_MODES: tuple[str, ...] = (
|
| 37 |
+
"always_on",
|
| 38 |
+
"incident_or_spillback",
|
| 39 |
+
"queue_threshold",
|
| 40 |
+
"imbalance_threshold",
|
| 41 |
+
"queue_or_imbalance",
|
| 42 |
+
"combined",
|
| 43 |
+
)
|
| 44 |
+
BIAS_DECAY_SCHEDULES: tuple[str, ...] = (
|
| 45 |
+
"linear",
|
| 46 |
+
)
|
| 47 |
+
STRATEGY_BIAS_MULTIPLIERS: dict[str, float] = {
|
| 48 |
+
"hold": 0.0,
|
| 49 |
+
"favor_NS": 1.0,
|
| 50 |
+
"favor_EW": 1.0,
|
| 51 |
+
"drain_inbound": 1.05,
|
| 52 |
+
"drain_outbound": 1.05,
|
| 53 |
+
"clear_spillback": 1.1,
|
| 54 |
+
"incident_response": 1.15,
|
| 55 |
+
"arterial_priority": 1.05,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass(frozen=True)
|
| 60 |
+
class GuidanceInfluenceConfig:
|
| 61 |
+
"""
|
| 62 |
+
Conservative inference-time wrapper around the fixed DQN policy.
|
| 63 |
+
|
| 64 |
+
The DQN checkpoint remains unchanged. Guidance is treated as a weak prior
|
| 65 |
+
and only biases Q-values slightly before greedy action selection.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
wrapper_mode: str = "target_only_soft"
|
| 69 |
+
bias_strength: float = 0.12
|
| 70 |
+
target_only_bias_strength: float = 0.18
|
| 71 |
+
corridor_bias_strength: float = 0.05
|
| 72 |
+
max_intersections_affected: int = 3
|
| 73 |
+
guidance_refresh_steps: int = 5
|
| 74 |
+
guidance_persistence_steps: int = 3
|
| 75 |
+
max_guidance_duration: int = 6
|
| 76 |
+
apply_global_bias: bool = False
|
| 77 |
+
apply_target_only: bool = True
|
| 78 |
+
gating_mode: str = "always_on"
|
| 79 |
+
min_avg_queue_for_guidance: float = 150.0
|
| 80 |
+
min_queue_imbalance_for_guidance: float = 20.0
|
| 81 |
+
require_incident_or_spillback: bool = False
|
| 82 |
+
allow_guidance_in_normal_conditions: bool = True
|
| 83 |
+
enable_bias_decay: bool = True
|
| 84 |
+
bias_decay_schedule: str = "linear"
|
| 85 |
+
fallback_policy: str = "hold_previous"
|
| 86 |
+
log_guidance_debug: bool = False
|
| 87 |
+
max_debug_chars: int = 240
|
| 88 |
+
|
| 89 |
+
def validate(self) -> "GuidanceInfluenceConfig":
|
| 90 |
+
if self.wrapper_mode not in WRAPPER_MODES:
|
| 91 |
+
raise ValueError(
|
| 92 |
+
f"Unsupported wrapper_mode '{self.wrapper_mode}'. Expected one of {WRAPPER_MODES}."
|
| 93 |
+
)
|
| 94 |
+
if self.fallback_policy not in FALLBACK_POLICIES:
|
| 95 |
+
raise ValueError(
|
| 96 |
+
f"Unsupported fallback_policy '{self.fallback_policy}'. Expected one of {FALLBACK_POLICIES}."
|
| 97 |
+
)
|
| 98 |
+
if self.gating_mode not in GATING_MODES:
|
| 99 |
+
raise ValueError(
|
| 100 |
+
f"Unsupported gating_mode '{self.gating_mode}'. Expected one of {GATING_MODES}."
|
| 101 |
+
)
|
| 102 |
+
if self.bias_decay_schedule not in BIAS_DECAY_SCHEDULES:
|
| 103 |
+
raise ValueError(
|
| 104 |
+
f"Unsupported bias_decay_schedule '{self.bias_decay_schedule}'. "
|
| 105 |
+
f"Expected one of {BIAS_DECAY_SCHEDULES}."
|
| 106 |
+
)
|
| 107 |
+
if self.guidance_refresh_steps < 1:
|
| 108 |
+
raise ValueError("guidance_refresh_steps must be at least 1.")
|
| 109 |
+
if self.guidance_persistence_steps < 1:
|
| 110 |
+
raise ValueError("guidance_persistence_steps must be at least 1.")
|
| 111 |
+
if self.max_guidance_duration < 1:
|
| 112 |
+
raise ValueError("max_guidance_duration must be at least 1.")
|
| 113 |
+
if self.max_intersections_affected < 1:
|
| 114 |
+
raise ValueError("max_intersections_affected must be at least 1.")
|
| 115 |
+
return self
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@dataclass(frozen=True)
|
| 119 |
+
class RLPolicyDecision:
|
| 120 |
+
q_values: np.ndarray
|
| 121 |
+
actions: np.ndarray
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@dataclass
|
| 125 |
+
class GuidanceDecision:
|
| 126 |
+
source: str
|
| 127 |
+
action: DistrictAction
|
| 128 |
+
runtime_seconds: float
|
| 129 |
+
raw_text: str | None = None
|
| 130 |
+
parsed_payload_before_repair: dict[str, Any] | None = None
|
| 131 |
+
repair_report: RepairReport | None = None
|
| 132 |
+
json_valid: bool = True
|
| 133 |
+
schema_valid_before_repair: bool = True
|
| 134 |
+
provider_error: str | None = None
|
| 135 |
+
fallback_policy_applied: str | None = None
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def repair_applied(self) -> bool:
|
| 139 |
+
report = self.repair_report
|
| 140 |
+
if report is None:
|
| 141 |
+
return False
|
| 142 |
+
return any(
|
| 143 |
+
(
|
| 144 |
+
report.invalid_ids_removed,
|
| 145 |
+
report.non_visible_ids_removed,
|
| 146 |
+
report.deduplicated,
|
| 147 |
+
report.truncated,
|
| 148 |
+
report.fallback_used,
|
| 149 |
+
report.empty_after_filtering,
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
@property
|
| 154 |
+
def invalid_before_repair(self) -> bool:
|
| 155 |
+
report = self.repair_report
|
| 156 |
+
if self.provider_error:
|
| 157 |
+
return True
|
| 158 |
+
if not self.json_valid or not self.schema_valid_before_repair:
|
| 159 |
+
return True
|
| 160 |
+
if report is None:
|
| 161 |
+
return False
|
| 162 |
+
return bool(
|
| 163 |
+
report.invalid_ids_removed
|
| 164 |
+
or report.non_visible_ids_removed
|
| 165 |
+
or report.empty_after_filtering
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def to_trace_payload(self) -> dict[str, Any]:
|
| 169 |
+
return {
|
| 170 |
+
"source": self.source,
|
| 171 |
+
"runtime_seconds": float(self.runtime_seconds),
|
| 172 |
+
"action": self.action.to_dict(),
|
| 173 |
+
"raw_text": self.raw_text,
|
| 174 |
+
"parsed_payload_before_repair": self.parsed_payload_before_repair,
|
| 175 |
+
"repair_report": None if self.repair_report is None else self.repair_report.to_dict(),
|
| 176 |
+
"json_valid": bool(self.json_valid),
|
| 177 |
+
"schema_valid_before_repair": bool(self.schema_valid_before_repair),
|
| 178 |
+
"repair_applied": bool(self.repair_applied),
|
| 179 |
+
"invalid_before_repair": bool(self.invalid_before_repair),
|
| 180 |
+
"provider_error": self.provider_error,
|
| 181 |
+
"fallback_policy_applied": self.fallback_policy_applied,
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@dataclass(frozen=True)
|
| 186 |
+
class GuidanceApplicationPlan:
|
| 187 |
+
wrapper_mode: str
|
| 188 |
+
scope: str
|
| 189 |
+
affected_intersections: tuple[str, ...]
|
| 190 |
+
targeted_intersections: tuple[str, ...]
|
| 191 |
+
target_candidate_ids: tuple[str, ...]
|
| 192 |
+
priority_direction: str | None
|
| 193 |
+
strength_scale: float
|
| 194 |
+
base_bias_strength: float
|
| 195 |
+
target_bias_strength: float
|
| 196 |
+
corridor_bias_strength: float
|
| 197 |
+
apply_global_bias: bool
|
| 198 |
+
apply_target_only: bool
|
| 199 |
+
max_intersections_affected: int
|
| 200 |
+
|
| 201 |
+
def to_dict(self) -> dict[str, Any]:
|
| 202 |
+
return {
|
| 203 |
+
"wrapper_mode": self.wrapper_mode,
|
| 204 |
+
"scope": self.scope,
|
| 205 |
+
"affected_intersections": list(self.affected_intersections),
|
| 206 |
+
"targeted_intersections": list(self.targeted_intersections),
|
| 207 |
+
"target_candidate_ids": list(self.target_candidate_ids),
|
| 208 |
+
"priority_direction": self.priority_direction,
|
| 209 |
+
"strength_scale": float(self.strength_scale),
|
| 210 |
+
"base_bias_strength": float(self.base_bias_strength),
|
| 211 |
+
"target_bias_strength": float(self.target_bias_strength),
|
| 212 |
+
"corridor_bias_strength": float(self.corridor_bias_strength),
|
| 213 |
+
"apply_global_bias": bool(self.apply_global_bias),
|
| 214 |
+
"apply_target_only": bool(self.apply_target_only),
|
| 215 |
+
"max_intersections_affected": int(self.max_intersections_affected),
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
@dataclass
|
| 220 |
+
class ActiveDistrictGuidance:
|
| 221 |
+
district_id: str
|
| 222 |
+
summary: DistrictStateSummary
|
| 223 |
+
decision: GuidanceDecision
|
| 224 |
+
application_plan: GuidanceApplicationPlan
|
| 225 |
+
generated_step: int
|
| 226 |
+
expires_step: int
|
| 227 |
+
fallback_used: bool = False
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@dataclass(frozen=True)
|
| 231 |
+
class GuidanceGateDecision:
|
| 232 |
+
allowed: bool
|
| 233 |
+
gating_mode: str
|
| 234 |
+
triggered_conditions: tuple[str, ...]
|
| 235 |
+
blocked_reasons: tuple[str, ...]
|
| 236 |
+
avg_queue: float
|
| 237 |
+
queue_imbalance: float
|
| 238 |
+
incident_flag: bool
|
| 239 |
+
spillback_risk: bool
|
| 240 |
+
overload_flag: bool
|
| 241 |
+
|
| 242 |
+
def to_dict(self) -> dict[str, Any]:
|
| 243 |
+
return {
|
| 244 |
+
"allowed": bool(self.allowed),
|
| 245 |
+
"gating_mode": self.gating_mode,
|
| 246 |
+
"triggered_conditions": list(self.triggered_conditions),
|
| 247 |
+
"blocked_reasons": list(self.blocked_reasons),
|
| 248 |
+
"avg_queue": float(self.avg_queue),
|
| 249 |
+
"queue_imbalance": float(self.queue_imbalance),
|
| 250 |
+
"incident_flag": bool(self.incident_flag),
|
| 251 |
+
"spillback_risk": bool(self.spillback_risk),
|
| 252 |
+
"overload_flag": bool(self.overload_flag),
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
@dataclass
|
| 257 |
+
class GuidanceRefreshTrace:
|
| 258 |
+
mode_source: str
|
| 259 |
+
district_id: str
|
| 260 |
+
decision_step: int
|
| 261 |
+
summary_hash: str
|
| 262 |
+
summary_excerpt: str
|
| 263 |
+
summary_payload: dict[str, Any]
|
| 264 |
+
guidance: dict[str, Any]
|
| 265 |
+
repaired_guidance: dict[str, Any]
|
| 266 |
+
fallback_used: bool
|
| 267 |
+
fallback_policy: str
|
| 268 |
+
application_plan: dict[str, Any]
|
| 269 |
+
applied_biases: dict[str, float]
|
| 270 |
+
gate_decision: dict[str, Any] | None = None
|
| 271 |
+
|
| 272 |
+
def to_dict(self) -> dict[str, Any]:
|
| 273 |
+
return {
|
| 274 |
+
"mode_source": self.mode_source,
|
| 275 |
+
"district_id": self.district_id,
|
| 276 |
+
"decision_step": int(self.decision_step),
|
| 277 |
+
"summary_hash": self.summary_hash,
|
| 278 |
+
"summary_excerpt": self.summary_excerpt,
|
| 279 |
+
"summary": self.summary_payload,
|
| 280 |
+
"raw_guidance": self.guidance,
|
| 281 |
+
"repaired_guidance": self.repaired_guidance,
|
| 282 |
+
"fallback_used": bool(self.fallback_used),
|
| 283 |
+
"fallback_policy": self.fallback_policy,
|
| 284 |
+
"application_plan": self.application_plan,
|
| 285 |
+
"applied_biases": self.applied_biases,
|
| 286 |
+
"gate_decision": self.gate_decision,
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
@dataclass
|
| 291 |
+
class WrapperEpisodeStats:
|
| 292 |
+
step_count: int = 0
|
| 293 |
+
steps_with_active_guidance: int = 0
|
| 294 |
+
guidance_refresh_count: int = 0
|
| 295 |
+
guidance_blocked_step_count: int = 0
|
| 296 |
+
guidance_blocked_refresh_count: int = 0
|
| 297 |
+
bias_application_count: int = 0
|
| 298 |
+
noop_guidance_events: int = 0
|
| 299 |
+
fallback_event_count: int = 0
|
| 300 |
+
total_affected_intersections: int = 0
|
| 301 |
+
total_targeted_intersections: int = 0
|
| 302 |
+
total_bias_magnitude: float = 0.0
|
| 303 |
+
max_bias_magnitude: float = 0.0
|
| 304 |
+
|
| 305 |
+
def to_dict(self) -> dict[str, float]:
|
| 306 |
+
refresh_count = max(1, self.guidance_refresh_count)
|
| 307 |
+
return {
|
| 308 |
+
"num_guidance_refreshes": float(self.guidance_refresh_count),
|
| 309 |
+
"num_steps_guidance_blocked_by_gate": float(self.guidance_blocked_step_count),
|
| 310 |
+
"num_guidance_refreshes_blocked_by_gate": float(self.guidance_blocked_refresh_count),
|
| 311 |
+
"num_bias_applications": float(self.bias_application_count),
|
| 312 |
+
"num_noop_guidance_events": float(self.noop_guidance_events),
|
| 313 |
+
"fallback_policy_used_count": float(self.fallback_event_count),
|
| 314 |
+
"avg_num_affected_intersections": float(self.total_affected_intersections) / float(refresh_count),
|
| 315 |
+
"avg_num_targeted_intersections": float(self.total_targeted_intersections) / float(refresh_count),
|
| 316 |
+
"mean_bias_magnitude": float(self.total_bias_magnitude) / float(max(1, self.bias_application_count)),
|
| 317 |
+
"max_bias_magnitude": float(self.max_bias_magnitude),
|
| 318 |
+
"percent_steps_with_active_guidance": float(self.steps_with_active_guidance)
|
| 319 |
+
/ float(max(1, self.step_count)),
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@dataclass
|
| 324 |
+
class GuidedActionBatch:
|
| 325 |
+
actions: np.ndarray
|
| 326 |
+
base_actions: np.ndarray
|
| 327 |
+
base_q_values: np.ndarray
|
| 328 |
+
guided_q_values: np.ndarray
|
| 329 |
+
q_bias: np.ndarray
|
| 330 |
+
refresh_traces: list[GuidanceRefreshTrace] = field(default_factory=list)
|
| 331 |
+
runtime_seconds: float = 0.0
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class BaseGuidanceProvider(ABC):
|
| 335 |
+
source_name: str
|
| 336 |
+
|
| 337 |
+
@abstractmethod
|
| 338 |
+
def generate(self, summary: DistrictStateSummary) -> GuidanceDecision:
|
| 339 |
+
raise NotImplementedError
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class HeuristicGuidanceProvider(BaseGuidanceProvider):
|
| 343 |
+
source_name = "heuristic"
|
| 344 |
+
|
| 345 |
+
def __init__(self, config: HeuristicGuidanceConfig | None = None):
|
| 346 |
+
self.config = config or HeuristicGuidanceConfig()
|
| 347 |
+
|
| 348 |
+
def generate(self, summary: DistrictStateSummary) -> GuidanceDecision:
|
| 349 |
+
started = perf_counter()
|
| 350 |
+
action = generate_heuristic_guidance(summary=summary, config=self.config)
|
| 351 |
+
return GuidanceDecision(
|
| 352 |
+
source=self.source_name,
|
| 353 |
+
action=action,
|
| 354 |
+
runtime_seconds=perf_counter() - started,
|
| 355 |
+
parsed_payload_before_repair=action.to_dict(),
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class LLMGuidanceProvider(BaseGuidanceProvider):
|
| 360 |
+
source_name = "llm"
|
| 361 |
+
|
| 362 |
+
def __init__(self, inference: DistrictLLMInference, max_new_tokens: int = 128):
|
| 363 |
+
self.inference = inference
|
| 364 |
+
self.max_new_tokens = int(max_new_tokens)
|
| 365 |
+
|
| 366 |
+
def generate(self, summary: DistrictStateSummary) -> GuidanceDecision:
|
| 367 |
+
started = perf_counter()
|
| 368 |
+
result: DistrictLLMInferenceResult = self.inference.predict_with_result(
|
| 369 |
+
summary=summary,
|
| 370 |
+
max_new_tokens=self.max_new_tokens,
|
| 371 |
+
)
|
| 372 |
+
return GuidanceDecision(
|
| 373 |
+
source=self.source_name,
|
| 374 |
+
action=result.action,
|
| 375 |
+
runtime_seconds=perf_counter() - started,
|
| 376 |
+
raw_text=result.raw_text,
|
| 377 |
+
parsed_payload_before_repair=result.parsed_payload_before_repair,
|
| 378 |
+
repair_report=result.repair_report,
|
| 379 |
+
json_valid=result.json_valid,
|
| 380 |
+
schema_valid_before_repair=result.schema_valid_before_repair,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class FixedRLPolicyAdapter:
|
| 385 |
+
def __init__(self, checkpoint_path: str, device: str | None = None):
|
| 386 |
+
self.teacher = RLCheckpointTeacher(checkpoint_path=checkpoint_path, device=device)
|
| 387 |
+
self.device = self.teacher.device
|
| 388 |
+
|
| 389 |
+
@property
|
| 390 |
+
def env_config(self) -> Any | None:
|
| 391 |
+
return self.teacher.env_config
|
| 392 |
+
|
| 393 |
+
def decide(self, observation_batch: dict[str, Any]) -> RLPolicyDecision:
|
| 394 |
+
raw_obs = observation_batch["observations"].astype(np.float32)
|
| 395 |
+
normalized_obs = (
|
| 396 |
+
self.teacher.obs_normalizer.normalize(raw_obs)
|
| 397 |
+
if self.teacher.obs_normalizer is not None
|
| 398 |
+
else raw_obs
|
| 399 |
+
)
|
| 400 |
+
obs_tensor = torch.as_tensor(normalized_obs, dtype=torch.float32, device=self.device)
|
| 401 |
+
district_type_tensor = torch.as_tensor(
|
| 402 |
+
observation_batch["district_type_indices"],
|
| 403 |
+
dtype=torch.int64,
|
| 404 |
+
device=self.device,
|
| 405 |
+
)
|
| 406 |
+
action_mask_tensor = torch.as_tensor(
|
| 407 |
+
observation_batch["action_mask"],
|
| 408 |
+
dtype=torch.float32,
|
| 409 |
+
device=self.device,
|
| 410 |
+
)
|
| 411 |
+
with torch.no_grad():
|
| 412 |
+
q_values = self.teacher.model.forward(
|
| 413 |
+
observations=obs_tensor,
|
| 414 |
+
district_type_indices=district_type_tensor,
|
| 415 |
+
action_mask=action_mask_tensor,
|
| 416 |
+
)
|
| 417 |
+
q_values_np = q_values.detach().cpu().numpy().astype(np.float32)
|
| 418 |
+
return RLPolicyDecision(q_values=q_values_np, actions=q_values_np.argmax(axis=1).astype(np.int64))
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class DistrictGuidedRLController:
|
| 422 |
+
def __init__(
|
| 423 |
+
self,
|
| 424 |
+
policy: FixedRLPolicyAdapter,
|
| 425 |
+
mode_source: str,
|
| 426 |
+
summary_builder: DistrictStateSummaryBuilder | None = None,
|
| 427 |
+
guidance_provider: BaseGuidanceProvider | None = None,
|
| 428 |
+
influence_config: GuidanceInfluenceConfig | None = None,
|
| 429 |
+
heuristic_provider: BaseGuidanceProvider | None = None,
|
| 430 |
+
):
|
| 431 |
+
self.policy = policy
|
| 432 |
+
self.mode_source = mode_source
|
| 433 |
+
self.summary_builder = summary_builder
|
| 434 |
+
self.guidance_provider = guidance_provider
|
| 435 |
+
self.influence_config = (influence_config or GuidanceInfluenceConfig()).validate()
|
| 436 |
+
self.heuristic_provider = heuristic_provider
|
| 437 |
+
self._active_guidance: dict[str, ActiveDistrictGuidance] = {}
|
| 438 |
+
self._next_refresh_step_by_district: dict[str, int] = {}
|
| 439 |
+
self._episode_stats = WrapperEpisodeStats()
|
| 440 |
+
|
| 441 |
+
def reset(self) -> None:
|
| 442 |
+
self._active_guidance = {}
|
| 443 |
+
self._next_refresh_step_by_district = {}
|
| 444 |
+
self._episode_stats = WrapperEpisodeStats()
|
| 445 |
+
if self.summary_builder is not None:
|
| 446 |
+
self.summary_builder.reset()
|
| 447 |
+
|
| 448 |
+
def active_guidance_snapshot(self) -> dict[str, dict[str, Any]]:
|
| 449 |
+
return {
|
| 450 |
+
district_id: active.decision.action.to_dict()
|
| 451 |
+
for district_id, active in sorted(self._active_guidance.items())
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
def episode_debug_summary(self) -> dict[str, Any]:
|
| 455 |
+
payload = self._episode_stats.to_dict()
|
| 456 |
+
payload.update(
|
| 457 |
+
{
|
| 458 |
+
"wrapper_mode": self.influence_config.wrapper_mode,
|
| 459 |
+
"fallback_policy": self.influence_config.fallback_policy,
|
| 460 |
+
}
|
| 461 |
+
)
|
| 462 |
+
return payload
|
| 463 |
+
|
| 464 |
+
def act(self, env, observation_batch: dict[str, Any]) -> GuidedActionBatch:
|
| 465 |
+
started = perf_counter()
|
| 466 |
+
base_decision = self.policy.decide(observation_batch)
|
| 467 |
+
base_q_values = base_decision.q_values
|
| 468 |
+
guided_q_values = base_q_values.copy()
|
| 469 |
+
q_bias = np.zeros_like(guided_q_values, dtype=np.float32)
|
| 470 |
+
|
| 471 |
+
refresh_traces = self._refresh_guidance_if_needed(env=env, observation_batch=observation_batch)
|
| 472 |
+
if self.guidance_provider is None:
|
| 473 |
+
self._episode_stats.step_count += 1
|
| 474 |
+
return GuidedActionBatch(
|
| 475 |
+
actions=base_decision.actions.copy(),
|
| 476 |
+
base_actions=base_decision.actions,
|
| 477 |
+
base_q_values=base_q_values,
|
| 478 |
+
guided_q_values=guided_q_values,
|
| 479 |
+
q_bias=q_bias,
|
| 480 |
+
refresh_traces=refresh_traces,
|
| 481 |
+
runtime_seconds=perf_counter() - started,
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
active_any = False
|
| 485 |
+
decision_step = int(observation_batch.get("decision_step", 0))
|
| 486 |
+
candidate_lookup_by_district = {
|
| 487 |
+
district_id: {
|
| 488 |
+
item.intersection_id: item
|
| 489 |
+
for item in active.summary.candidate_intersections
|
| 490 |
+
}
|
| 491 |
+
for district_id, active in self._active_guidance.items()
|
| 492 |
+
}
|
| 493 |
+
for row_index, intersection_id in enumerate(observation_batch["intersection_ids"]):
|
| 494 |
+
district_id = str(observation_batch["district_ids"][row_index])
|
| 495 |
+
active = self._active_guidance.get(district_id)
|
| 496 |
+
if active is None:
|
| 497 |
+
continue
|
| 498 |
+
active_any = True
|
| 499 |
+
candidate = candidate_lookup_by_district[district_id].get(str(intersection_id))
|
| 500 |
+
row_bias = self._row_action_bias(
|
| 501 |
+
active=active,
|
| 502 |
+
candidate=candidate,
|
| 503 |
+
intersection_id=str(intersection_id),
|
| 504 |
+
current_phase=int(observation_batch["current_phase"][row_index]),
|
| 505 |
+
decision_step=decision_step,
|
| 506 |
+
)
|
| 507 |
+
if row_bias is None:
|
| 508 |
+
continue
|
| 509 |
+
q_bias[row_index] = row_bias
|
| 510 |
+
guided_q_values[row_index] = guided_q_values[row_index] + row_bias
|
| 511 |
+
magnitude = float(np.abs(row_bias).max())
|
| 512 |
+
self._episode_stats.bias_application_count += 1
|
| 513 |
+
self._episode_stats.total_bias_magnitude += magnitude
|
| 514 |
+
self._episode_stats.max_bias_magnitude = max(self._episode_stats.max_bias_magnitude, magnitude)
|
| 515 |
+
|
| 516 |
+
self._episode_stats.step_count += 1
|
| 517 |
+
if active_any:
|
| 518 |
+
self._episode_stats.steps_with_active_guidance += 1
|
| 519 |
+
actions = guided_q_values.argmax(axis=1).astype(np.int64)
|
| 520 |
+
return GuidedActionBatch(
|
| 521 |
+
actions=actions,
|
| 522 |
+
base_actions=base_decision.actions,
|
| 523 |
+
base_q_values=base_q_values,
|
| 524 |
+
guided_q_values=guided_q_values,
|
| 525 |
+
q_bias=q_bias,
|
| 526 |
+
refresh_traces=refresh_traces,
|
| 527 |
+
runtime_seconds=perf_counter() - started,
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
def _refresh_guidance_if_needed(
|
| 531 |
+
self,
|
| 532 |
+
env,
|
| 533 |
+
observation_batch: dict[str, Any],
|
| 534 |
+
) -> list[GuidanceRefreshTrace]:
|
| 535 |
+
if self.guidance_provider is None or self.summary_builder is None:
|
| 536 |
+
return []
|
| 537 |
+
|
| 538 |
+
decision_step = int(observation_batch.get("decision_step", 0))
|
| 539 |
+
due_districts = [
|
| 540 |
+
district_id
|
| 541 |
+
for district_id in tuple(sorted(env.districts))
|
| 542 |
+
if self._district_requires_refresh(district_id=district_id, decision_step=decision_step)
|
| 543 |
+
]
|
| 544 |
+
if not due_districts:
|
| 545 |
+
return []
|
| 546 |
+
|
| 547 |
+
summaries = self.summary_builder.build_all(env, observation_batch)
|
| 548 |
+
refresh_traces: list[GuidanceRefreshTrace] = []
|
| 549 |
+
gate_blocked_this_step = False
|
| 550 |
+
for district_id in due_districts:
|
| 551 |
+
summary = summaries[district_id]
|
| 552 |
+
previous_active = self._active_guidance.get(district_id)
|
| 553 |
+
gate_decision = _evaluate_guidance_gate(summary=summary, config=self.influence_config)
|
| 554 |
+
if not gate_decision.allowed:
|
| 555 |
+
gate_blocked_this_step = True
|
| 556 |
+
self._active_guidance.pop(district_id, None)
|
| 557 |
+
self._next_refresh_step_by_district[district_id] = (
|
| 558 |
+
decision_step + self._resolve_blocked_refresh_horizon()
|
| 559 |
+
)
|
| 560 |
+
self._episode_stats.guidance_refresh_count += 1
|
| 561 |
+
self._episode_stats.guidance_blocked_refresh_count += 1
|
| 562 |
+
self._episode_stats.noop_guidance_events += 1
|
| 563 |
+
decision = GuidanceDecision(
|
| 564 |
+
source=f"{self.mode_source}_gate_blocked",
|
| 565 |
+
action=DistrictAction.default_hold(),
|
| 566 |
+
runtime_seconds=0.0,
|
| 567 |
+
fallback_policy_applied="gate_blocked",
|
| 568 |
+
)
|
| 569 |
+
application_plan = _build_application_plan(
|
| 570 |
+
summary=summary,
|
| 571 |
+
action=decision.action,
|
| 572 |
+
config=replace(self.influence_config, wrapper_mode="no_op"),
|
| 573 |
+
district_intersection_ids=tuple(env.districts[district_id].intersection_ids),
|
| 574 |
+
)
|
| 575 |
+
trace = GuidanceRefreshTrace(
|
| 576 |
+
mode_source=self.mode_source,
|
| 577 |
+
district_id=district_id,
|
| 578 |
+
decision_step=decision_step,
|
| 579 |
+
summary_hash=_summary_hash(summary),
|
| 580 |
+
summary_excerpt=summary.to_prompt_text()[:240],
|
| 581 |
+
summary_payload=summary.to_dict(),
|
| 582 |
+
guidance=decision.to_trace_payload(),
|
| 583 |
+
repaired_guidance=decision.action.to_dict(),
|
| 584 |
+
fallback_used=False,
|
| 585 |
+
fallback_policy="gate_blocked",
|
| 586 |
+
application_plan=application_plan.to_dict(),
|
| 587 |
+
applied_biases={
|
| 588 |
+
"base": 0.0,
|
| 589 |
+
"target": 0.0,
|
| 590 |
+
"corridor": 0.0,
|
| 591 |
+
"strength_scale": 0.0,
|
| 592 |
+
},
|
| 593 |
+
gate_decision=gate_decision.to_dict(),
|
| 594 |
+
)
|
| 595 |
+
refresh_traces.append(trace)
|
| 596 |
+
if self.influence_config.log_guidance_debug:
|
| 597 |
+
_log_guidance_debug(trace)
|
| 598 |
+
continue
|
| 599 |
+
|
| 600 |
+
decision, fallback_used = self._generate_guidance(
|
| 601 |
+
district_id=district_id,
|
| 602 |
+
summary=summary,
|
| 603 |
+
previous_active=previous_active,
|
| 604 |
+
)
|
| 605 |
+
application_plan = _build_application_plan(
|
| 606 |
+
summary=summary,
|
| 607 |
+
action=decision.action,
|
| 608 |
+
config=self.influence_config,
|
| 609 |
+
district_intersection_ids=tuple(env.districts[district_id].intersection_ids),
|
| 610 |
+
)
|
| 611 |
+
expires_step = decision_step + self._resolve_refresh_horizon(decision.action)
|
| 612 |
+
active = ActiveDistrictGuidance(
|
| 613 |
+
district_id=district_id,
|
| 614 |
+
summary=summary,
|
| 615 |
+
decision=decision,
|
| 616 |
+
application_plan=application_plan,
|
| 617 |
+
generated_step=decision_step,
|
| 618 |
+
expires_step=expires_step,
|
| 619 |
+
fallback_used=fallback_used,
|
| 620 |
+
)
|
| 621 |
+
self._active_guidance[district_id] = active
|
| 622 |
+
self._next_refresh_step_by_district[district_id] = int(expires_step)
|
| 623 |
+
|
| 624 |
+
self._episode_stats.guidance_refresh_count += 1
|
| 625 |
+
self._episode_stats.total_affected_intersections += len(application_plan.affected_intersections)
|
| 626 |
+
self._episode_stats.total_targeted_intersections += len(application_plan.targeted_intersections)
|
| 627 |
+
if application_plan.wrapper_mode == "no_op" or not application_plan.affected_intersections:
|
| 628 |
+
self._episode_stats.noop_guidance_events += 1
|
| 629 |
+
if fallback_used:
|
| 630 |
+
self._episode_stats.fallback_event_count += 1
|
| 631 |
+
|
| 632 |
+
trace = GuidanceRefreshTrace(
|
| 633 |
+
mode_source=self.mode_source,
|
| 634 |
+
district_id=district_id,
|
| 635 |
+
decision_step=decision_step,
|
| 636 |
+
summary_hash=_summary_hash(summary),
|
| 637 |
+
summary_excerpt=summary.to_prompt_text()[:240],
|
| 638 |
+
summary_payload=summary.to_dict(),
|
| 639 |
+
guidance=decision.to_trace_payload(),
|
| 640 |
+
repaired_guidance=decision.action.to_dict(),
|
| 641 |
+
fallback_used=fallback_used,
|
| 642 |
+
fallback_policy=self.influence_config.fallback_policy if fallback_used else "none",
|
| 643 |
+
application_plan=application_plan.to_dict(),
|
| 644 |
+
applied_biases={
|
| 645 |
+
"base": float(application_plan.base_bias_strength),
|
| 646 |
+
"target": float(application_plan.target_bias_strength),
|
| 647 |
+
"corridor": float(application_plan.corridor_bias_strength),
|
| 648 |
+
"strength_scale": float(application_plan.strength_scale),
|
| 649 |
+
},
|
| 650 |
+
gate_decision=gate_decision.to_dict(),
|
| 651 |
+
)
|
| 652 |
+
refresh_traces.append(trace)
|
| 653 |
+
if self.influence_config.log_guidance_debug:
|
| 654 |
+
_log_guidance_debug(trace)
|
| 655 |
+
if gate_blocked_this_step:
|
| 656 |
+
self._episode_stats.guidance_blocked_step_count += 1
|
| 657 |
+
return refresh_traces
|
| 658 |
+
|
| 659 |
+
def _generate_guidance(
|
| 660 |
+
self,
|
| 661 |
+
district_id: str,
|
| 662 |
+
summary: DistrictStateSummary,
|
| 663 |
+
previous_active: ActiveDistrictGuidance | None,
|
| 664 |
+
) -> tuple[GuidanceDecision, bool]:
|
| 665 |
+
fallback_used = False
|
| 666 |
+
try:
|
| 667 |
+
decision = self.guidance_provider.generate(summary)
|
| 668 |
+
except Exception as exc:
|
| 669 |
+
decision = GuidanceDecision(
|
| 670 |
+
source=self.guidance_provider.source_name,
|
| 671 |
+
action=DistrictAction.default_hold(),
|
| 672 |
+
runtime_seconds=0.0,
|
| 673 |
+
provider_error=str(exc),
|
| 674 |
+
json_valid=False,
|
| 675 |
+
schema_valid_before_repair=False,
|
| 676 |
+
)
|
| 677 |
+
if not _should_fallback(decision):
|
| 678 |
+
return decision, fallback_used
|
| 679 |
+
|
| 680 |
+
fallback_used = True
|
| 681 |
+
fallback_policy = self.influence_config.fallback_policy
|
| 682 |
+
if fallback_policy == "hold_previous" and previous_active is not None:
|
| 683 |
+
fallback_decision = GuidanceDecision(
|
| 684 |
+
source=f"{decision.source}_fallback_hold_previous",
|
| 685 |
+
action=previous_active.decision.action,
|
| 686 |
+
runtime_seconds=decision.runtime_seconds,
|
| 687 |
+
raw_text=decision.raw_text,
|
| 688 |
+
parsed_payload_before_repair=decision.parsed_payload_before_repair,
|
| 689 |
+
repair_report=decision.repair_report,
|
| 690 |
+
json_valid=decision.json_valid,
|
| 691 |
+
schema_valid_before_repair=decision.schema_valid_before_repair,
|
| 692 |
+
provider_error=decision.provider_error,
|
| 693 |
+
fallback_policy_applied=fallback_policy,
|
| 694 |
+
)
|
| 695 |
+
return fallback_decision, fallback_used
|
| 696 |
+
|
| 697 |
+
if fallback_policy == "heuristic_weak" and self.heuristic_provider is not None:
|
| 698 |
+
fallback_decision = self.heuristic_provider.generate(summary)
|
| 699 |
+
fallback_decision.fallback_policy_applied = fallback_policy
|
| 700 |
+
return fallback_decision, fallback_used
|
| 701 |
+
|
| 702 |
+
fallback_decision = GuidanceDecision(
|
| 703 |
+
source=f"{decision.source}_fallback_no_op",
|
| 704 |
+
action=DistrictAction.default_hold(),
|
| 705 |
+
runtime_seconds=decision.runtime_seconds,
|
| 706 |
+
raw_text=decision.raw_text,
|
| 707 |
+
parsed_payload_before_repair=decision.parsed_payload_before_repair,
|
| 708 |
+
repair_report=decision.repair_report,
|
| 709 |
+
json_valid=decision.json_valid,
|
| 710 |
+
schema_valid_before_repair=decision.schema_valid_before_repair,
|
| 711 |
+
provider_error=decision.provider_error,
|
| 712 |
+
fallback_policy_applied=fallback_policy,
|
| 713 |
+
)
|
| 714 |
+
return fallback_decision, fallback_used
|
| 715 |
+
|
| 716 |
+
def _district_requires_refresh(self, district_id: str, decision_step: int) -> bool:
|
| 717 |
+
next_refresh_step = self._next_refresh_step_by_district.get(district_id)
|
| 718 |
+
if next_refresh_step is None:
|
| 719 |
+
return True
|
| 720 |
+
return decision_step >= int(next_refresh_step)
|
| 721 |
+
|
| 722 |
+
def _resolve_refresh_horizon(self, action: DistrictAction) -> int:
|
| 723 |
+
requested = max(1, min(int(action.duration_steps), self.influence_config.max_guidance_duration))
|
| 724 |
+
return min(
|
| 725 |
+
requested,
|
| 726 |
+
int(self.influence_config.guidance_refresh_steps),
|
| 727 |
+
int(self.influence_config.guidance_persistence_steps),
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
def _resolve_blocked_refresh_horizon(self) -> int:
|
| 731 |
+
return max(
|
| 732 |
+
1,
|
| 733 |
+
min(
|
| 734 |
+
int(self.influence_config.guidance_refresh_steps),
|
| 735 |
+
int(self.influence_config.guidance_persistence_steps),
|
| 736 |
+
),
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
def _row_action_bias(
|
| 740 |
+
self,
|
| 741 |
+
active: ActiveDistrictGuidance,
|
| 742 |
+
candidate: CandidateIntersection | None,
|
| 743 |
+
intersection_id: str,
|
| 744 |
+
current_phase: int,
|
| 745 |
+
decision_step: int,
|
| 746 |
+
) -> np.ndarray | None:
|
| 747 |
+
plan = active.application_plan
|
| 748 |
+
if plan.wrapper_mode == "no_op":
|
| 749 |
+
return None
|
| 750 |
+
if intersection_id not in set(plan.affected_intersections):
|
| 751 |
+
return None
|
| 752 |
+
|
| 753 |
+
preferred_action = _preferred_action_for_direction(
|
| 754 |
+
direction=plan.priority_direction,
|
| 755 |
+
current_phase=current_phase,
|
| 756 |
+
)
|
| 757 |
+
if preferred_action is None:
|
| 758 |
+
return None
|
| 759 |
+
|
| 760 |
+
decay = 1.0
|
| 761 |
+
if self.influence_config.enable_bias_decay:
|
| 762 |
+
horizon = max(1, active.expires_step - active.generated_step)
|
| 763 |
+
age = max(0, decision_step - active.generated_step)
|
| 764 |
+
if self.influence_config.bias_decay_schedule == "linear":
|
| 765 |
+
decay = max(0.25, 1.0 - (float(age) / float(horizon)))
|
| 766 |
+
|
| 767 |
+
magnitude = plan.base_bias_strength * plan.strength_scale * decay
|
| 768 |
+
if intersection_id in set(plan.targeted_intersections):
|
| 769 |
+
magnitude += plan.target_bias_strength * plan.strength_scale * decay
|
| 770 |
+
if candidate is not None and plan.priority_direction in {"NS", "EW"}:
|
| 771 |
+
if candidate.corridor_alignment == plan.priority_direction:
|
| 772 |
+
magnitude += plan.corridor_bias_strength * plan.strength_scale * decay
|
| 773 |
+
if candidate.is_boundary and plan.scope in {"corridor_local", "global"}:
|
| 774 |
+
magnitude += 0.5 * plan.corridor_bias_strength * plan.strength_scale * decay
|
| 775 |
+
|
| 776 |
+
strategy_multiplier = STRATEGY_BIAS_MULTIPLIERS.get(active.decision.action.strategy, 1.0)
|
| 777 |
+
magnitude *= strategy_multiplier
|
| 778 |
+
if magnitude <= 0.0:
|
| 779 |
+
return None
|
| 780 |
+
|
| 781 |
+
bias = np.zeros(2, dtype=np.float32)
|
| 782 |
+
bias[preferred_action] += float(magnitude)
|
| 783 |
+
return bias
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
def _build_application_plan(
|
| 787 |
+
summary: DistrictStateSummary,
|
| 788 |
+
action: DistrictAction,
|
| 789 |
+
config: GuidanceInfluenceConfig,
|
| 790 |
+
district_intersection_ids: tuple[str, ...],
|
| 791 |
+
) -> GuidanceApplicationPlan:
|
| 792 |
+
wrapper_mode = config.wrapper_mode
|
| 793 |
+
target_ids = tuple(
|
| 794 |
+
intersection_id
|
| 795 |
+
for intersection_id in action.target_intersections
|
| 796 |
+
if intersection_id in {item.intersection_id for item in summary.candidate_intersections}
|
| 797 |
+
)
|
| 798 |
+
candidate_lookup = {
|
| 799 |
+
item.intersection_id: item
|
| 800 |
+
for item in summary.candidate_intersections
|
| 801 |
+
}
|
| 802 |
+
priority_direction = _resolve_guidance_direction(action=action, summary=summary)
|
| 803 |
+
if wrapper_mode == "no_op":
|
| 804 |
+
return GuidanceApplicationPlan(
|
| 805 |
+
wrapper_mode=wrapper_mode,
|
| 806 |
+
scope="none",
|
| 807 |
+
affected_intersections=(),
|
| 808 |
+
targeted_intersections=target_ids,
|
| 809 |
+
target_candidate_ids=tuple(candidate_lookup),
|
| 810 |
+
priority_direction=priority_direction,
|
| 811 |
+
strength_scale=0.0,
|
| 812 |
+
base_bias_strength=0.0,
|
| 813 |
+
target_bias_strength=0.0,
|
| 814 |
+
corridor_bias_strength=0.0,
|
| 815 |
+
apply_global_bias=False,
|
| 816 |
+
apply_target_only=True,
|
| 817 |
+
max_intersections_affected=0,
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
if wrapper_mode == "current_legacy":
|
| 821 |
+
affected = tuple(district_intersection_ids)
|
| 822 |
+
return GuidanceApplicationPlan(
|
| 823 |
+
wrapper_mode=wrapper_mode,
|
| 824 |
+
scope="global",
|
| 825 |
+
affected_intersections=affected,
|
| 826 |
+
targeted_intersections=target_ids,
|
| 827 |
+
target_candidate_ids=tuple(candidate_lookup),
|
| 828 |
+
priority_direction=priority_direction,
|
| 829 |
+
strength_scale=1.0,
|
| 830 |
+
base_bias_strength=float(max(config.bias_strength, 0.75)),
|
| 831 |
+
target_bias_strength=float(max(config.target_only_bias_strength, 1.25)),
|
| 832 |
+
corridor_bias_strength=float(max(config.corridor_bias_strength, 0.5)),
|
| 833 |
+
apply_global_bias=True,
|
| 834 |
+
apply_target_only=False,
|
| 835 |
+
max_intersections_affected=max(len(affected), config.max_intersections_affected),
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
if wrapper_mode in {"target_only_soft", "target_only_medium"}:
|
| 839 |
+
strength_scale = 0.5 if wrapper_mode == "target_only_soft" else 1.0
|
| 840 |
+
affected = target_ids[: config.max_intersections_affected]
|
| 841 |
+
return GuidanceApplicationPlan(
|
| 842 |
+
wrapper_mode=wrapper_mode,
|
| 843 |
+
scope="targeted",
|
| 844 |
+
affected_intersections=affected,
|
| 845 |
+
targeted_intersections=target_ids,
|
| 846 |
+
target_candidate_ids=tuple(candidate_lookup),
|
| 847 |
+
priority_direction=priority_direction,
|
| 848 |
+
strength_scale=strength_scale,
|
| 849 |
+
base_bias_strength=float(config.bias_strength),
|
| 850 |
+
target_bias_strength=float(config.target_only_bias_strength),
|
| 851 |
+
corridor_bias_strength=float(config.corridor_bias_strength),
|
| 852 |
+
apply_global_bias=False,
|
| 853 |
+
apply_target_only=True,
|
| 854 |
+
max_intersections_affected=config.max_intersections_affected,
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
if wrapper_mode == "corridor_soft":
|
| 858 |
+
ranked = list(target_ids)
|
| 859 |
+
extras = [
|
| 860 |
+
item.intersection_id
|
| 861 |
+
for item in summary.candidate_intersections
|
| 862 |
+
if item.intersection_id not in ranked
|
| 863 |
+
and item.is_boundary
|
| 864 |
+
and (priority_direction is None or item.corridor_alignment == priority_direction)
|
| 865 |
+
]
|
| 866 |
+
affected = tuple((ranked + extras)[: config.max_intersections_affected])
|
| 867 |
+
return GuidanceApplicationPlan(
|
| 868 |
+
wrapper_mode=wrapper_mode,
|
| 869 |
+
scope="corridor_local",
|
| 870 |
+
affected_intersections=affected,
|
| 871 |
+
targeted_intersections=target_ids,
|
| 872 |
+
target_candidate_ids=tuple(candidate_lookup),
|
| 873 |
+
priority_direction=priority_direction,
|
| 874 |
+
strength_scale=0.6,
|
| 875 |
+
base_bias_strength=float(config.bias_strength),
|
| 876 |
+
target_bias_strength=float(config.target_only_bias_strength),
|
| 877 |
+
corridor_bias_strength=float(config.corridor_bias_strength),
|
| 878 |
+
apply_global_bias=False,
|
| 879 |
+
apply_target_only=False,
|
| 880 |
+
max_intersections_affected=config.max_intersections_affected,
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
affected_global = tuple(district_intersection_ids)
|
| 884 |
+
return GuidanceApplicationPlan(
|
| 885 |
+
wrapper_mode="global_soft",
|
| 886 |
+
scope="global",
|
| 887 |
+
affected_intersections=affected_global,
|
| 888 |
+
targeted_intersections=target_ids,
|
| 889 |
+
target_candidate_ids=tuple(candidate_lookup),
|
| 890 |
+
priority_direction=priority_direction,
|
| 891 |
+
strength_scale=0.35,
|
| 892 |
+
base_bias_strength=float(config.bias_strength),
|
| 893 |
+
target_bias_strength=float(config.target_only_bias_strength),
|
| 894 |
+
corridor_bias_strength=float(config.corridor_bias_strength),
|
| 895 |
+
apply_global_bias=True,
|
| 896 |
+
apply_target_only=False,
|
| 897 |
+
max_intersections_affected=config.max_intersections_affected,
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
|
| 901 |
+
def _should_fallback(decision: GuidanceDecision) -> bool:
|
| 902 |
+
if decision.provider_error is not None:
|
| 903 |
+
return True
|
| 904 |
+
if not decision.json_valid or not decision.schema_valid_before_repair:
|
| 905 |
+
return True
|
| 906 |
+
report = decision.repair_report
|
| 907 |
+
if report is None:
|
| 908 |
+
return False
|
| 909 |
+
return bool(
|
| 910 |
+
report.fallback_used
|
| 911 |
+
or report.empty_after_filtering
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
def _evaluate_guidance_gate(
|
| 916 |
+
summary: DistrictStateSummary,
|
| 917 |
+
config: GuidanceInfluenceConfig,
|
| 918 |
+
) -> GuidanceGateDecision:
|
| 919 |
+
queue_imbalance = abs(float(summary.ns_queue) - float(summary.ew_queue))
|
| 920 |
+
queue_trigger = float(summary.avg_queue) >= float(config.min_avg_queue_for_guidance)
|
| 921 |
+
imbalance_trigger = queue_imbalance >= float(config.min_queue_imbalance_for_guidance)
|
| 922 |
+
incident_or_spillback = bool(summary.incident_flag or summary.spillback_risk or summary.overload_flag)
|
| 923 |
+
triggers = {
|
| 924 |
+
"incident_or_spillback": incident_or_spillback,
|
| 925 |
+
"queue_threshold": queue_trigger,
|
| 926 |
+
"imbalance_threshold": imbalance_trigger,
|
| 927 |
+
}
|
| 928 |
+
triggered_conditions = tuple(name for name, active in triggers.items() if active)
|
| 929 |
+
|
| 930 |
+
if config.gating_mode == "always_on":
|
| 931 |
+
allowed = True
|
| 932 |
+
elif config.gating_mode == "incident_or_spillback":
|
| 933 |
+
allowed = incident_or_spillback
|
| 934 |
+
elif config.gating_mode == "queue_threshold":
|
| 935 |
+
allowed = queue_trigger
|
| 936 |
+
elif config.gating_mode == "imbalance_threshold":
|
| 937 |
+
allowed = imbalance_trigger
|
| 938 |
+
elif config.gating_mode == "queue_or_imbalance":
|
| 939 |
+
allowed = queue_trigger or imbalance_trigger
|
| 940 |
+
else:
|
| 941 |
+
allowed = incident_or_spillback or queue_trigger or imbalance_trigger
|
| 942 |
+
|
| 943 |
+
blocked_reasons: list[str] = []
|
| 944 |
+
if config.require_incident_or_spillback and not incident_or_spillback:
|
| 945 |
+
allowed = False
|
| 946 |
+
blocked_reasons.append("requires_incident_or_spillback")
|
| 947 |
+
if not config.allow_guidance_in_normal_conditions and not triggered_conditions:
|
| 948 |
+
allowed = False
|
| 949 |
+
blocked_reasons.append("normal_conditions_blocked")
|
| 950 |
+
if not allowed and not blocked_reasons:
|
| 951 |
+
blocked_reasons.append(f"gating_mode:{config.gating_mode}")
|
| 952 |
+
|
| 953 |
+
return GuidanceGateDecision(
|
| 954 |
+
allowed=allowed,
|
| 955 |
+
gating_mode=config.gating_mode,
|
| 956 |
+
triggered_conditions=triggered_conditions,
|
| 957 |
+
blocked_reasons=tuple(blocked_reasons),
|
| 958 |
+
avg_queue=float(summary.avg_queue),
|
| 959 |
+
queue_imbalance=float(queue_imbalance),
|
| 960 |
+
incident_flag=bool(summary.incident_flag),
|
| 961 |
+
spillback_risk=bool(summary.spillback_risk),
|
| 962 |
+
overload_flag=bool(summary.overload_flag),
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
|
| 966 |
+
def _resolve_guidance_direction(action: DistrictAction, summary: DistrictStateSummary) -> str | None:
|
| 967 |
+
if action.phase_bias in {"NS", "EW"}:
|
| 968 |
+
return action.phase_bias
|
| 969 |
+
if action.priority_corridor in {"NS", "EW"}:
|
| 970 |
+
return action.priority_corridor
|
| 971 |
+
if summary.dominant_flow in {"NS", "EW"}:
|
| 972 |
+
return summary.dominant_flow
|
| 973 |
+
return None
|
| 974 |
+
|
| 975 |
+
|
| 976 |
+
def _preferred_action_for_direction(direction: str | None, current_phase: int) -> int | None:
|
| 977 |
+
if direction == "NS":
|
| 978 |
+
return 0 if current_phase == 0 else 1
|
| 979 |
+
if direction == "EW":
|
| 980 |
+
return 0 if current_phase != 0 else 1
|
| 981 |
+
return None
|
| 982 |
+
|
| 983 |
+
|
| 984 |
+
def _summary_hash(summary: DistrictStateSummary) -> str:
|
| 985 |
+
return hashlib.sha1(summary.to_json().encode("utf-8")).hexdigest()[:16]
|
| 986 |
+
|
| 987 |
+
|
| 988 |
+
def guidance_config_payload(config: GuidanceInfluenceConfig) -> dict[str, Any]:
|
| 989 |
+
return asdict(config.validate())
|
| 990 |
+
|
| 991 |
+
|
| 992 |
+
def _log_guidance_debug(trace: GuidanceRefreshTrace) -> None:
|
| 993 |
+
print(
|
| 994 |
+
"[guidance-debug] "
|
| 995 |
+
f"mode={trace.mode_source} "
|
| 996 |
+
f"district={trace.district_id} "
|
| 997 |
+
f"wrapper_mode={trace.application_plan['wrapper_mode']} "
|
| 998 |
+
f"gate_allowed={trace.gate_decision.get('allowed') if trace.gate_decision else True} "
|
| 999 |
+
f"scope={trace.application_plan['scope']} "
|
| 1000 |
+
f"targets={trace.repaired_guidance.get('target_intersections', [])} "
|
| 1001 |
+
f"affected={trace.application_plan['affected_intersections']} "
|
| 1002 |
+
f"fallback_used={trace.fallback_used} "
|
| 1003 |
+
f"fallback_policy={trace.fallback_policy}"
|
| 1004 |
+
)
|
district_llm/schema.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
DISTRICT_STRATEGIES: tuple[str, ...] = (
|
| 9 |
+
"hold",
|
| 10 |
+
"favor_NS",
|
| 11 |
+
"favor_EW",
|
| 12 |
+
"drain_inbound",
|
| 13 |
+
"drain_outbound",
|
| 14 |
+
"clear_spillback",
|
| 15 |
+
"incident_response",
|
| 16 |
+
"arterial_priority",
|
| 17 |
+
)
|
| 18 |
+
PHASE_BIASES: tuple[str, ...] = ("NONE", "NS", "EW")
|
| 19 |
+
PRIORITY_CORRIDORS: tuple[str, ...] = (
|
| 20 |
+
"NS",
|
| 21 |
+
"EW",
|
| 22 |
+
"inbound",
|
| 23 |
+
"outbound",
|
| 24 |
+
"arterial",
|
| 25 |
+
)
|
| 26 |
+
DOMINANT_FLOWS: tuple[str, ...] = ("NS", "EW", "BALANCED")
|
| 27 |
+
CANDIDATE_REASON_TAGS: tuple[str, ...] = (
|
| 28 |
+
"congested",
|
| 29 |
+
"boundary",
|
| 30 |
+
"spillback",
|
| 31 |
+
"incident",
|
| 32 |
+
"outgoing",
|
| 33 |
+
"overload",
|
| 34 |
+
"event",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _round_float(value: float, digits: int = 3) -> float:
|
| 39 |
+
return round(float(value), digits)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _dedupe_string_list(values: list[str] | tuple[str, ...] | None, limit: int | None = None) -> list[str]:
|
| 43 |
+
normalized: list[str] = []
|
| 44 |
+
seen: set[str] = set()
|
| 45 |
+
for item in values or []:
|
| 46 |
+
value = str(item).strip()
|
| 47 |
+
if not value or value in seen:
|
| 48 |
+
continue
|
| 49 |
+
normalized.append(value)
|
| 50 |
+
seen.add(value)
|
| 51 |
+
if limit is not None and len(normalized) >= limit:
|
| 52 |
+
break
|
| 53 |
+
return normalized
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _stable_reason_list(values: list[str] | tuple[str, ...] | None) -> list[str]:
|
| 57 |
+
present = {str(item).strip() for item in (values or []) if str(item).strip()}
|
| 58 |
+
return [item for item in CANDIDATE_REASON_TAGS if item in present]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def candidate_priority_score(candidate: "CandidateIntersection | dict[str, Any]") -> float:
|
| 62 |
+
item = candidate.to_dict() if hasattr(candidate, "to_dict") else dict(candidate)
|
| 63 |
+
queue_total = float(item.get("queue_total", 0.0))
|
| 64 |
+
wait_total = float(item.get("wait_total", 0.0))
|
| 65 |
+
outgoing_load = float(item.get("outgoing_load", 0.0))
|
| 66 |
+
score = queue_total + 1.5 * wait_total + 0.5 * outgoing_load
|
| 67 |
+
score += 2.0 * float(bool(item.get("spillback_risk", False)))
|
| 68 |
+
score += 1.5 * float(bool(item.get("incident_proximity", False)))
|
| 69 |
+
score += 1.0 * float(bool(item.get("is_boundary", False)))
|
| 70 |
+
score += 0.75 * float(bool(item.get("event_proximity", False)))
|
| 71 |
+
score += 0.75 * float(bool(item.get("overload_marker", False)))
|
| 72 |
+
return score
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def candidate_priority_tuple(candidate: "CandidateIntersection | dict[str, Any]") -> tuple[float, float, float, float, str]:
|
| 76 |
+
item = candidate.to_dict() if hasattr(candidate, "to_dict") else dict(candidate)
|
| 77 |
+
return (
|
| 78 |
+
candidate_priority_score(item),
|
| 79 |
+
float(item.get("queue_total", 0.0)),
|
| 80 |
+
float(item.get("wait_total", 0.0)),
|
| 81 |
+
float(item.get("outgoing_load", 0.0)),
|
| 82 |
+
str(item.get("intersection_id", "")),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def canonicalize_target_intersections(
|
| 87 |
+
targets: list[str] | tuple[str, ...] | None,
|
| 88 |
+
candidates: list["CandidateIntersection | dict[str, Any]"] | None = None,
|
| 89 |
+
limit: int | None = None,
|
| 90 |
+
) -> list[str]:
|
| 91 |
+
normalized = _dedupe_string_list(targets, limit=None)
|
| 92 |
+
if not candidates:
|
| 93 |
+
return normalized[:limit] if limit is not None else normalized
|
| 94 |
+
|
| 95 |
+
candidate_order = {
|
| 96 |
+
str(candidate.to_dict()["intersection_id"] if hasattr(candidate, "to_dict") else candidate["intersection_id"]): (
|
| 97 |
+
-candidate_priority_tuple(candidate)[0],
|
| 98 |
+
-candidate_priority_tuple(candidate)[1],
|
| 99 |
+
-candidate_priority_tuple(candidate)[2],
|
| 100 |
+
-candidate_priority_tuple(candidate)[3],
|
| 101 |
+
candidate_priority_tuple(candidate)[4],
|
| 102 |
+
)
|
| 103 |
+
for candidate in candidates
|
| 104 |
+
}
|
| 105 |
+
normalized.sort(key=lambda item: candidate_order.get(item, (1.0, 1.0, 1.0, 1.0, item)))
|
| 106 |
+
if limit is not None:
|
| 107 |
+
normalized = normalized[:limit]
|
| 108 |
+
return normalized
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@dataclass
|
| 112 |
+
class CongestedIntersection:
|
| 113 |
+
intersection_id: str
|
| 114 |
+
queue_total: float
|
| 115 |
+
wait_total: float
|
| 116 |
+
outgoing_load: float
|
| 117 |
+
current_phase: int
|
| 118 |
+
is_boundary: bool
|
| 119 |
+
|
| 120 |
+
def to_dict(self) -> dict[str, Any]:
|
| 121 |
+
return {
|
| 122 |
+
"intersection_id": self.intersection_id,
|
| 123 |
+
"queue_total": _round_float(self.queue_total),
|
| 124 |
+
"wait_total": _round_float(self.wait_total),
|
| 125 |
+
"outgoing_load": _round_float(self.outgoing_load),
|
| 126 |
+
"current_phase": int(self.current_phase),
|
| 127 |
+
"is_boundary": bool(self.is_boundary),
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
def to_prompt_line(self) -> str:
|
| 131 |
+
return (
|
| 132 |
+
f"- {self.intersection_id} "
|
| 133 |
+
f"q={self.queue_total:.2f} "
|
| 134 |
+
f"w={self.wait_total:.2f} "
|
| 135 |
+
f"out={self.outgoing_load:.2f} "
|
| 136 |
+
f"phase={self.current_phase} "
|
| 137 |
+
f"boundary={int(self.is_boundary)}"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@dataclass
|
| 142 |
+
class CandidateIntersection:
|
| 143 |
+
intersection_id: str
|
| 144 |
+
queue_total: float
|
| 145 |
+
wait_total: float
|
| 146 |
+
outgoing_load: float
|
| 147 |
+
current_phase: int
|
| 148 |
+
is_boundary: bool
|
| 149 |
+
spillback_risk: bool = False
|
| 150 |
+
incident_proximity: bool = False
|
| 151 |
+
overload_marker: bool = False
|
| 152 |
+
event_proximity: bool = False
|
| 153 |
+
corridor_alignment: str = "BALANCED"
|
| 154 |
+
selection_reasons: list[str] = field(default_factory=list)
|
| 155 |
+
|
| 156 |
+
def validate(self) -> "CandidateIntersection":
|
| 157 |
+
if self.corridor_alignment not in DOMINANT_FLOWS:
|
| 158 |
+
raise ValueError(
|
| 159 |
+
f"Invalid corridor_alignment '{self.corridor_alignment}'. Expected one of {DOMINANT_FLOWS}."
|
| 160 |
+
)
|
| 161 |
+
self.selection_reasons = _stable_reason_list(self.selection_reasons)
|
| 162 |
+
return self
|
| 163 |
+
|
| 164 |
+
def to_dict(self) -> dict[str, Any]:
|
| 165 |
+
self.validate()
|
| 166 |
+
return {
|
| 167 |
+
"intersection_id": self.intersection_id,
|
| 168 |
+
"queue_total": _round_float(self.queue_total),
|
| 169 |
+
"wait_total": _round_float(self.wait_total),
|
| 170 |
+
"outgoing_load": _round_float(self.outgoing_load),
|
| 171 |
+
"current_phase": int(self.current_phase),
|
| 172 |
+
"is_boundary": bool(self.is_boundary),
|
| 173 |
+
"spillback_risk": bool(self.spillback_risk),
|
| 174 |
+
"incident_proximity": bool(self.incident_proximity),
|
| 175 |
+
"overload_marker": bool(self.overload_marker),
|
| 176 |
+
"event_proximity": bool(self.event_proximity),
|
| 177 |
+
"corridor_alignment": self.corridor_alignment,
|
| 178 |
+
"selection_reasons": list(self.selection_reasons),
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
def to_prompt_line(self) -> str:
|
| 182 |
+
self.validate()
|
| 183 |
+
reasons = "|".join(self.selection_reasons) if self.selection_reasons else "none"
|
| 184 |
+
return (
|
| 185 |
+
f"- {self.intersection_id} "
|
| 186 |
+
f"q={self.queue_total:.2f} "
|
| 187 |
+
f"w={self.wait_total:.2f} "
|
| 188 |
+
f"out={self.outgoing_load:.2f} "
|
| 189 |
+
f"phase={self.current_phase} "
|
| 190 |
+
f"boundary={int(self.is_boundary)} "
|
| 191 |
+
f"spillback={int(self.spillback_risk)} "
|
| 192 |
+
f"incident={int(self.incident_proximity)} "
|
| 193 |
+
f"overload={int(self.overload_marker)} "
|
| 194 |
+
f"event={int(self.event_proximity)} "
|
| 195 |
+
f"align={self.corridor_alignment} "
|
| 196 |
+
f"reasons={reasons}"
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@dataclass
|
| 201 |
+
class DistrictAction:
|
| 202 |
+
strategy: str = "hold"
|
| 203 |
+
priority_corridor: str | None = None
|
| 204 |
+
target_intersections: list[str] = field(default_factory=list)
|
| 205 |
+
phase_bias: str = "NONE"
|
| 206 |
+
duration_steps: int = 1
|
| 207 |
+
|
| 208 |
+
def validate(self) -> "DistrictAction":
|
| 209 |
+
if self.strategy not in DISTRICT_STRATEGIES:
|
| 210 |
+
raise ValueError(
|
| 211 |
+
f"Invalid strategy '{self.strategy}'. Expected one of {DISTRICT_STRATEGIES}."
|
| 212 |
+
)
|
| 213 |
+
if self.priority_corridor is not None and self.priority_corridor not in PRIORITY_CORRIDORS:
|
| 214 |
+
raise ValueError(
|
| 215 |
+
f"Invalid priority_corridor '{self.priority_corridor}'. "
|
| 216 |
+
f"Expected one of {PRIORITY_CORRIDORS} or None."
|
| 217 |
+
)
|
| 218 |
+
if self.phase_bias not in PHASE_BIASES:
|
| 219 |
+
raise ValueError(
|
| 220 |
+
f"Invalid phase_bias '{self.phase_bias}'. Expected one of {PHASE_BIASES}."
|
| 221 |
+
)
|
| 222 |
+
if not isinstance(self.duration_steps, int):
|
| 223 |
+
raise ValueError("duration_steps must be an integer.")
|
| 224 |
+
if not 1 <= self.duration_steps <= 20:
|
| 225 |
+
raise ValueError("duration_steps must be between 1 and 20.")
|
| 226 |
+
self.target_intersections = _dedupe_string_list(self.target_intersections, limit=8)
|
| 227 |
+
return self
|
| 228 |
+
|
| 229 |
+
@classmethod
|
| 230 |
+
def default_hold(cls, duration_steps: int = 1) -> "DistrictAction":
|
| 231 |
+
return cls(
|
| 232 |
+
strategy="hold",
|
| 233 |
+
priority_corridor=None,
|
| 234 |
+
target_intersections=[],
|
| 235 |
+
phase_bias="NONE",
|
| 236 |
+
duration_steps=max(1, min(int(duration_steps), 20)),
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
@classmethod
|
| 240 |
+
def from_dict(cls, payload: dict[str, Any]) -> "DistrictAction":
|
| 241 |
+
return cls(
|
| 242 |
+
strategy=str(payload.get("strategy", "hold")),
|
| 243 |
+
priority_corridor=payload.get("priority_corridor"),
|
| 244 |
+
target_intersections=list(payload.get("target_intersections", [])),
|
| 245 |
+
phase_bias=str(payload.get("phase_bias", "NONE")),
|
| 246 |
+
duration_steps=int(payload.get("duration_steps", 1)),
|
| 247 |
+
).validate()
|
| 248 |
+
|
| 249 |
+
@classmethod
|
| 250 |
+
def from_json(cls, payload: str) -> "DistrictAction":
|
| 251 |
+
return cls.from_dict(json.loads(payload))
|
| 252 |
+
|
| 253 |
+
def to_dict(self) -> dict[str, Any]:
|
| 254 |
+
self.validate()
|
| 255 |
+
return {
|
| 256 |
+
"strategy": self.strategy,
|
| 257 |
+
"priority_corridor": self.priority_corridor,
|
| 258 |
+
"target_intersections": list(self.target_intersections),
|
| 259 |
+
"phase_bias": self.phase_bias,
|
| 260 |
+
"duration_steps": int(self.duration_steps),
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
def to_json(self) -> str:
|
| 264 |
+
return json.dumps(self.to_dict(), sort_keys=True, separators=(",", ":"))
|
| 265 |
+
|
| 266 |
+
def to_pretty_json(self) -> str:
|
| 267 |
+
return json.dumps(self.to_dict(), sort_keys=True, indent=2)
|
| 268 |
+
|
| 269 |
+
def to_rl_context(self) -> dict[str, Any]:
|
| 270 |
+
payload = self.to_dict()
|
| 271 |
+
payload["district_strategy"] = payload.pop("strategy")
|
| 272 |
+
payload["district_duration_steps"] = payload.pop("duration_steps")
|
| 273 |
+
return payload
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
@dataclass
|
| 277 |
+
class DistrictStateSummary:
|
| 278 |
+
city_id: str
|
| 279 |
+
district_id: str
|
| 280 |
+
district_type: str
|
| 281 |
+
scenario_name: str
|
| 282 |
+
scenario_type: str
|
| 283 |
+
decision_step: int
|
| 284 |
+
sim_time: int
|
| 285 |
+
intersection_count: int
|
| 286 |
+
avg_queue: float
|
| 287 |
+
max_queue: float
|
| 288 |
+
total_queue: float
|
| 289 |
+
avg_wait: float
|
| 290 |
+
max_wait: float
|
| 291 |
+
total_wait: float
|
| 292 |
+
avg_outgoing_load: float
|
| 293 |
+
max_outgoing_load: float
|
| 294 |
+
total_outgoing_load: float
|
| 295 |
+
recent_throughput: float
|
| 296 |
+
queue_change: float
|
| 297 |
+
wait_change: float
|
| 298 |
+
throughput_change: float
|
| 299 |
+
ns_queue: float
|
| 300 |
+
ew_queue: float
|
| 301 |
+
ns_wait: float
|
| 302 |
+
ew_wait: float
|
| 303 |
+
dominant_flow: str
|
| 304 |
+
boundary_queue_total: float
|
| 305 |
+
boundary_wait_total: float
|
| 306 |
+
spillback_risk: bool
|
| 307 |
+
incident_flag: bool
|
| 308 |
+
construction_flag: bool
|
| 309 |
+
overload_flag: bool
|
| 310 |
+
event_flag: bool
|
| 311 |
+
top_congested_intersections: list[CongestedIntersection] = field(default_factory=list)
|
| 312 |
+
candidate_intersections: list[CandidateIntersection] = field(default_factory=list)
|
| 313 |
+
|
| 314 |
+
def validate(self) -> "DistrictStateSummary":
|
| 315 |
+
if self.dominant_flow not in DOMINANT_FLOWS:
|
| 316 |
+
raise ValueError(
|
| 317 |
+
f"Invalid dominant_flow '{self.dominant_flow}'. Expected one of {DOMINANT_FLOWS}."
|
| 318 |
+
)
|
| 319 |
+
self.top_congested_intersections = list(self.top_congested_intersections[:5])
|
| 320 |
+
self.candidate_intersections = list(self.candidate_intersections[:8])
|
| 321 |
+
return self
|
| 322 |
+
|
| 323 |
+
def candidate_ids(self) -> list[str]:
|
| 324 |
+
self.validate()
|
| 325 |
+
return [item.intersection_id for item in self.candidate_intersections]
|
| 326 |
+
|
| 327 |
+
def candidate_lookup(self) -> dict[str, CandidateIntersection]:
|
| 328 |
+
self.validate()
|
| 329 |
+
return {
|
| 330 |
+
item.intersection_id: item
|
| 331 |
+
for item in self.candidate_intersections
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
def to_dict(self) -> dict[str, Any]:
|
| 335 |
+
self.validate()
|
| 336 |
+
return {
|
| 337 |
+
"city_id": self.city_id,
|
| 338 |
+
"district_id": self.district_id,
|
| 339 |
+
"district_type": self.district_type,
|
| 340 |
+
"scenario_name": self.scenario_name,
|
| 341 |
+
"scenario_type": self.scenario_type,
|
| 342 |
+
"decision_step": int(self.decision_step),
|
| 343 |
+
"sim_time": int(self.sim_time),
|
| 344 |
+
"intersection_count": int(self.intersection_count),
|
| 345 |
+
"avg_queue": _round_float(self.avg_queue),
|
| 346 |
+
"max_queue": _round_float(self.max_queue),
|
| 347 |
+
"total_queue": _round_float(self.total_queue),
|
| 348 |
+
"avg_wait": _round_float(self.avg_wait),
|
| 349 |
+
"max_wait": _round_float(self.max_wait),
|
| 350 |
+
"total_wait": _round_float(self.total_wait),
|
| 351 |
+
"avg_outgoing_load": _round_float(self.avg_outgoing_load),
|
| 352 |
+
"max_outgoing_load": _round_float(self.max_outgoing_load),
|
| 353 |
+
"total_outgoing_load": _round_float(self.total_outgoing_load),
|
| 354 |
+
"recent_throughput": _round_float(self.recent_throughput),
|
| 355 |
+
"queue_change": _round_float(self.queue_change),
|
| 356 |
+
"wait_change": _round_float(self.wait_change),
|
| 357 |
+
"throughput_change": _round_float(self.throughput_change),
|
| 358 |
+
"ns_queue": _round_float(self.ns_queue),
|
| 359 |
+
"ew_queue": _round_float(self.ew_queue),
|
| 360 |
+
"ns_wait": _round_float(self.ns_wait),
|
| 361 |
+
"ew_wait": _round_float(self.ew_wait),
|
| 362 |
+
"dominant_flow": self.dominant_flow,
|
| 363 |
+
"boundary_queue_total": _round_float(self.boundary_queue_total),
|
| 364 |
+
"boundary_wait_total": _round_float(self.boundary_wait_total),
|
| 365 |
+
"spillback_risk": bool(self.spillback_risk),
|
| 366 |
+
"incident_flag": bool(self.incident_flag),
|
| 367 |
+
"construction_flag": bool(self.construction_flag),
|
| 368 |
+
"overload_flag": bool(self.overload_flag),
|
| 369 |
+
"event_flag": bool(self.event_flag),
|
| 370 |
+
"top_congested_intersections": [
|
| 371 |
+
item.to_dict() for item in self.top_congested_intersections
|
| 372 |
+
],
|
| 373 |
+
"candidate_intersections": [
|
| 374 |
+
item.to_dict() for item in self.candidate_intersections
|
| 375 |
+
],
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
def to_json(self) -> str:
|
| 379 |
+
return json.dumps(self.to_dict(), sort_keys=True, separators=(",", ":"))
|
| 380 |
+
|
| 381 |
+
def to_prompt_text(self) -> str:
|
| 382 |
+
self.validate()
|
| 383 |
+
top_lines = [item.to_prompt_line() for item in self.top_congested_intersections]
|
| 384 |
+
candidate_lines = [item.to_prompt_line() for item in self.candidate_intersections]
|
| 385 |
+
if not top_lines:
|
| 386 |
+
top_lines = ["- none"]
|
| 387 |
+
if not candidate_lines:
|
| 388 |
+
candidate_lines = ["- none"]
|
| 389 |
+
return "\n".join(
|
| 390 |
+
[
|
| 391 |
+
f"city_id: {self.city_id}",
|
| 392 |
+
f"district_id: {self.district_id}",
|
| 393 |
+
f"district_type: {self.district_type}",
|
| 394 |
+
f"scenario: {self.scenario_name}",
|
| 395 |
+
f"scenario_type: {self.scenario_type}",
|
| 396 |
+
f"decision_step: {self.decision_step}",
|
| 397 |
+
f"sim_time: {self.sim_time}",
|
| 398 |
+
f"intersection_count: {self.intersection_count}",
|
| 399 |
+
f"avg_queue: {self.avg_queue:.2f}",
|
| 400 |
+
f"max_queue: {self.max_queue:.2f}",
|
| 401 |
+
f"total_queue: {self.total_queue:.2f}",
|
| 402 |
+
f"avg_wait: {self.avg_wait:.2f}",
|
| 403 |
+
f"max_wait: {self.max_wait:.2f}",
|
| 404 |
+
f"total_wait: {self.total_wait:.2f}",
|
| 405 |
+
f"avg_outgoing_load: {self.avg_outgoing_load:.2f}",
|
| 406 |
+
f"max_outgoing_load: {self.max_outgoing_load:.2f}",
|
| 407 |
+
f"total_outgoing_load: {self.total_outgoing_load:.2f}",
|
| 408 |
+
f"recent_throughput: {self.recent_throughput:.2f}",
|
| 409 |
+
f"queue_change: {self.queue_change:.2f}",
|
| 410 |
+
f"wait_change: {self.wait_change:.2f}",
|
| 411 |
+
f"throughput_change: {self.throughput_change:.2f}",
|
| 412 |
+
f"ns_queue: {self.ns_queue:.2f}",
|
| 413 |
+
f"ew_queue: {self.ew_queue:.2f}",
|
| 414 |
+
f"ns_wait: {self.ns_wait:.2f}",
|
| 415 |
+
f"ew_wait: {self.ew_wait:.2f}",
|
| 416 |
+
f"dominant_flow: {self.dominant_flow}",
|
| 417 |
+
f"boundary_queue_total: {self.boundary_queue_total:.2f}",
|
| 418 |
+
f"boundary_wait_total: {self.boundary_wait_total:.2f}",
|
| 419 |
+
f"spillback_risk: {int(self.spillback_risk)}",
|
| 420 |
+
f"incident_flag: {int(self.incident_flag)}",
|
| 421 |
+
f"construction_flag: {int(self.construction_flag)}",
|
| 422 |
+
f"overload_flag: {int(self.overload_flag)}",
|
| 423 |
+
f"event_flag: {int(self.event_flag)}",
|
| 424 |
+
"top_congested_intersections:",
|
| 425 |
+
*top_lines,
|
| 426 |
+
"candidate_intersections:",
|
| 427 |
+
*candidate_lines,
|
| 428 |
+
]
|
| 429 |
+
)
|
district_llm/summary_builder.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from district_llm.schema import CandidateIntersection, CongestedIntersection, DistrictStateSummary, candidate_priority_score
|
| 10 |
+
from env.utils import load_json
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class _SummaryContext:
|
| 15 |
+
previous_summaries: dict[str, DistrictStateSummary]
|
| 16 |
+
previous_finished_vehicles: int
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DistrictStateSummaryBuilder:
|
| 20 |
+
def __init__(self, top_k: int = 3, candidate_limit: int = 6):
|
| 21 |
+
self.top_k = int(top_k)
|
| 22 |
+
self.candidate_limit = int(candidate_limit)
|
| 23 |
+
self._context = _SummaryContext(previous_summaries={}, previous_finished_vehicles=0)
|
| 24 |
+
self._scenario_metadata: dict[str, Any] | None = None
|
| 25 |
+
self._road_endpoints: dict[str, tuple[str, str]] | None = None
|
| 26 |
+
self._incident_intersections: set[str] = set()
|
| 27 |
+
|
| 28 |
+
def reset(self) -> None:
|
| 29 |
+
self._context = _SummaryContext(previous_summaries={}, previous_finished_vehicles=0)
|
| 30 |
+
self._scenario_metadata = None
|
| 31 |
+
self._road_endpoints = None
|
| 32 |
+
self._incident_intersections = set()
|
| 33 |
+
|
| 34 |
+
def build_all(self, env, observation_batch: dict[str, Any]) -> dict[str, DistrictStateSummary]:
|
| 35 |
+
if self._scenario_metadata is None:
|
| 36 |
+
metadata_path = Path(env.scenario_dir) / "scenario_metadata.json"
|
| 37 |
+
self._scenario_metadata = load_json(metadata_path) if metadata_path.exists() else {}
|
| 38 |
+
self._road_endpoints = self._load_road_endpoints(Path(env.roadnet_path))
|
| 39 |
+
self._incident_intersections = self._derive_incident_intersections()
|
| 40 |
+
|
| 41 |
+
lane_vehicle_count = env.adapter.get_lane_vehicle_count()
|
| 42 |
+
finished_vehicles = int(env.adapter.get_finished_vehicle_count())
|
| 43 |
+
district_summaries: dict[str, DistrictStateSummary] = {}
|
| 44 |
+
|
| 45 |
+
for district_id in env.districts:
|
| 46 |
+
district_summaries[district_id] = self._build_single(
|
| 47 |
+
env=env,
|
| 48 |
+
observation_batch=observation_batch,
|
| 49 |
+
lane_vehicle_count=lane_vehicle_count,
|
| 50 |
+
district_id=district_id,
|
| 51 |
+
finished_vehicles=finished_vehicles,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self._context.previous_summaries = district_summaries
|
| 55 |
+
self._context.previous_finished_vehicles = finished_vehicles
|
| 56 |
+
return district_summaries
|
| 57 |
+
|
| 58 |
+
def _build_single(
|
| 59 |
+
self,
|
| 60 |
+
env,
|
| 61 |
+
observation_batch: dict[str, Any],
|
| 62 |
+
lane_vehicle_count: dict[str, int],
|
| 63 |
+
district_id: str,
|
| 64 |
+
finished_vehicles: int,
|
| 65 |
+
) -> DistrictStateSummary:
|
| 66 |
+
district = env.districts[district_id]
|
| 67 |
+
scenario_metadata = self._scenario_metadata or {}
|
| 68 |
+
intersection_ids = observation_batch["intersection_ids"]
|
| 69 |
+
district_ids = observation_batch["district_ids"]
|
| 70 |
+
incoming_counts = observation_batch["incoming_counts"]
|
| 71 |
+
incoming_waiting = observation_batch["incoming_waiting"]
|
| 72 |
+
current_phase = observation_batch["current_phase"]
|
| 73 |
+
|
| 74 |
+
queue_totals: list[float] = []
|
| 75 |
+
wait_totals: list[float] = []
|
| 76 |
+
outgoing_loads: list[float] = []
|
| 77 |
+
ns_queue = 0.0
|
| 78 |
+
ew_queue = 0.0
|
| 79 |
+
ns_wait = 0.0
|
| 80 |
+
ew_wait = 0.0
|
| 81 |
+
boundary_queue_total = 0.0
|
| 82 |
+
boundary_wait_total = 0.0
|
| 83 |
+
congestion_items: list[CongestedIntersection] = []
|
| 84 |
+
candidate_seed_items: list[dict[str, Any]] = []
|
| 85 |
+
|
| 86 |
+
for index, intersection_id in enumerate(intersection_ids):
|
| 87 |
+
if district_ids[index] != district_id:
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
queue_total = float(np.asarray(incoming_counts[index], dtype=np.float32).sum())
|
| 91 |
+
wait_total = float(np.asarray(incoming_waiting[index], dtype=np.float32).sum())
|
| 92 |
+
outgoing_load = self._compute_outgoing_load(
|
| 93 |
+
env=env,
|
| 94 |
+
lane_vehicle_count=lane_vehicle_count,
|
| 95 |
+
intersection_id=intersection_id,
|
| 96 |
+
)
|
| 97 |
+
queue_totals.append(queue_total)
|
| 98 |
+
wait_totals.append(wait_total)
|
| 99 |
+
outgoing_loads.append(outgoing_load)
|
| 100 |
+
|
| 101 |
+
midpoint = incoming_counts.shape[1] // 2
|
| 102 |
+
ns_queue_local = float(np.asarray(incoming_counts[index][:midpoint], dtype=np.float32).sum())
|
| 103 |
+
ew_queue_local = float(np.asarray(incoming_counts[index][midpoint:], dtype=np.float32).sum())
|
| 104 |
+
ns_wait_local = float(np.asarray(incoming_waiting[index][:midpoint], dtype=np.float32).sum())
|
| 105 |
+
ew_wait_local = float(np.asarray(incoming_waiting[index][midpoint:], dtype=np.float32).sum())
|
| 106 |
+
ns_queue += ns_queue_local
|
| 107 |
+
ew_queue += ew_queue_local
|
| 108 |
+
ns_wait += ns_wait_local
|
| 109 |
+
ew_wait += ew_wait_local
|
| 110 |
+
|
| 111 |
+
intersection_config = env.intersections[intersection_id]
|
| 112 |
+
if intersection_config.is_boundary:
|
| 113 |
+
boundary_queue_total += queue_total
|
| 114 |
+
boundary_wait_total += wait_total
|
| 115 |
+
|
| 116 |
+
congestion_items.append(
|
| 117 |
+
CongestedIntersection(
|
| 118 |
+
intersection_id=intersection_id,
|
| 119 |
+
queue_total=queue_total,
|
| 120 |
+
wait_total=wait_total,
|
| 121 |
+
outgoing_load=outgoing_load,
|
| 122 |
+
current_phase=int(current_phase[index]),
|
| 123 |
+
is_boundary=bool(intersection_config.is_boundary),
|
| 124 |
+
)
|
| 125 |
+
)
|
| 126 |
+
candidate_seed_items.append(
|
| 127 |
+
{
|
| 128 |
+
"intersection_id": intersection_id,
|
| 129 |
+
"queue_total": queue_total,
|
| 130 |
+
"wait_total": wait_total,
|
| 131 |
+
"outgoing_load": outgoing_load,
|
| 132 |
+
"current_phase": int(current_phase[index]),
|
| 133 |
+
"is_boundary": bool(intersection_config.is_boundary),
|
| 134 |
+
"spillback_risk": bool(
|
| 135 |
+
outgoing_load >= max(6.0, queue_total * 0.6)
|
| 136 |
+
or (
|
| 137 |
+
intersection_config.is_boundary
|
| 138 |
+
and outgoing_load >= max(4.0, queue_total * 0.4)
|
| 139 |
+
)
|
| 140 |
+
),
|
| 141 |
+
"incident_proximity": intersection_id in self._incident_intersections,
|
| 142 |
+
"corridor_alignment": self._compute_corridor_alignment(
|
| 143 |
+
ns_queue=ns_queue_local,
|
| 144 |
+
ew_queue=ew_queue_local,
|
| 145 |
+
ns_wait=ns_wait_local,
|
| 146 |
+
ew_wait=ew_wait_local,
|
| 147 |
+
),
|
| 148 |
+
}
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
queue_array = np.asarray(queue_totals or [0.0], dtype=np.float32)
|
| 152 |
+
wait_array = np.asarray(wait_totals or [0.0], dtype=np.float32)
|
| 153 |
+
outgoing_array = np.asarray(outgoing_loads or [0.0], dtype=np.float32)
|
| 154 |
+
|
| 155 |
+
previous_summary = self._context.previous_summaries.get(district_id)
|
| 156 |
+
recent_throughput = float(
|
| 157 |
+
finished_vehicles - self._context.previous_finished_vehicles
|
| 158 |
+
if self._context.previous_finished_vehicles
|
| 159 |
+
else 0.0
|
| 160 |
+
)
|
| 161 |
+
queue_change = 0.0 if previous_summary is None else float(queue_array.sum() - previous_summary.total_queue)
|
| 162 |
+
wait_change = 0.0 if previous_summary is None else float(wait_array.sum() - previous_summary.total_wait)
|
| 163 |
+
throughput_change = (
|
| 164 |
+
0.0
|
| 165 |
+
if previous_summary is None
|
| 166 |
+
else recent_throughput - previous_summary.recent_throughput
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
directional_ns = ns_queue + 1.5 * ns_wait
|
| 170 |
+
directional_ew = ew_queue + 1.5 * ew_wait
|
| 171 |
+
if directional_ns > directional_ew * 1.1:
|
| 172 |
+
dominant_flow = "NS"
|
| 173 |
+
elif directional_ew > directional_ns * 1.1:
|
| 174 |
+
dominant_flow = "EW"
|
| 175 |
+
else:
|
| 176 |
+
dominant_flow = "BALANCED"
|
| 177 |
+
|
| 178 |
+
boundary_share = boundary_queue_total / max(1.0, float(queue_array.sum()))
|
| 179 |
+
spillback_risk = bool(
|
| 180 |
+
outgoing_array.max() >= max(8.0, queue_array.max() * 0.5)
|
| 181 |
+
or (boundary_share >= 0.6 and queue_change >= 0.0)
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
top_intersections = sorted(
|
| 185 |
+
congestion_items,
|
| 186 |
+
key=lambda item: (item.queue_total + 1.5 * item.wait_total + 0.5 * item.outgoing_load),
|
| 187 |
+
reverse=True,
|
| 188 |
+
)[: self.top_k]
|
| 189 |
+
|
| 190 |
+
overload_flag = bool(
|
| 191 |
+
scenario_metadata.get("overload_district") == district_id
|
| 192 |
+
or (scenario_metadata.get("name") == "district_overload" and queue_array.sum() >= 25.0)
|
| 193 |
+
)
|
| 194 |
+
event_flag = bool(scenario_metadata.get("event_district") == district_id)
|
| 195 |
+
incident_flag = bool(
|
| 196 |
+
scenario_metadata.get("name") in {"accident", "construction"}
|
| 197 |
+
or bool(scenario_metadata.get("blocked_roads"))
|
| 198 |
+
)
|
| 199 |
+
construction_flag = bool(scenario_metadata.get("name") == "construction")
|
| 200 |
+
candidate_intersections = self._build_candidate_intersections(
|
| 201 |
+
candidate_seed_items=candidate_seed_items,
|
| 202 |
+
overload_flag=overload_flag,
|
| 203 |
+
event_flag=event_flag,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
return DistrictStateSummary(
|
| 207 |
+
city_id=env.city_id,
|
| 208 |
+
district_id=district_id,
|
| 209 |
+
district_type=district.district_type,
|
| 210 |
+
scenario_name=env.scenario_name,
|
| 211 |
+
scenario_type=str(scenario_metadata.get("intensity", env.scenario_name)),
|
| 212 |
+
decision_step=int(observation_batch["decision_step"]),
|
| 213 |
+
sim_time=int(observation_batch["sim_time"]),
|
| 214 |
+
intersection_count=int(len(district.intersection_ids)),
|
| 215 |
+
avg_queue=float(queue_array.mean()),
|
| 216 |
+
max_queue=float(queue_array.max()),
|
| 217 |
+
total_queue=float(queue_array.sum()),
|
| 218 |
+
avg_wait=float(wait_array.mean()),
|
| 219 |
+
max_wait=float(wait_array.max()),
|
| 220 |
+
total_wait=float(wait_array.sum()),
|
| 221 |
+
avg_outgoing_load=float(outgoing_array.mean()),
|
| 222 |
+
max_outgoing_load=float(outgoing_array.max()),
|
| 223 |
+
total_outgoing_load=float(outgoing_array.sum()),
|
| 224 |
+
recent_throughput=recent_throughput,
|
| 225 |
+
queue_change=queue_change,
|
| 226 |
+
wait_change=wait_change,
|
| 227 |
+
throughput_change=throughput_change,
|
| 228 |
+
ns_queue=ns_queue,
|
| 229 |
+
ew_queue=ew_queue,
|
| 230 |
+
ns_wait=ns_wait,
|
| 231 |
+
ew_wait=ew_wait,
|
| 232 |
+
dominant_flow=dominant_flow,
|
| 233 |
+
boundary_queue_total=boundary_queue_total,
|
| 234 |
+
boundary_wait_total=boundary_wait_total,
|
| 235 |
+
spillback_risk=spillback_risk,
|
| 236 |
+
incident_flag=incident_flag,
|
| 237 |
+
construction_flag=construction_flag,
|
| 238 |
+
overload_flag=overload_flag,
|
| 239 |
+
event_flag=event_flag,
|
| 240 |
+
top_congested_intersections=top_intersections,
|
| 241 |
+
candidate_intersections=candidate_intersections,
|
| 242 |
+
).validate()
|
| 243 |
+
|
| 244 |
+
@staticmethod
|
| 245 |
+
def _compute_outgoing_load(env, lane_vehicle_count: dict[str, int], intersection_id: str) -> float:
|
| 246 |
+
intersection_config = env.intersections[intersection_id]
|
| 247 |
+
if not intersection_config.outgoing_lanes:
|
| 248 |
+
return 0.0
|
| 249 |
+
return float(
|
| 250 |
+
sum(float(lane_vehicle_count.get(lane_id, 0)) for lane_id in intersection_config.outgoing_lanes)
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
@staticmethod
|
| 254 |
+
def _compute_corridor_alignment(
|
| 255 |
+
ns_queue: float,
|
| 256 |
+
ew_queue: float,
|
| 257 |
+
ns_wait: float,
|
| 258 |
+
ew_wait: float,
|
| 259 |
+
) -> str:
|
| 260 |
+
ns_pressure = ns_queue + 1.5 * ns_wait
|
| 261 |
+
ew_pressure = ew_queue + 1.5 * ew_wait
|
| 262 |
+
if ns_pressure > ew_pressure * 1.1:
|
| 263 |
+
return "NS"
|
| 264 |
+
if ew_pressure > ns_pressure * 1.1:
|
| 265 |
+
return "EW"
|
| 266 |
+
return "BALANCED"
|
| 267 |
+
|
| 268 |
+
@staticmethod
|
| 269 |
+
def _load_road_endpoints(roadnet_path: Path) -> dict[str, tuple[str, str]]:
|
| 270 |
+
roadnet = load_json(roadnet_path)
|
| 271 |
+
return {
|
| 272 |
+
str(road["id"]): (
|
| 273 |
+
str(road["startIntersection"]),
|
| 274 |
+
str(road["endIntersection"]),
|
| 275 |
+
)
|
| 276 |
+
for road in roadnet.get("roads", [])
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
def _derive_incident_intersections(self) -> set[str]:
|
| 280 |
+
if not self._road_endpoints:
|
| 281 |
+
return set()
|
| 282 |
+
scenario_metadata = self._scenario_metadata or {}
|
| 283 |
+
details = scenario_metadata.get("details", {})
|
| 284 |
+
incident_roads = list(scenario_metadata.get("blocked_roads", []))
|
| 285 |
+
incident_roads.extend(details.get("accident_roads", []))
|
| 286 |
+
incident_roads.extend(details.get("construction_roads", []))
|
| 287 |
+
if not incident_roads:
|
| 288 |
+
incident_roads.extend(list((scenario_metadata.get("penalized_roads") or {}).keys()))
|
| 289 |
+
|
| 290 |
+
intersections: set[str] = set()
|
| 291 |
+
for road_id in incident_roads:
|
| 292 |
+
endpoints = self._road_endpoints.get(str(road_id))
|
| 293 |
+
if endpoints is None:
|
| 294 |
+
continue
|
| 295 |
+
intersections.update(endpoints)
|
| 296 |
+
return intersections
|
| 297 |
+
|
| 298 |
+
def _build_candidate_intersections(
|
| 299 |
+
self,
|
| 300 |
+
candidate_seed_items: list[dict[str, Any]],
|
| 301 |
+
overload_flag: bool,
|
| 302 |
+
event_flag: bool,
|
| 303 |
+
) -> list[CandidateIntersection]:
|
| 304 |
+
if not candidate_seed_items or self.candidate_limit <= 0:
|
| 305 |
+
return []
|
| 306 |
+
|
| 307 |
+
def severity_key(item: dict[str, Any]) -> tuple[float, float, float, float, str]:
|
| 308 |
+
candidate = CandidateIntersection(
|
| 309 |
+
intersection_id=str(item["intersection_id"]),
|
| 310 |
+
queue_total=float(item["queue_total"]),
|
| 311 |
+
wait_total=float(item["wait_total"]),
|
| 312 |
+
outgoing_load=float(item["outgoing_load"]),
|
| 313 |
+
current_phase=int(item["current_phase"]),
|
| 314 |
+
is_boundary=bool(item["is_boundary"]),
|
| 315 |
+
spillback_risk=bool(item["spillback_risk"]),
|
| 316 |
+
incident_proximity=bool(item["incident_proximity"]),
|
| 317 |
+
overload_marker=overload_flag,
|
| 318 |
+
event_proximity=event_flag,
|
| 319 |
+
corridor_alignment=str(item["corridor_alignment"]),
|
| 320 |
+
selection_reasons=[],
|
| 321 |
+
)
|
| 322 |
+
return (
|
| 323 |
+
candidate_priority_score(candidate),
|
| 324 |
+
float(item["queue_total"]),
|
| 325 |
+
float(item["wait_total"]),
|
| 326 |
+
float(item["outgoing_load"]),
|
| 327 |
+
str(item["intersection_id"]),
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
overall_sorted = sorted(
|
| 331 |
+
candidate_seed_items,
|
| 332 |
+
key=lambda item: (
|
| 333 |
+
-severity_key(item)[0],
|
| 334 |
+
-severity_key(item)[1],
|
| 335 |
+
-severity_key(item)[2],
|
| 336 |
+
-severity_key(item)[3],
|
| 337 |
+
severity_key(item)[4],
|
| 338 |
+
),
|
| 339 |
+
)
|
| 340 |
+
boundary_sorted = [item for item in overall_sorted if item["is_boundary"]]
|
| 341 |
+
spillback_sorted = [item for item in overall_sorted if item["spillback_risk"]]
|
| 342 |
+
incident_sorted = [item for item in overall_sorted if item["incident_proximity"]]
|
| 343 |
+
outgoing_sorted = sorted(
|
| 344 |
+
candidate_seed_items,
|
| 345 |
+
key=lambda item: (
|
| 346 |
+
-float(item["outgoing_load"]),
|
| 347 |
+
-float(item["queue_total"]),
|
| 348 |
+
-float(item["wait_total"]),
|
| 349 |
+
str(item["intersection_id"]),
|
| 350 |
+
),
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
reason_tags: dict[str, set[str]] = {}
|
| 354 |
+
selected_ids: list[str] = []
|
| 355 |
+
|
| 356 |
+
def mark(items: list[dict[str, Any]], tag: str, limit: int) -> None:
|
| 357 |
+
for item in items[:limit]:
|
| 358 |
+
intersection_id = str(item["intersection_id"])
|
| 359 |
+
reason_tags.setdefault(intersection_id, set()).add(tag)
|
| 360 |
+
if intersection_id not in selected_ids:
|
| 361 |
+
selected_ids.append(intersection_id)
|
| 362 |
+
|
| 363 |
+
mark(overall_sorted, "congested", max(1, min(self.top_k, self.candidate_limit)))
|
| 364 |
+
mark(boundary_sorted, "boundary", min(2, self.candidate_limit))
|
| 365 |
+
mark(spillback_sorted, "spillback", min(2, self.candidate_limit))
|
| 366 |
+
mark(incident_sorted, "incident", min(2, self.candidate_limit))
|
| 367 |
+
mark(outgoing_sorted, "outgoing", min(2, self.candidate_limit))
|
| 368 |
+
if overload_flag:
|
| 369 |
+
mark(overall_sorted, "overload", min(2, self.candidate_limit))
|
| 370 |
+
if event_flag:
|
| 371 |
+
event_seed = boundary_sorted if boundary_sorted else outgoing_sorted
|
| 372 |
+
mark(event_seed, "event", min(2, self.candidate_limit))
|
| 373 |
+
|
| 374 |
+
for item in overall_sorted:
|
| 375 |
+
if len(selected_ids) >= self.candidate_limit:
|
| 376 |
+
break
|
| 377 |
+
intersection_id = str(item["intersection_id"])
|
| 378 |
+
if intersection_id in selected_ids:
|
| 379 |
+
continue
|
| 380 |
+
selected_ids.append(intersection_id)
|
| 381 |
+
reason_tags.setdefault(intersection_id, {"congested"})
|
| 382 |
+
|
| 383 |
+
seed_lookup = {
|
| 384 |
+
str(item["intersection_id"]): item
|
| 385 |
+
for item in candidate_seed_items
|
| 386 |
+
}
|
| 387 |
+
candidates = [
|
| 388 |
+
CandidateIntersection(
|
| 389 |
+
intersection_id=intersection_id,
|
| 390 |
+
queue_total=float(seed_lookup[intersection_id]["queue_total"]),
|
| 391 |
+
wait_total=float(seed_lookup[intersection_id]["wait_total"]),
|
| 392 |
+
outgoing_load=float(seed_lookup[intersection_id]["outgoing_load"]),
|
| 393 |
+
current_phase=int(seed_lookup[intersection_id]["current_phase"]),
|
| 394 |
+
is_boundary=bool(seed_lookup[intersection_id]["is_boundary"]),
|
| 395 |
+
spillback_risk=bool(seed_lookup[intersection_id]["spillback_risk"]),
|
| 396 |
+
incident_proximity=bool(seed_lookup[intersection_id]["incident_proximity"]),
|
| 397 |
+
overload_marker=overload_flag,
|
| 398 |
+
event_proximity=event_flag,
|
| 399 |
+
corridor_alignment=str(seed_lookup[intersection_id]["corridor_alignment"]),
|
| 400 |
+
selection_reasons=sorted(reason_tags.get(intersection_id, {"congested"})),
|
| 401 |
+
).validate()
|
| 402 |
+
for intersection_id in selected_ids[: self.candidate_limit]
|
| 403 |
+
]
|
| 404 |
+
return sorted(
|
| 405 |
+
candidates,
|
| 406 |
+
key=lambda item: (
|
| 407 |
+
-candidate_priority_score(item),
|
| 408 |
+
-item.queue_total,
|
| 409 |
+
-item.wait_total,
|
| 410 |
+
-item.outgoing_load,
|
| 411 |
+
item.intersection_id,
|
| 412 |
+
),
|
| 413 |
+
)
|
district_llm/teachers.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from agents.local_policy import (
|
| 12 |
+
BaseLocalPolicy,
|
| 13 |
+
FixedCyclePolicy,
|
| 14 |
+
HoldPhasePolicy,
|
| 15 |
+
QueueGreedyPolicy,
|
| 16 |
+
RandomPhasePolicy,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
BASELINE_TYPES: tuple[str, ...] = ("hold", "fixed", "random", "queue_greedy")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass(frozen=True)
|
| 24 |
+
class TeacherMetadata:
|
| 25 |
+
controller_type: str
|
| 26 |
+
controller_id: str
|
| 27 |
+
controller_family: str
|
| 28 |
+
teacher_algorithm: str
|
| 29 |
+
checkpoint_path: str | None = None
|
| 30 |
+
|
| 31 |
+
def to_dict(self) -> dict[str, Any]:
|
| 32 |
+
return {
|
| 33 |
+
"controller_type": self.controller_type,
|
| 34 |
+
"controller_id": self.controller_id,
|
| 35 |
+
"controller_family": self.controller_family,
|
| 36 |
+
"teacher_algorithm": self.teacher_algorithm,
|
| 37 |
+
"checkpoint_path": self.checkpoint_path,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class BaseTeacher(ABC):
|
| 42 |
+
def __init__(self, metadata: TeacherMetadata):
|
| 43 |
+
self.metadata = metadata
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def env_config(self) -> Any | None:
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
@abstractmethod
|
| 50 |
+
def act(self, observation_batch: dict[str, Any]) -> np.ndarray:
|
| 51 |
+
raise NotImplementedError
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class BaselineTeacher(BaseTeacher):
|
| 55 |
+
def __init__(self, policy: BaseLocalPolicy, metadata: TeacherMetadata):
|
| 56 |
+
super().__init__(metadata=metadata)
|
| 57 |
+
self.policy = policy
|
| 58 |
+
|
| 59 |
+
def act(self, observation_batch: dict[str, Any]) -> np.ndarray:
|
| 60 |
+
return np.asarray(self.policy.act(observation_batch), dtype=np.int64)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class RLCheckpointTeacher(BaseTeacher):
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
checkpoint_path: str | Path,
|
| 67 |
+
device: str | None = None,
|
| 68 |
+
deterministic: bool = True,
|
| 69 |
+
):
|
| 70 |
+
try:
|
| 71 |
+
import torch
|
| 72 |
+
except ImportError as exc:
|
| 73 |
+
raise ImportError(
|
| 74 |
+
"RL checkpoint teachers require PyTorch to be installed."
|
| 75 |
+
) from exc
|
| 76 |
+
|
| 77 |
+
from training.models import RunningNormalizer, TrafficControlQNetwork
|
| 78 |
+
from training.train_local_policy import load_env_config
|
| 79 |
+
|
| 80 |
+
checkpoint_path = Path(checkpoint_path)
|
| 81 |
+
self._torch = torch
|
| 82 |
+
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 83 |
+
self.checkpoint = torch.load(
|
| 84 |
+
checkpoint_path,
|
| 85 |
+
map_location=self.device,
|
| 86 |
+
weights_only=False,
|
| 87 |
+
)
|
| 88 |
+
network_architecture = self.checkpoint.get("network_architecture") or self.checkpoint.get(
|
| 89 |
+
"policy_architecture",
|
| 90 |
+
{},
|
| 91 |
+
)
|
| 92 |
+
trainer_config = self.checkpoint.get("dqn_config", {})
|
| 93 |
+
policy_arch = network_architecture.get(
|
| 94 |
+
"policy_arch",
|
| 95 |
+
trainer_config.get("policy_arch", "single_head_with_district_feature"),
|
| 96 |
+
)
|
| 97 |
+
self.model = TrafficControlQNetwork(
|
| 98 |
+
observation_dim=int(network_architecture["observation_dim"]),
|
| 99 |
+
action_dim=int(network_architecture.get("action_dim", 2)),
|
| 100 |
+
hidden_dim=int(trainer_config.get("hidden_dim", 256)),
|
| 101 |
+
num_layers=int(trainer_config.get("hidden_layers", 2)),
|
| 102 |
+
district_types=tuple(network_architecture.get("district_types", ())),
|
| 103 |
+
policy_arch=policy_arch,
|
| 104 |
+
dueling=bool(network_architecture.get("dueling", True)),
|
| 105 |
+
).to(self.device)
|
| 106 |
+
self.model.load_state_dict(
|
| 107 |
+
self.checkpoint.get("q_network_state_dict") or self.checkpoint["policy_state_dict"]
|
| 108 |
+
)
|
| 109 |
+
self.model.eval()
|
| 110 |
+
self.obs_normalizer = None
|
| 111 |
+
if self.checkpoint.get("obs_normalizer"):
|
| 112 |
+
self.obs_normalizer = RunningNormalizer()
|
| 113 |
+
self.obs_normalizer.load_state_dict(self.checkpoint["obs_normalizer"])
|
| 114 |
+
|
| 115 |
+
checkpoint_id = checkpoint_path.stem
|
| 116 |
+
super().__init__(
|
| 117 |
+
metadata=TeacherMetadata(
|
| 118 |
+
controller_type="rl_checkpoint",
|
| 119 |
+
controller_id=checkpoint_id,
|
| 120 |
+
controller_family="dqn",
|
| 121 |
+
teacher_algorithm="dqn",
|
| 122 |
+
checkpoint_path=str(checkpoint_path),
|
| 123 |
+
)
|
| 124 |
+
)
|
| 125 |
+
self.deterministic = bool(deterministic)
|
| 126 |
+
self._env_config = (
|
| 127 |
+
load_env_config(self.checkpoint["env_config"])
|
| 128 |
+
if self.checkpoint.get("env_config")
|
| 129 |
+
else None
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
@property
|
| 133 |
+
def env_config(self) -> Any | None:
|
| 134 |
+
return self._env_config
|
| 135 |
+
|
| 136 |
+
def act(self, observation_batch: dict[str, Any]) -> np.ndarray:
|
| 137 |
+
torch = self._torch
|
| 138 |
+
raw_obs = observation_batch["observations"].astype(np.float32)
|
| 139 |
+
normalized_obs = self.obs_normalizer.normalize(raw_obs) if self.obs_normalizer else raw_obs
|
| 140 |
+
obs_tensor = torch.as_tensor(normalized_obs, dtype=torch.float32, device=self.device)
|
| 141 |
+
district_type_tensor = torch.as_tensor(
|
| 142 |
+
observation_batch["district_type_indices"],
|
| 143 |
+
dtype=torch.int64,
|
| 144 |
+
device=self.device,
|
| 145 |
+
)
|
| 146 |
+
action_mask_tensor = torch.as_tensor(
|
| 147 |
+
observation_batch["action_mask"],
|
| 148 |
+
dtype=torch.float32,
|
| 149 |
+
device=self.device,
|
| 150 |
+
)
|
| 151 |
+
with torch.no_grad():
|
| 152 |
+
actions = self.model.act(
|
| 153 |
+
observations=obs_tensor,
|
| 154 |
+
district_type_indices=district_type_tensor,
|
| 155 |
+
action_mask=action_mask_tensor,
|
| 156 |
+
deterministic=self.deterministic,
|
| 157 |
+
epsilon=0.0,
|
| 158 |
+
)
|
| 159 |
+
return actions.cpu().numpy().astype(np.int64)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def build_teacher(
|
| 163 |
+
controller_type: str,
|
| 164 |
+
checkpoint: str | None = None,
|
| 165 |
+
fixed_green_time: int = 20,
|
| 166 |
+
seed: int = 7,
|
| 167 |
+
device: str | None = None,
|
| 168 |
+
) -> BaseTeacher:
|
| 169 |
+
if controller_type == "rl_checkpoint":
|
| 170 |
+
if not checkpoint:
|
| 171 |
+
raise ValueError("controller_type='rl_checkpoint' requires --checkpoint.")
|
| 172 |
+
return RLCheckpointTeacher(checkpoint_path=checkpoint, device=device)
|
| 173 |
+
if controller_type == "hold":
|
| 174 |
+
return BaselineTeacher(
|
| 175 |
+
policy=HoldPhasePolicy(),
|
| 176 |
+
metadata=TeacherMetadata(
|
| 177 |
+
controller_type="hold",
|
| 178 |
+
controller_id="hold",
|
| 179 |
+
controller_family="baseline",
|
| 180 |
+
teacher_algorithm="hold",
|
| 181 |
+
),
|
| 182 |
+
)
|
| 183 |
+
if controller_type == "fixed":
|
| 184 |
+
return BaselineTeacher(
|
| 185 |
+
policy=FixedCyclePolicy(green_time=fixed_green_time),
|
| 186 |
+
metadata=TeacherMetadata(
|
| 187 |
+
controller_type="fixed",
|
| 188 |
+
controller_id=f"fixed_{fixed_green_time}",
|
| 189 |
+
controller_family="baseline",
|
| 190 |
+
teacher_algorithm="fixed_cycle",
|
| 191 |
+
),
|
| 192 |
+
)
|
| 193 |
+
if controller_type == "random":
|
| 194 |
+
return BaselineTeacher(
|
| 195 |
+
policy=RandomPhasePolicy(seed=seed),
|
| 196 |
+
metadata=TeacherMetadata(
|
| 197 |
+
controller_type="random",
|
| 198 |
+
controller_id=f"random_{seed}",
|
| 199 |
+
controller_family="baseline",
|
| 200 |
+
teacher_algorithm="random",
|
| 201 |
+
),
|
| 202 |
+
)
|
| 203 |
+
if controller_type == "queue_greedy":
|
| 204 |
+
return BaselineTeacher(
|
| 205 |
+
policy=QueueGreedyPolicy(),
|
| 206 |
+
metadata=TeacherMetadata(
|
| 207 |
+
controller_type="queue_greedy",
|
| 208 |
+
controller_id="queue_greedy",
|
| 209 |
+
controller_family="baseline",
|
| 210 |
+
teacher_algorithm="queue_greedy",
|
| 211 |
+
),
|
| 212 |
+
)
|
| 213 |
+
raise ValueError(
|
| 214 |
+
f"Unsupported controller_type '{controller_type}'. "
|
| 215 |
+
f"Expected rl_checkpoint or one of {BASELINE_TYPES}."
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def parse_teacher_spec(spec: str) -> tuple[str, str | None]:
|
| 220 |
+
if "=" not in spec:
|
| 221 |
+
return spec.strip(), None
|
| 222 |
+
controller_type, checkpoint_path = spec.split("=", 1)
|
| 223 |
+
return controller_type.strip(), checkpoint_path.strip() or None
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def teachers_metadata_json(teachers: list[BaseTeacher]) -> str:
|
| 227 |
+
return json.dumps([teacher.metadata.to_dict() for teacher in teachers], sort_keys=True)
|
district_llm/train_unsloth.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from district_llm.data import load_jsonl_text_dataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_args() -> argparse.Namespace:
|
| 10 |
+
parser = argparse.ArgumentParser(
|
| 11 |
+
description="Supervised fine-tune a district LLM on DQN-derived district traces with Unsloth/QLoRA."
|
| 12 |
+
)
|
| 13 |
+
parser.add_argument("--dataset", required=True, help="JSONL dataset with a 'text' field.")
|
| 14 |
+
parser.add_argument("--output-dir", required=True)
|
| 15 |
+
parser.add_argument("--model-name", default="Qwen/Qwen2.5-7B-Instruct")
|
| 16 |
+
parser.add_argument("--max-seq-length", type=int, default=1024)
|
| 17 |
+
parser.add_argument("--load-in-4bit", action="store_true")
|
| 18 |
+
parser.add_argument("--lora-rank", type=int, default=16)
|
| 19 |
+
parser.add_argument("--lora-alpha", type=int, default=16)
|
| 20 |
+
parser.add_argument("--lora-dropout", type=float, default=0.0)
|
| 21 |
+
parser.add_argument("--batch-size", type=int, default=2)
|
| 22 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
|
| 23 |
+
parser.add_argument("--learning-rate", type=float, default=2e-4)
|
| 24 |
+
parser.add_argument("--warmup-steps", type=int, default=50)
|
| 25 |
+
parser.add_argument("--max-steps", type=int, default=500)
|
| 26 |
+
parser.add_argument("--logging-steps", type=int, default=10)
|
| 27 |
+
parser.add_argument("--save-steps", type=int, default=100)
|
| 28 |
+
parser.add_argument("--seed", type=int, default=7)
|
| 29 |
+
parser.add_argument("--dataset-num-proc", type=int, default=2)
|
| 30 |
+
parser.add_argument("--eval-dataset", default=None)
|
| 31 |
+
parser.add_argument("--resume-from-checkpoint", default=None)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--include-non-dqn-sources",
|
| 34 |
+
action="store_true",
|
| 35 |
+
help="By default the trainer keeps only DQN-derived rows (controller_family=dqn).",
|
| 36 |
+
)
|
| 37 |
+
return parser.parse_args()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def main() -> None:
|
| 41 |
+
args = parse_args()
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
import torch
|
| 45 |
+
from trl import SFTTrainer
|
| 46 |
+
from transformers import TrainingArguments
|
| 47 |
+
from unsloth import FastLanguageModel
|
| 48 |
+
except ImportError as exc:
|
| 49 |
+
raise ImportError(
|
| 50 |
+
"District LLM training requires 'unsloth' and 'trl'. "
|
| 51 |
+
"Install them in the active environment before running this entry point."
|
| 52 |
+
) from exc
|
| 53 |
+
|
| 54 |
+
controller_families = None if args.include_non_dqn_sources else ["dqn"]
|
| 55 |
+
train_dataset = load_jsonl_text_dataset(
|
| 56 |
+
args.dataset,
|
| 57 |
+
controller_families=controller_families,
|
| 58 |
+
)
|
| 59 |
+
eval_dataset = (
|
| 60 |
+
load_jsonl_text_dataset(
|
| 61 |
+
args.eval_dataset,
|
| 62 |
+
controller_families=controller_families,
|
| 63 |
+
)
|
| 64 |
+
if args.eval_dataset
|
| 65 |
+
else None
|
| 66 |
+
)
|
| 67 |
+
output_dir = Path(args.output_dir)
|
| 68 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 69 |
+
|
| 70 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 71 |
+
model_name=args.model_name,
|
| 72 |
+
max_seq_length=args.max_seq_length,
|
| 73 |
+
dtype=None,
|
| 74 |
+
load_in_4bit=bool(args.load_in_4bit),
|
| 75 |
+
)
|
| 76 |
+
model = FastLanguageModel.get_peft_model(
|
| 77 |
+
model,
|
| 78 |
+
r=args.lora_rank,
|
| 79 |
+
target_modules=[
|
| 80 |
+
"q_proj",
|
| 81 |
+
"k_proj",
|
| 82 |
+
"v_proj",
|
| 83 |
+
"o_proj",
|
| 84 |
+
"gate_proj",
|
| 85 |
+
"up_proj",
|
| 86 |
+
"down_proj",
|
| 87 |
+
],
|
| 88 |
+
lora_alpha=args.lora_alpha,
|
| 89 |
+
lora_dropout=args.lora_dropout,
|
| 90 |
+
bias="none",
|
| 91 |
+
use_gradient_checkpointing="unsloth",
|
| 92 |
+
random_state=args.seed,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
trainer = SFTTrainer(
|
| 96 |
+
model=model,
|
| 97 |
+
tokenizer=tokenizer,
|
| 98 |
+
train_dataset=train_dataset,
|
| 99 |
+
eval_dataset=eval_dataset,
|
| 100 |
+
dataset_text_field="text",
|
| 101 |
+
max_seq_length=args.max_seq_length,
|
| 102 |
+
dataset_num_proc=args.dataset_num_proc,
|
| 103 |
+
packing=False,
|
| 104 |
+
args=TrainingArguments(
|
| 105 |
+
output_dir=str(output_dir),
|
| 106 |
+
per_device_train_batch_size=args.batch_size,
|
| 107 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 108 |
+
warmup_steps=args.warmup_steps,
|
| 109 |
+
max_steps=args.max_steps,
|
| 110 |
+
learning_rate=args.learning_rate,
|
| 111 |
+
logging_steps=args.logging_steps,
|
| 112 |
+
save_steps=args.save_steps,
|
| 113 |
+
bf16=bool(torch.cuda.is_available() and torch.cuda.is_bf16_supported()),
|
| 114 |
+
fp16=bool(torch.cuda.is_available() and not torch.cuda.is_bf16_supported()),
|
| 115 |
+
optim="paged_adamw_8bit",
|
| 116 |
+
lr_scheduler_type="cosine",
|
| 117 |
+
seed=args.seed,
|
| 118 |
+
report_to="none",
|
| 119 |
+
evaluation_strategy="steps" if eval_dataset is not None else "no",
|
| 120 |
+
eval_steps=args.save_steps if eval_dataset is not None else None,
|
| 121 |
+
),
|
| 122 |
+
)
|
| 123 |
+
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
|
| 124 |
+
model.save_pretrained(str(output_dir))
|
| 125 |
+
tokenizer.save_pretrained(str(output_dir))
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
main()
|
env/README.md
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# env
|
| 2 |
+
|
| 3 |
+
CityFlow environment implementation for intersection-level RL with district-type metadata.
|
| 4 |
+
|
| 5 |
+
## Main files
|
| 6 |
+
|
| 7 |
+
- [traffic_env.py](/Users/aditya/Developer/traffic-llm/env/traffic_env.py)
|
| 8 |
+
Main environment. One episode corresponds to one `(city, scenario)` pair.
|
| 9 |
+
- [cityflow_adapter.py](/Users/aditya/Developer/traffic-llm/env/cityflow_adapter.py)
|
| 10 |
+
Thin wrapper around the CityFlow Python engine.
|
| 11 |
+
- [observation_builder.py](/Users/aditya/Developer/traffic-llm/env/observation_builder.py)
|
| 12 |
+
Converts variable city topology into fixed-size per-intersection tensors.
|
| 13 |
+
- [reward.py](/Users/aditya/Developer/traffic-llm/env/reward.py)
|
| 14 |
+
Configurable local reward calculation.
|
| 15 |
+
- [utils.py](/Users/aditya/Developer/traffic-llm/env/utils.py)
|
| 16 |
+
Topology parsing and helper functions.
|
| 17 |
+
- [intersection_config.py](/Users/aditya/Developer/traffic-llm/env/intersection_config.py)
|
| 18 |
+
Internal topology dataclasses.
|
| 19 |
+
|
| 20 |
+
## How it works
|
| 21 |
+
|
| 22 |
+
- Reads `roadnet.json`, `district_map.json`, and district types from `metadata.json`.
|
| 23 |
+
- Identifies non-virtual controllable intersections with at least two green phases.
|
| 24 |
+
- Uses one action per controllable intersection.
|
| 25 |
+
- Enforces `min_green_time` inside the environment.
|
| 26 |
+
- Advances CityFlow for `decision_interval` simulator steps between policy decisions.
|
| 27 |
+
- Returns a batched observation for all controlled intersections.
|
| 28 |
+
|
| 29 |
+
## Observation model
|
| 30 |
+
|
| 31 |
+
Per intersection:
|
| 32 |
+
|
| 33 |
+
- padded incoming lane vehicle counts
|
| 34 |
+
- padded incoming lane waiting counts
|
| 35 |
+
- incoming lane mask
|
| 36 |
+
- current green phase index
|
| 37 |
+
- elapsed time in current phase
|
| 38 |
+
- optional outgoing congestion summary
|
| 39 |
+
- district-type one-hot features
|
| 40 |
+
- optional small district context
|
| 41 |
+
- boundary-intersection indicator
|
| 42 |
+
|
| 43 |
+
The observation dimension is exposed as `TrafficEnv.observation_dim`.
|
env/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from env.intersection_config import DistrictConfig, IntersectionConfig, PhaseConfig
|
| 2 |
+
from env.observation_builder import ObservationBuilder, ObservationConfig
|
| 3 |
+
from env.reward import RewardCalculator, RewardConfig
|
| 4 |
+
from env.traffic_env import EnvConfig, TrafficEnv
|
| 5 |
+
from env.utils import build_topology
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"DistrictConfig",
|
| 9 |
+
"EnvConfig",
|
| 10 |
+
"IntersectionConfig",
|
| 11 |
+
"ObservationBuilder",
|
| 12 |
+
"ObservationConfig",
|
| 13 |
+
"PhaseConfig",
|
| 14 |
+
"RewardCalculator",
|
| 15 |
+
"RewardConfig",
|
| 16 |
+
"TrafficEnv",
|
| 17 |
+
"build_topology",
|
| 18 |
+
]
|
env/cityflow_adapter.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CityFlowAdapter:
|
| 8 |
+
def __init__(self, config_path: str | Path, thread_num: int = 1):
|
| 9 |
+
self.config_path = str(config_path)
|
| 10 |
+
self.thread_num = int(thread_num)
|
| 11 |
+
self.engine = None
|
| 12 |
+
self._phase_cache: dict[str, int] = {}
|
| 13 |
+
self._active_vehicle_ids: set[str] = set()
|
| 14 |
+
self._finished_vehicle_ids: set[str] = set()
|
| 15 |
+
|
| 16 |
+
def reset(self) -> None:
|
| 17 |
+
try:
|
| 18 |
+
import cityflow
|
| 19 |
+
except ImportError as exc:
|
| 20 |
+
raise RuntimeError(
|
| 21 |
+
"CityFlow is not installed. Install the CityFlow Python bindings "
|
| 22 |
+
"before running smoke tests, training, or evaluation."
|
| 23 |
+
) from exc
|
| 24 |
+
|
| 25 |
+
self.engine = cityflow.Engine(self.config_path, thread_num=self.thread_num)
|
| 26 |
+
self._phase_cache.clear()
|
| 27 |
+
self._active_vehicle_ids = self._fetch_active_vehicle_ids()
|
| 28 |
+
self._finished_vehicle_ids.clear()
|
| 29 |
+
|
| 30 |
+
def step(self) -> None:
|
| 31 |
+
self._require_engine()
|
| 32 |
+
self.engine.next_step()
|
| 33 |
+
current_vehicle_ids = self._fetch_active_vehicle_ids()
|
| 34 |
+
self._finished_vehicle_ids.update(self._active_vehicle_ids - current_vehicle_ids)
|
| 35 |
+
self._active_vehicle_ids = current_vehicle_ids
|
| 36 |
+
|
| 37 |
+
def set_tl_phase(self, intersection_id: str, phase: int) -> None:
|
| 38 |
+
self._require_engine()
|
| 39 |
+
self.engine.set_tl_phase(intersection_id, int(phase))
|
| 40 |
+
self._phase_cache[intersection_id] = int(phase)
|
| 41 |
+
|
| 42 |
+
def get_tl_phase(self, intersection_id: str) -> int:
|
| 43 |
+
self._require_engine()
|
| 44 |
+
if hasattr(self.engine, "get_tl_phase"):
|
| 45 |
+
phase = int(self.engine.get_tl_phase(intersection_id))
|
| 46 |
+
self._phase_cache[intersection_id] = phase
|
| 47 |
+
return phase
|
| 48 |
+
return int(self._phase_cache.get(intersection_id, 0))
|
| 49 |
+
|
| 50 |
+
def get_lane_vehicle_count(self) -> dict[str, int]:
|
| 51 |
+
self._require_engine()
|
| 52 |
+
return {
|
| 53 |
+
lane_id: int(count)
|
| 54 |
+
for lane_id, count in self.engine.get_lane_vehicle_count().items()
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
def get_lane_waiting_vehicle_count(self) -> dict[str, int]:
|
| 58 |
+
self._require_engine()
|
| 59 |
+
return {
|
| 60 |
+
lane_id: int(count)
|
| 61 |
+
for lane_id, count in self.engine.get_lane_waiting_vehicle_count().items()
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
def get_current_time(self) -> int:
|
| 65 |
+
self._require_engine()
|
| 66 |
+
return int(self.engine.get_current_time())
|
| 67 |
+
|
| 68 |
+
def get_vehicle_count(self) -> int:
|
| 69 |
+
self._require_engine()
|
| 70 |
+
if hasattr(self.engine, "get_vehicle_count"):
|
| 71 |
+
return int(self.engine.get_vehicle_count())
|
| 72 |
+
return len(self._active_vehicle_ids)
|
| 73 |
+
|
| 74 |
+
def get_average_travel_time(self) -> float | None:
|
| 75 |
+
self._require_engine()
|
| 76 |
+
if hasattr(self.engine, "get_average_travel_time"):
|
| 77 |
+
return float(self.engine.get_average_travel_time())
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
def get_finished_vehicle_count(self) -> int:
|
| 81 |
+
self._require_engine()
|
| 82 |
+
if hasattr(self.engine, "get_finished_vehicle_count"):
|
| 83 |
+
return int(self.engine.get_finished_vehicle_count())
|
| 84 |
+
return len(self._finished_vehicle_ids)
|
| 85 |
+
|
| 86 |
+
def get_active_vehicle_ids(self) -> set[str]:
|
| 87 |
+
return set(self._active_vehicle_ids)
|
| 88 |
+
|
| 89 |
+
def _fetch_active_vehicle_ids(self) -> set[str]:
|
| 90 |
+
if self.engine is None or not hasattr(self.engine, "get_vehicles"):
|
| 91 |
+
return set()
|
| 92 |
+
|
| 93 |
+
vehicles = self.engine.get_vehicles()
|
| 94 |
+
if isinstance(vehicles, dict):
|
| 95 |
+
return set(vehicles.keys())
|
| 96 |
+
return set(vehicles)
|
| 97 |
+
|
| 98 |
+
def _require_engine(self) -> None:
|
| 99 |
+
if self.engine is None:
|
| 100 |
+
raise RuntimeError(
|
| 101 |
+
"CityFlow engine has not been initialized. Call reset() before use."
|
| 102 |
+
)
|
env/district_summary.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class DistrictSummaryBuilder:
|
| 2 |
+
def build(self, adapter, district_config):
|
| 3 |
+
waiting = adapter.get_lane_waiting_vehicle_count()
|
| 4 |
+
|
| 5 |
+
return {
|
| 6 |
+
"district_id": district_config.id,
|
| 7 |
+
"intersection_ids": district_config.intersection_ids,
|
| 8 |
+
"avg_wait": sum(waiting.values()) / len(waiting),
|
| 9 |
+
}
|
env/intersection_config.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
DISTRICT_TYPES: tuple[str, ...] = (
|
| 6 |
+
"residential",
|
| 7 |
+
"commercial",
|
| 8 |
+
"industrial",
|
| 9 |
+
"mixed",
|
| 10 |
+
)
|
| 11 |
+
DISTRICT_TYPE_TO_INDEX: dict[str, int] = {
|
| 12 |
+
district_type: index for index, district_type in enumerate(DISTRICT_TYPES)
|
| 13 |
+
}
|
| 14 |
+
DEFAULT_DISTRICT_TYPE = "mixed"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass(frozen=True)
|
| 18 |
+
class PhaseConfig:
|
| 19 |
+
engine_phase_index: int
|
| 20 |
+
available_road_links: tuple[int, ...]
|
| 21 |
+
incoming_lanes_served: tuple[str, ...]
|
| 22 |
+
outgoing_lanes_served: tuple[str, ...]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass(frozen=True)
|
| 26 |
+
class IntersectionConfig:
|
| 27 |
+
intersection_id: str
|
| 28 |
+
district_id: str
|
| 29 |
+
district_type: str
|
| 30 |
+
district_type_index: int
|
| 31 |
+
incoming_lanes: tuple[str, ...]
|
| 32 |
+
outgoing_lanes: tuple[str, ...]
|
| 33 |
+
is_boundary: bool
|
| 34 |
+
green_phases: tuple[PhaseConfig, ...]
|
| 35 |
+
all_phase_indices: tuple[int, ...]
|
| 36 |
+
initial_engine_phase_index: int
|
| 37 |
+
|
| 38 |
+
@property
|
| 39 |
+
def num_green_phases(self) -> int:
|
| 40 |
+
return len(self.green_phases)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass(frozen=True)
|
| 44 |
+
class DistrictConfig:
|
| 45 |
+
district_id: str
|
| 46 |
+
district_type: str
|
| 47 |
+
district_type_index: int
|
| 48 |
+
intersection_ids: tuple[str, ...]
|
| 49 |
+
neighbor_districts: tuple[str, ...]
|
env/observation_builder.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from env.intersection_config import DistrictConfig, IntersectionConfig
|
| 8 |
+
from env.utils import normalize_scalar
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass(frozen=True)
|
| 12 |
+
class ObservationConfig:
|
| 13 |
+
max_incoming_lanes: int = 16
|
| 14 |
+
count_scale: float = 20.0
|
| 15 |
+
elapsed_time_scale: float = 60.0
|
| 16 |
+
include_outgoing_congestion: bool = True
|
| 17 |
+
include_district_context: bool = True
|
| 18 |
+
include_district_type_feature: bool = True
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ObservationBuilder:
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
intersections: dict[str, IntersectionConfig],
|
| 25 |
+
districts: dict[str, DistrictConfig],
|
| 26 |
+
config: ObservationConfig | None = None,
|
| 27 |
+
):
|
| 28 |
+
self.intersections = intersections
|
| 29 |
+
self.districts = districts
|
| 30 |
+
self.config = config or ObservationConfig()
|
| 31 |
+
self.intersection_ids = tuple(sorted(intersections))
|
| 32 |
+
self._district_lookup = {
|
| 33 |
+
intersection_id: intersections[intersection_id].district_id
|
| 34 |
+
for intersection_id in self.intersection_ids
|
| 35 |
+
}
|
| 36 |
+
self._district_sizes = {
|
| 37 |
+
district_id: max(1, len(district.intersection_ids))
|
| 38 |
+
for district_id, district in districts.items()
|
| 39 |
+
}
|
| 40 |
+
self.observation_dim = self._compute_observation_dim()
|
| 41 |
+
|
| 42 |
+
def build(
|
| 43 |
+
self,
|
| 44 |
+
lane_vehicle_count: dict[str, int],
|
| 45 |
+
lane_waiting_count: dict[str, int],
|
| 46 |
+
phase_positions: dict[str, int],
|
| 47 |
+
phase_elapsed_times: dict[str, int],
|
| 48 |
+
switch_allowed: dict[str, bool],
|
| 49 |
+
) -> dict[str, np.ndarray | tuple[str, ...]]:
|
| 50 |
+
district_context = self._compute_district_context(
|
| 51 |
+
lane_vehicle_count=lane_vehicle_count,
|
| 52 |
+
lane_waiting_count=lane_waiting_count,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
num_intersections = len(self.intersection_ids)
|
| 56 |
+
max_lanes = self.config.max_incoming_lanes
|
| 57 |
+
|
| 58 |
+
observations = np.zeros(
|
| 59 |
+
(num_intersections, self.observation_dim),
|
| 60 |
+
dtype=np.float32,
|
| 61 |
+
)
|
| 62 |
+
incoming_counts = np.zeros((num_intersections, max_lanes), dtype=np.float32)
|
| 63 |
+
incoming_waiting = np.zeros((num_intersections, max_lanes), dtype=np.float32)
|
| 64 |
+
lane_mask = np.zeros((num_intersections, max_lanes), dtype=np.float32)
|
| 65 |
+
action_mask = np.ones((num_intersections, 2), dtype=np.float32)
|
| 66 |
+
current_phase = np.zeros(num_intersections, dtype=np.int64)
|
| 67 |
+
phase_elapsed = np.zeros(num_intersections, dtype=np.float32)
|
| 68 |
+
outgoing_congestion = np.zeros(num_intersections, dtype=np.float32)
|
| 69 |
+
district_type_indices = np.zeros(num_intersections, dtype=np.int64)
|
| 70 |
+
boundary_mask = np.zeros(num_intersections, dtype=np.float32)
|
| 71 |
+
|
| 72 |
+
for row_index, intersection_id in enumerate(self.intersection_ids):
|
| 73 |
+
config = self.intersections[intersection_id]
|
| 74 |
+
lane_count_vector, waiting_vector, mask_vector = self._lane_vectors(
|
| 75 |
+
config=config,
|
| 76 |
+
lane_vehicle_count=lane_vehicle_count,
|
| 77 |
+
lane_waiting_count=lane_waiting_count,
|
| 78 |
+
)
|
| 79 |
+
incoming_counts[row_index] = lane_count_vector
|
| 80 |
+
incoming_waiting[row_index] = waiting_vector
|
| 81 |
+
lane_mask[row_index] = mask_vector
|
| 82 |
+
|
| 83 |
+
phase_index = int(phase_positions[intersection_id])
|
| 84 |
+
phase_time = float(phase_elapsed_times[intersection_id])
|
| 85 |
+
phase_count = max(1, config.num_green_phases)
|
| 86 |
+
current_phase[row_index] = phase_index
|
| 87 |
+
phase_elapsed[row_index] = phase_time
|
| 88 |
+
district_type_indices[row_index] = config.district_type_index
|
| 89 |
+
boundary_mask[row_index] = 1.0 if config.is_boundary else 0.0
|
| 90 |
+
|
| 91 |
+
next_col = 0
|
| 92 |
+
observations[row_index, next_col : next_col + max_lanes] = (
|
| 93 |
+
lane_count_vector / self.config.count_scale
|
| 94 |
+
)
|
| 95 |
+
next_col += max_lanes
|
| 96 |
+
observations[row_index, next_col : next_col + max_lanes] = (
|
| 97 |
+
waiting_vector / self.config.count_scale
|
| 98 |
+
)
|
| 99 |
+
next_col += max_lanes
|
| 100 |
+
observations[row_index, next_col : next_col + max_lanes] = mask_vector
|
| 101 |
+
next_col += max_lanes
|
| 102 |
+
|
| 103 |
+
if self.config.include_outgoing_congestion:
|
| 104 |
+
outgoing_congestion[row_index] = self._mean_outgoing_congestion(
|
| 105 |
+
config=config,
|
| 106 |
+
lane_vehicle_count=lane_vehicle_count,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
meta_features = [
|
| 110 |
+
normalize_scalar(phase_index, max(1, phase_count - 1))
|
| 111 |
+
if phase_count > 1
|
| 112 |
+
else 0.0,
|
| 113 |
+
normalize_scalar(phase_time, self.config.elapsed_time_scale),
|
| 114 |
+
normalize_scalar(float(outgoing_congestion[row_index]), self.config.count_scale),
|
| 115 |
+
normalize_scalar(float(lane_count_vector.sum()), self.config.count_scale),
|
| 116 |
+
normalize_scalar(float(phase_count), 4.0),
|
| 117 |
+
1.0 if switch_allowed[intersection_id] else 0.0,
|
| 118 |
+
boundary_mask[row_index],
|
| 119 |
+
]
|
| 120 |
+
observations[row_index, next_col : next_col + len(meta_features)] = meta_features
|
| 121 |
+
next_col += len(meta_features)
|
| 122 |
+
|
| 123 |
+
if self.config.include_district_type_feature:
|
| 124 |
+
observations[row_index, next_col + config.district_type_index] = 1.0
|
| 125 |
+
next_col += 4
|
| 126 |
+
|
| 127 |
+
if self.config.include_district_context:
|
| 128 |
+
district_values = district_context.get(
|
| 129 |
+
config.district_id,
|
| 130 |
+
(0.0, 0.0),
|
| 131 |
+
)
|
| 132 |
+
observations[row_index, next_col : next_col + len(district_values)] = district_values
|
| 133 |
+
|
| 134 |
+
if not switch_allowed[intersection_id]:
|
| 135 |
+
action_mask[row_index, 1] = 0.0
|
| 136 |
+
|
| 137 |
+
return {
|
| 138 |
+
"observations": observations,
|
| 139 |
+
"incoming_counts": incoming_counts,
|
| 140 |
+
"incoming_waiting": incoming_waiting,
|
| 141 |
+
"lane_mask": lane_mask,
|
| 142 |
+
"action_mask": action_mask,
|
| 143 |
+
"current_phase": current_phase,
|
| 144 |
+
"phase_elapsed": phase_elapsed,
|
| 145 |
+
"outgoing_congestion": outgoing_congestion,
|
| 146 |
+
"boundary_mask": boundary_mask,
|
| 147 |
+
"district_type_indices": district_type_indices,
|
| 148 |
+
"district_types": tuple(
|
| 149 |
+
self.intersections[intersection_id].district_type
|
| 150 |
+
for intersection_id in self.intersection_ids
|
| 151 |
+
),
|
| 152 |
+
"district_ids": tuple(
|
| 153 |
+
self.intersections[intersection_id].district_id
|
| 154 |
+
for intersection_id in self.intersection_ids
|
| 155 |
+
),
|
| 156 |
+
"intersection_ids": self.intersection_ids,
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
def _compute_observation_dim(self) -> int:
|
| 160 |
+
base_dim = self.config.max_incoming_lanes * 3 + 7
|
| 161 |
+
if self.config.include_district_type_feature:
|
| 162 |
+
base_dim += 4
|
| 163 |
+
if self.config.include_district_context:
|
| 164 |
+
base_dim += 2
|
| 165 |
+
return base_dim
|
| 166 |
+
|
| 167 |
+
def _lane_vectors(
|
| 168 |
+
self,
|
| 169 |
+
config: IntersectionConfig,
|
| 170 |
+
lane_vehicle_count: dict[str, int],
|
| 171 |
+
lane_waiting_count: dict[str, int],
|
| 172 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 173 |
+
max_lanes = self.config.max_incoming_lanes
|
| 174 |
+
count_vector = np.zeros(max_lanes, dtype=np.float32)
|
| 175 |
+
waiting_vector = np.zeros(max_lanes, dtype=np.float32)
|
| 176 |
+
mask_vector = np.zeros(max_lanes, dtype=np.float32)
|
| 177 |
+
|
| 178 |
+
for lane_index, lane_id in enumerate(config.incoming_lanes[:max_lanes]):
|
| 179 |
+
count_vector[lane_index] = float(lane_vehicle_count.get(lane_id, 0))
|
| 180 |
+
waiting_vector[lane_index] = float(lane_waiting_count.get(lane_id, 0))
|
| 181 |
+
mask_vector[lane_index] = 1.0
|
| 182 |
+
|
| 183 |
+
return count_vector, waiting_vector, mask_vector
|
| 184 |
+
|
| 185 |
+
def _mean_outgoing_congestion(
|
| 186 |
+
self,
|
| 187 |
+
config: IntersectionConfig,
|
| 188 |
+
lane_vehicle_count: dict[str, int],
|
| 189 |
+
) -> float:
|
| 190 |
+
if not config.outgoing_lanes:
|
| 191 |
+
return 0.0
|
| 192 |
+
total = sum(float(lane_vehicle_count.get(lane_id, 0)) for lane_id in config.outgoing_lanes)
|
| 193 |
+
return total / float(len(config.outgoing_lanes))
|
| 194 |
+
|
| 195 |
+
def _compute_district_context(
|
| 196 |
+
self,
|
| 197 |
+
lane_vehicle_count: dict[str, int],
|
| 198 |
+
lane_waiting_count: dict[str, int],
|
| 199 |
+
) -> dict[str, tuple[float, float]]:
|
| 200 |
+
context: dict[str, tuple[float, float]] = {}
|
| 201 |
+
if not self.config.include_district_context:
|
| 202 |
+
return context
|
| 203 |
+
|
| 204 |
+
for district_id, district in self.districts.items():
|
| 205 |
+
total_count = 0.0
|
| 206 |
+
total_waiting = 0.0
|
| 207 |
+
for intersection_id in district.intersection_ids:
|
| 208 |
+
config = self.intersections[intersection_id]
|
| 209 |
+
total_count += sum(
|
| 210 |
+
float(lane_vehicle_count.get(lane_id, 0))
|
| 211 |
+
for lane_id in config.incoming_lanes
|
| 212 |
+
)
|
| 213 |
+
total_waiting += sum(
|
| 214 |
+
float(lane_waiting_count.get(lane_id, 0))
|
| 215 |
+
for lane_id in config.incoming_lanes
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
size = float(self._district_sizes[district_id])
|
| 219 |
+
context[district_id] = (
|
| 220 |
+
normalize_scalar(total_count / size, self.config.count_scale),
|
| 221 |
+
normalize_scalar(total_waiting / size, self.config.count_scale),
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
return context
|
env/reward.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
REWARD_VARIANTS: tuple[str, ...] = (
|
| 8 |
+
"current",
|
| 9 |
+
"normalized_wait_queue",
|
| 10 |
+
"wait_queue_throughput",
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(frozen=True)
|
| 15 |
+
class RewardConfig:
|
| 16 |
+
variant: str = "current"
|
| 17 |
+
waiting_weight: float = 1.0
|
| 18 |
+
vehicle_weight: float = 0.1
|
| 19 |
+
pressure_weight: float = 0.0
|
| 20 |
+
reward_scale: float = 0.1
|
| 21 |
+
normalize_by_lane_count: bool = True
|
| 22 |
+
clip_reward: float | None = 5.0
|
| 23 |
+
queue_delta_weight: float = 2.0
|
| 24 |
+
wait_delta_weight: float = 4.0
|
| 25 |
+
queue_level_weight: float = 0.5
|
| 26 |
+
wait_level_weight: float = 1.0
|
| 27 |
+
throughput_weight: float = 0.1
|
| 28 |
+
imbalance_weight: float = 0.1
|
| 29 |
+
delta_clip: float = 2.0
|
| 30 |
+
level_normalizer: float = 10.0
|
| 31 |
+
throughput_normalizer: float = 2.0
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass(frozen=True)
|
| 35 |
+
class RewardBreakdown:
|
| 36 |
+
reward: np.ndarray
|
| 37 |
+
components: dict[str, np.ndarray]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class RewardCalculator:
|
| 41 |
+
def __init__(self, config: RewardConfig | None = None):
|
| 42 |
+
self.config = config or RewardConfig()
|
| 43 |
+
if self.config.variant not in REWARD_VARIANTS:
|
| 44 |
+
raise ValueError(
|
| 45 |
+
f"Unsupported reward variant: {self.config.variant}. "
|
| 46 |
+
f"Expected one of {REWARD_VARIANTS}."
|
| 47 |
+
)
|
| 48 |
+
self._prev_queue_norm: np.ndarray | None = None
|
| 49 |
+
self._prev_wait_norm: np.ndarray | None = None
|
| 50 |
+
self._prev_finished_vehicle_count = 0.0
|
| 51 |
+
|
| 52 |
+
def reset(
|
| 53 |
+
self,
|
| 54 |
+
incoming_waiting: np.ndarray,
|
| 55 |
+
incoming_counts: np.ndarray,
|
| 56 |
+
incoming_lane_counts: np.ndarray | None = None,
|
| 57 |
+
finished_vehicle_count: float = 0.0,
|
| 58 |
+
) -> None:
|
| 59 |
+
queue_norm, wait_norm, _ = self._normalized_state(
|
| 60 |
+
incoming_waiting=incoming_waiting,
|
| 61 |
+
incoming_counts=incoming_counts,
|
| 62 |
+
incoming_lane_counts=incoming_lane_counts,
|
| 63 |
+
)
|
| 64 |
+
self._prev_queue_norm = queue_norm
|
| 65 |
+
self._prev_wait_norm = wait_norm
|
| 66 |
+
self._prev_finished_vehicle_count = float(finished_vehicle_count)
|
| 67 |
+
|
| 68 |
+
def compute(
|
| 69 |
+
self,
|
| 70 |
+
incoming_waiting: np.ndarray,
|
| 71 |
+
incoming_counts: np.ndarray,
|
| 72 |
+
outgoing_counts: np.ndarray | None = None,
|
| 73 |
+
incoming_lane_counts: np.ndarray | None = None,
|
| 74 |
+
finished_vehicle_count: float = 0.0,
|
| 75 |
+
) -> np.ndarray:
|
| 76 |
+
return self.compute_breakdown(
|
| 77 |
+
incoming_waiting=incoming_waiting,
|
| 78 |
+
incoming_counts=incoming_counts,
|
| 79 |
+
outgoing_counts=outgoing_counts,
|
| 80 |
+
incoming_lane_counts=incoming_lane_counts,
|
| 81 |
+
finished_vehicle_count=finished_vehicle_count,
|
| 82 |
+
).reward
|
| 83 |
+
|
| 84 |
+
def compute_breakdown(
|
| 85 |
+
self,
|
| 86 |
+
incoming_waiting: np.ndarray,
|
| 87 |
+
incoming_counts: np.ndarray,
|
| 88 |
+
outgoing_counts: np.ndarray | None = None,
|
| 89 |
+
incoming_lane_counts: np.ndarray | None = None,
|
| 90 |
+
finished_vehicle_count: float = 0.0,
|
| 91 |
+
) -> RewardBreakdown:
|
| 92 |
+
if self.config.variant == "current":
|
| 93 |
+
return self._compute_current(
|
| 94 |
+
incoming_waiting=incoming_waiting,
|
| 95 |
+
incoming_counts=incoming_counts,
|
| 96 |
+
outgoing_counts=outgoing_counts,
|
| 97 |
+
incoming_lane_counts=incoming_lane_counts,
|
| 98 |
+
)
|
| 99 |
+
return self._compute_delta_based(
|
| 100 |
+
incoming_waiting=incoming_waiting,
|
| 101 |
+
incoming_counts=incoming_counts,
|
| 102 |
+
incoming_lane_counts=incoming_lane_counts,
|
| 103 |
+
finished_vehicle_count=finished_vehicle_count,
|
| 104 |
+
include_throughput=self.config.variant == "wait_queue_throughput",
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def _compute_current(
|
| 108 |
+
self,
|
| 109 |
+
incoming_waiting: np.ndarray,
|
| 110 |
+
incoming_counts: np.ndarray,
|
| 111 |
+
outgoing_counts: np.ndarray | None = None,
|
| 112 |
+
incoming_lane_counts: np.ndarray | None = None,
|
| 113 |
+
) -> RewardBreakdown:
|
| 114 |
+
waiting_total = incoming_waiting.sum(axis=1)
|
| 115 |
+
vehicle_total = incoming_counts.sum(axis=1)
|
| 116 |
+
normalization = self._lane_normalization(waiting_total.shape[0], incoming_lane_counts)
|
| 117 |
+
|
| 118 |
+
components = {
|
| 119 |
+
"wait_term": (-self.config.waiting_weight * waiting_total / normalization).astype(np.float32),
|
| 120 |
+
"queue_term": (-self.config.vehicle_weight * vehicle_total / normalization).astype(np.float32),
|
| 121 |
+
}
|
| 122 |
+
if outgoing_counts is not None and self.config.pressure_weight != 0.0:
|
| 123 |
+
outgoing_total = outgoing_counts.sum(axis=1)
|
| 124 |
+
components["pressure_term"] = (
|
| 125 |
+
self.config.pressure_weight * (outgoing_total - vehicle_total) / normalization
|
| 126 |
+
).astype(np.float32)
|
| 127 |
+
components = self._scale_components(components)
|
| 128 |
+
reward = self._finalize_reward(components)
|
| 129 |
+
return RewardBreakdown(reward=reward, components=components)
|
| 130 |
+
|
| 131 |
+
def _compute_delta_based(
|
| 132 |
+
self,
|
| 133 |
+
incoming_waiting: np.ndarray,
|
| 134 |
+
incoming_counts: np.ndarray,
|
| 135 |
+
incoming_lane_counts: np.ndarray | None,
|
| 136 |
+
finished_vehicle_count: float,
|
| 137 |
+
include_throughput: bool,
|
| 138 |
+
) -> RewardBreakdown:
|
| 139 |
+
queue_norm, wait_norm, lane_norm = self._normalized_state(
|
| 140 |
+
incoming_waiting=incoming_waiting,
|
| 141 |
+
incoming_counts=incoming_counts,
|
| 142 |
+
incoming_lane_counts=incoming_lane_counts,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if self._prev_queue_norm is None or self._prev_wait_norm is None:
|
| 146 |
+
self._prev_queue_norm = queue_norm.copy()
|
| 147 |
+
self._prev_wait_norm = wait_norm.copy()
|
| 148 |
+
|
| 149 |
+
queue_delta = np.clip(
|
| 150 |
+
self._prev_queue_norm - queue_norm,
|
| 151 |
+
-self.config.delta_clip,
|
| 152 |
+
self.config.delta_clip,
|
| 153 |
+
).astype(np.float32)
|
| 154 |
+
wait_delta = np.clip(
|
| 155 |
+
self._prev_wait_norm - wait_norm,
|
| 156 |
+
-self.config.delta_clip,
|
| 157 |
+
self.config.delta_clip,
|
| 158 |
+
).astype(np.float32)
|
| 159 |
+
|
| 160 |
+
components: dict[str, np.ndarray] = {
|
| 161 |
+
"queue_term": (self.config.queue_delta_weight * queue_delta).astype(np.float32),
|
| 162 |
+
"wait_term": (self.config.wait_delta_weight * wait_delta).astype(np.float32),
|
| 163 |
+
"queue_level_term": (
|
| 164 |
+
-self.config.queue_level_weight
|
| 165 |
+
* np.clip(queue_norm / self.config.level_normalizer, 0.0, self.config.delta_clip)
|
| 166 |
+
).astype(np.float32),
|
| 167 |
+
"wait_level_term": (
|
| 168 |
+
-self.config.wait_level_weight
|
| 169 |
+
* np.clip(wait_norm / self.config.level_normalizer, 0.0, self.config.delta_clip)
|
| 170 |
+
).astype(np.float32),
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
if include_throughput:
|
| 174 |
+
num_intersections = max(1, queue_norm.shape[0])
|
| 175 |
+
finished_delta = max(
|
| 176 |
+
0.0,
|
| 177 |
+
float(finished_vehicle_count) - self._prev_finished_vehicle_count,
|
| 178 |
+
)
|
| 179 |
+
throughput_per_intersection = finished_delta / float(num_intersections)
|
| 180 |
+
throughput_term = np.full(
|
| 181 |
+
queue_norm.shape,
|
| 182 |
+
self.config.throughput_weight
|
| 183 |
+
* min(1.0, throughput_per_intersection / self.config.throughput_normalizer),
|
| 184 |
+
dtype=np.float32,
|
| 185 |
+
)
|
| 186 |
+
imbalance = np.std(
|
| 187 |
+
incoming_waiting / lane_norm[:, None],
|
| 188 |
+
axis=1,
|
| 189 |
+
).astype(np.float32)
|
| 190 |
+
components["throughput_term"] = throughput_term
|
| 191 |
+
components["imbalance_term"] = (-self.config.imbalance_weight * imbalance).astype(
|
| 192 |
+
np.float32
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
components = self._scale_components(components)
|
| 196 |
+
reward = self._finalize_reward(components)
|
| 197 |
+
self._prev_queue_norm = queue_norm
|
| 198 |
+
self._prev_wait_norm = wait_norm
|
| 199 |
+
self._prev_finished_vehicle_count = float(finished_vehicle_count)
|
| 200 |
+
return RewardBreakdown(reward=reward, components=components)
|
| 201 |
+
|
| 202 |
+
def _normalized_state(
|
| 203 |
+
self,
|
| 204 |
+
incoming_waiting: np.ndarray,
|
| 205 |
+
incoming_counts: np.ndarray,
|
| 206 |
+
incoming_lane_counts: np.ndarray | None,
|
| 207 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 208 |
+
lane_norm = self._lane_normalization(incoming_counts.shape[0], incoming_lane_counts)
|
| 209 |
+
queue_norm = (incoming_counts.sum(axis=1) / lane_norm).astype(np.float32)
|
| 210 |
+
wait_norm = (incoming_waiting.sum(axis=1) / lane_norm).astype(np.float32)
|
| 211 |
+
return queue_norm, wait_norm, lane_norm
|
| 212 |
+
|
| 213 |
+
def _lane_normalization(
|
| 214 |
+
self,
|
| 215 |
+
batch_size: int,
|
| 216 |
+
incoming_lane_counts: np.ndarray | None,
|
| 217 |
+
) -> np.ndarray:
|
| 218 |
+
normalization = np.ones(batch_size, dtype=np.float32)
|
| 219 |
+
if incoming_lane_counts is not None and self.config.normalize_by_lane_count:
|
| 220 |
+
normalization = np.maximum(1.0, incoming_lane_counts.astype(np.float32))
|
| 221 |
+
return normalization
|
| 222 |
+
|
| 223 |
+
def _finalize_reward(self, components: dict[str, np.ndarray]) -> np.ndarray:
|
| 224 |
+
reward = np.zeros_like(next(iter(components.values())), dtype=np.float32)
|
| 225 |
+
for term in components.values():
|
| 226 |
+
reward += term.astype(np.float32)
|
| 227 |
+
|
| 228 |
+
if self.config.clip_reward is not None:
|
| 229 |
+
reward = np.clip(
|
| 230 |
+
reward,
|
| 231 |
+
-float(self.config.clip_reward),
|
| 232 |
+
float(self.config.clip_reward),
|
| 233 |
+
)
|
| 234 |
+
return reward.astype(np.float32)
|
| 235 |
+
|
| 236 |
+
def _scale_components(
|
| 237 |
+
self,
|
| 238 |
+
components: dict[str, np.ndarray],
|
| 239 |
+
) -> dict[str, np.ndarray]:
|
| 240 |
+
scale = float(self.config.reward_scale)
|
| 241 |
+
return {
|
| 242 |
+
name: (values.astype(np.float32) * scale).astype(np.float32)
|
| 243 |
+
for name, values in components.items()
|
| 244 |
+
}
|
env/scenarios.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ScenarioGenerator:
|
| 2 |
+
def generate(self, seed):
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
random.seed(seed)
|
| 6 |
+
|
| 7 |
+
return {
|
| 8 |
+
"traffic_bias": random.choice(["ns", "ew", "balanced"]),
|
| 9 |
+
"emergency_vehicle": random.random() < 0.2,
|
| 10 |
+
}
|
env/traffic_env.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from env.cityflow_adapter import CityFlowAdapter
|
| 9 |
+
from env.intersection_config import DistrictConfig, IntersectionConfig
|
| 10 |
+
from env.observation_builder import ObservationBuilder, ObservationConfig
|
| 11 |
+
from env.reward import RewardCalculator, RewardConfig
|
| 12 |
+
from env.utils import build_topology, load_json
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass(frozen=True)
|
| 16 |
+
class EnvConfig:
|
| 17 |
+
simulator_interval: int = 1
|
| 18 |
+
decision_interval: int = 5
|
| 19 |
+
min_green_time: int = 10
|
| 20 |
+
thread_num: int = 1
|
| 21 |
+
observation: ObservationConfig = ObservationConfig()
|
| 22 |
+
reward: RewardConfig = RewardConfig()
|
| 23 |
+
max_episode_seconds: int | None = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TrafficEnv:
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
city_id: str,
|
| 30 |
+
scenario_name: str,
|
| 31 |
+
city_dir: str | Path,
|
| 32 |
+
scenario_dir: str | Path,
|
| 33 |
+
config_path: str | Path,
|
| 34 |
+
roadnet_path: str | Path,
|
| 35 |
+
district_map_path: str | Path | None = None,
|
| 36 |
+
metadata_path: str | Path | None = None,
|
| 37 |
+
env_config: EnvConfig | None = None,
|
| 38 |
+
):
|
| 39 |
+
self.city_id = city_id
|
| 40 |
+
self.scenario_name = scenario_name
|
| 41 |
+
self.city_dir = Path(city_dir)
|
| 42 |
+
self.scenario_dir = Path(scenario_dir)
|
| 43 |
+
self.original_config_path = Path(config_path)
|
| 44 |
+
self.roadnet_path = Path(roadnet_path)
|
| 45 |
+
self.district_map_path = Path(district_map_path) if district_map_path else None
|
| 46 |
+
self.metadata_path = Path(metadata_path) if metadata_path else None
|
| 47 |
+
self.env_config = env_config or EnvConfig()
|
| 48 |
+
|
| 49 |
+
self.intersections, self.districts = build_topology(
|
| 50 |
+
roadnet_path=self.roadnet_path,
|
| 51 |
+
district_map_path=self.district_map_path,
|
| 52 |
+
metadata_path=self.metadata_path,
|
| 53 |
+
)
|
| 54 |
+
if not self.intersections:
|
| 55 |
+
raise ValueError(
|
| 56 |
+
f"No controllable intersections found for {self.city_id}/{self.scenario_name}."
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
self.controlled_intersection_ids = tuple(sorted(self.intersections))
|
| 60 |
+
self.observation_builder = ObservationBuilder(
|
| 61 |
+
intersections=self.intersections,
|
| 62 |
+
districts=self.districts,
|
| 63 |
+
config=self.env_config.observation,
|
| 64 |
+
)
|
| 65 |
+
self.reward_calculator = RewardCalculator(self.env_config.reward)
|
| 66 |
+
self.adapter = CityFlowAdapter(
|
| 67 |
+
config_path=self.original_config_path,
|
| 68 |
+
thread_num=self.env_config.thread_num,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
config_payload = load_json(self.original_config_path)
|
| 72 |
+
self.max_episode_seconds = int(
|
| 73 |
+
self.env_config.max_episode_seconds
|
| 74 |
+
or config_payload.get("step", 0)
|
| 75 |
+
)
|
| 76 |
+
self.metadata = load_json(self.metadata_path) if self.metadata_path else {}
|
| 77 |
+
self._district_type_labels = tuple(
|
| 78 |
+
self.intersections[intersection_id].district_type
|
| 79 |
+
for intersection_id in self.controlled_intersection_ids
|
| 80 |
+
)
|
| 81 |
+
self._incoming_lane_counts = np.asarray(
|
| 82 |
+
[
|
| 83 |
+
max(1, len(self.intersections[intersection_id].incoming_lanes))
|
| 84 |
+
for intersection_id in self.controlled_intersection_ids
|
| 85 |
+
],
|
| 86 |
+
dtype=np.float32,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
self.current_phase_positions: dict[str, int] = {}
|
| 90 |
+
self.phase_elapsed_times: dict[str, int] = {}
|
| 91 |
+
self.decision_step_count = 0
|
| 92 |
+
self.episode_return = 0.0
|
| 93 |
+
self.total_episode_return = 0.0
|
| 94 |
+
self.last_info: dict[str, Any] = {}
|
| 95 |
+
self.reward_component_sums: dict[str, float] = {}
|
| 96 |
+
|
| 97 |
+
@property
|
| 98 |
+
def observation_dim(self) -> int:
|
| 99 |
+
return self.observation_builder.observation_dim
|
| 100 |
+
|
| 101 |
+
def reset(self, seed: int | None = None) -> dict[str, Any]:
|
| 102 |
+
del seed
|
| 103 |
+
self.adapter.reset()
|
| 104 |
+
self.decision_step_count = 0
|
| 105 |
+
self.episode_return = 0.0
|
| 106 |
+
self.total_episode_return = 0.0
|
| 107 |
+
self.reward_component_sums = {}
|
| 108 |
+
|
| 109 |
+
self.current_phase_positions = {}
|
| 110 |
+
self.phase_elapsed_times = {}
|
| 111 |
+
for intersection_id in self.controlled_intersection_ids:
|
| 112 |
+
config = self.intersections[intersection_id]
|
| 113 |
+
initial_position = 0
|
| 114 |
+
initial_phase = config.green_phases[initial_position].engine_phase_index
|
| 115 |
+
self.current_phase_positions[intersection_id] = initial_position
|
| 116 |
+
self.phase_elapsed_times[intersection_id] = 0
|
| 117 |
+
self.adapter.set_tl_phase(intersection_id, initial_phase)
|
| 118 |
+
|
| 119 |
+
observation = self._build_observation()
|
| 120 |
+
self.reward_calculator.reset(
|
| 121 |
+
incoming_waiting=observation["incoming_waiting"],
|
| 122 |
+
incoming_counts=observation["incoming_counts"],
|
| 123 |
+
incoming_lane_counts=self._incoming_lane_counts,
|
| 124 |
+
finished_vehicle_count=self.adapter.get_finished_vehicle_count(),
|
| 125 |
+
)
|
| 126 |
+
self.last_info = self._build_info(
|
| 127 |
+
rewards=np.zeros(len(self.controlled_intersection_ids), dtype=np.float32),
|
| 128 |
+
avg_incoming_counts=observation["incoming_counts"],
|
| 129 |
+
avg_incoming_waiting=observation["incoming_waiting"],
|
| 130 |
+
reward_components={},
|
| 131 |
+
)
|
| 132 |
+
return observation
|
| 133 |
+
|
| 134 |
+
def step(
|
| 135 |
+
self,
|
| 136 |
+
actions: dict[str, int] | list[int] | np.ndarray,
|
| 137 |
+
) -> tuple[dict[str, Any], np.ndarray, bool, dict[str, Any]]:
|
| 138 |
+
normalized_actions = self._normalize_actions(actions)
|
| 139 |
+
self._apply_actions(normalized_actions)
|
| 140 |
+
|
| 141 |
+
avg_incoming_counts, avg_incoming_waiting, avg_outgoing_counts = self._advance_simulator()
|
| 142 |
+
reward_breakdown = self.reward_calculator.compute_breakdown(
|
| 143 |
+
incoming_waiting=avg_incoming_waiting,
|
| 144 |
+
incoming_counts=avg_incoming_counts,
|
| 145 |
+
outgoing_counts=avg_outgoing_counts,
|
| 146 |
+
incoming_lane_counts=self._incoming_lane_counts,
|
| 147 |
+
finished_vehicle_count=self.adapter.get_finished_vehicle_count(),
|
| 148 |
+
)
|
| 149 |
+
rewards = reward_breakdown.reward
|
| 150 |
+
self.decision_step_count += 1
|
| 151 |
+
self.total_episode_return += float(rewards.sum())
|
| 152 |
+
self.episode_return = self._mean_step_intersection_reward()
|
| 153 |
+
self._accumulate_reward_components(reward_breakdown.components)
|
| 154 |
+
|
| 155 |
+
observation = self._build_observation()
|
| 156 |
+
done = self.adapter.get_current_time() >= self.max_episode_seconds
|
| 157 |
+
info = self._build_info(
|
| 158 |
+
rewards=rewards,
|
| 159 |
+
avg_incoming_counts=avg_incoming_counts,
|
| 160 |
+
avg_incoming_waiting=avg_incoming_waiting,
|
| 161 |
+
reward_components=reward_breakdown.components,
|
| 162 |
+
)
|
| 163 |
+
self.last_info = info
|
| 164 |
+
return observation, rewards, done, info
|
| 165 |
+
|
| 166 |
+
def _build_observation(self) -> dict[str, Any]:
|
| 167 |
+
lane_vehicle_count = self.adapter.get_lane_vehicle_count()
|
| 168 |
+
lane_waiting_count = self.adapter.get_lane_waiting_vehicle_count()
|
| 169 |
+
switch_allowed = {
|
| 170 |
+
intersection_id: (
|
| 171 |
+
self.phase_elapsed_times[intersection_id] >= self.env_config.min_green_time
|
| 172 |
+
)
|
| 173 |
+
for intersection_id in self.controlled_intersection_ids
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
observation = self.observation_builder.build(
|
| 177 |
+
lane_vehicle_count=lane_vehicle_count,
|
| 178 |
+
lane_waiting_count=lane_waiting_count,
|
| 179 |
+
phase_positions=self.current_phase_positions,
|
| 180 |
+
phase_elapsed_times=self.phase_elapsed_times,
|
| 181 |
+
switch_allowed=switch_allowed,
|
| 182 |
+
)
|
| 183 |
+
observation["city_id"] = self.city_id
|
| 184 |
+
observation["scenario_name"] = self.scenario_name
|
| 185 |
+
observation["decision_step"] = self.decision_step_count
|
| 186 |
+
observation["sim_time"] = self.adapter.get_current_time()
|
| 187 |
+
return observation
|
| 188 |
+
|
| 189 |
+
def _apply_actions(self, actions: np.ndarray) -> None:
|
| 190 |
+
for action_index, intersection_id in enumerate(self.controlled_intersection_ids):
|
| 191 |
+
config = self.intersections[intersection_id]
|
| 192 |
+
current_position = self.current_phase_positions[intersection_id]
|
| 193 |
+
can_switch = self.phase_elapsed_times[intersection_id] >= self.env_config.min_green_time
|
| 194 |
+
should_switch = int(actions[action_index]) == 1 and can_switch
|
| 195 |
+
|
| 196 |
+
if should_switch:
|
| 197 |
+
next_position = (current_position + 1) % config.num_green_phases
|
| 198 |
+
engine_phase = config.green_phases[next_position].engine_phase_index
|
| 199 |
+
self.adapter.set_tl_phase(intersection_id, engine_phase)
|
| 200 |
+
self.current_phase_positions[intersection_id] = next_position
|
| 201 |
+
self.phase_elapsed_times[intersection_id] = 0
|
| 202 |
+
else:
|
| 203 |
+
current_engine_phase = config.green_phases[current_position].engine_phase_index
|
| 204 |
+
self.adapter.set_tl_phase(intersection_id, current_engine_phase)
|
| 205 |
+
|
| 206 |
+
def _advance_simulator(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 207 |
+
num_intersections = len(self.controlled_intersection_ids)
|
| 208 |
+
max_lanes = self.env_config.observation.max_incoming_lanes
|
| 209 |
+
avg_incoming_counts = np.zeros((num_intersections, max_lanes), dtype=np.float32)
|
| 210 |
+
avg_incoming_waiting = np.zeros((num_intersections, max_lanes), dtype=np.float32)
|
| 211 |
+
avg_outgoing_counts = np.zeros((num_intersections, max_lanes), dtype=np.float32)
|
| 212 |
+
|
| 213 |
+
for _ in range(self.env_config.decision_interval):
|
| 214 |
+
self.adapter.step()
|
| 215 |
+
lane_vehicle_count = self.adapter.get_lane_vehicle_count()
|
| 216 |
+
lane_waiting_count = self.adapter.get_lane_waiting_vehicle_count()
|
| 217 |
+
|
| 218 |
+
for row_index, intersection_id in enumerate(self.controlled_intersection_ids):
|
| 219 |
+
config = self.intersections[intersection_id]
|
| 220 |
+
for lane_index, lane_id in enumerate(
|
| 221 |
+
config.incoming_lanes[: self.env_config.observation.max_incoming_lanes]
|
| 222 |
+
):
|
| 223 |
+
avg_incoming_counts[row_index, lane_index] += float(
|
| 224 |
+
lane_vehicle_count.get(lane_id, 0)
|
| 225 |
+
)
|
| 226 |
+
avg_incoming_waiting[row_index, lane_index] += float(
|
| 227 |
+
lane_waiting_count.get(lane_id, 0)
|
| 228 |
+
)
|
| 229 |
+
for lane_index, lane_id in enumerate(
|
| 230 |
+
config.outgoing_lanes[: self.env_config.observation.max_incoming_lanes]
|
| 231 |
+
):
|
| 232 |
+
avg_outgoing_counts[row_index, lane_index] += float(
|
| 233 |
+
lane_vehicle_count.get(lane_id, 0)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
self.phase_elapsed_times[intersection_id] += self.env_config.simulator_interval
|
| 237 |
+
|
| 238 |
+
avg_incoming_counts /= float(self.env_config.decision_interval)
|
| 239 |
+
avg_incoming_waiting /= float(self.env_config.decision_interval)
|
| 240 |
+
avg_outgoing_counts /= float(self.env_config.decision_interval)
|
| 241 |
+
return avg_incoming_counts, avg_incoming_waiting, avg_outgoing_counts
|
| 242 |
+
|
| 243 |
+
def _build_info(
|
| 244 |
+
self,
|
| 245 |
+
rewards: np.ndarray,
|
| 246 |
+
avg_incoming_counts: np.ndarray,
|
| 247 |
+
avg_incoming_waiting: np.ndarray,
|
| 248 |
+
reward_components: dict[str, np.ndarray],
|
| 249 |
+
) -> dict[str, Any]:
|
| 250 |
+
mean_reward = float(rewards.mean()) if rewards.size else 0.0
|
| 251 |
+
average_travel_time = self.adapter.get_average_travel_time()
|
| 252 |
+
info = {
|
| 253 |
+
"city_id": self.city_id,
|
| 254 |
+
"scenario_name": self.scenario_name,
|
| 255 |
+
"decision_step": self.decision_step_count,
|
| 256 |
+
"sim_time": self.adapter.get_current_time(),
|
| 257 |
+
"episode_return": float(self.episode_return),
|
| 258 |
+
"total_episode_return": float(self.total_episode_return),
|
| 259 |
+
"intersection_ids": self.controlled_intersection_ids,
|
| 260 |
+
"district_types": self._district_type_labels,
|
| 261 |
+
"metrics": {
|
| 262 |
+
"num_controlled_intersections": len(self.controlled_intersection_ids),
|
| 263 |
+
"mean_reward": mean_reward,
|
| 264 |
+
"mean_step_intersection_reward": self._mean_step_intersection_reward(),
|
| 265 |
+
"mean_waiting_vehicles": float(avg_incoming_waiting.sum(axis=1).mean()),
|
| 266 |
+
"mean_incoming_vehicles": float(avg_incoming_counts.sum(axis=1).mean()),
|
| 267 |
+
"total_waiting_vehicles": float(avg_incoming_waiting.sum()),
|
| 268 |
+
"total_incoming_vehicles": float(avg_incoming_counts.sum()),
|
| 269 |
+
"running_vehicles": self.adapter.get_vehicle_count(),
|
| 270 |
+
"throughput": self.adapter.get_finished_vehicle_count(),
|
| 271 |
+
"average_travel_time": average_travel_time,
|
| 272 |
+
"reward_variant": self.env_config.reward.variant,
|
| 273 |
+
},
|
| 274 |
+
}
|
| 275 |
+
info["metrics"].update(self._reward_component_metrics(reward_components))
|
| 276 |
+
info["metrics"].update(
|
| 277 |
+
per_district_type_metrics(
|
| 278 |
+
district_types=self._district_type_labels,
|
| 279 |
+
rewards=rewards,
|
| 280 |
+
avg_incoming_counts=avg_incoming_counts,
|
| 281 |
+
avg_incoming_waiting=avg_incoming_waiting,
|
| 282 |
+
)
|
| 283 |
+
)
|
| 284 |
+
return info
|
| 285 |
+
|
| 286 |
+
def _normalize_actions(
|
| 287 |
+
self,
|
| 288 |
+
actions: dict[str, int] | list[int] | np.ndarray,
|
| 289 |
+
) -> np.ndarray:
|
| 290 |
+
if isinstance(actions, dict):
|
| 291 |
+
return np.asarray(
|
| 292 |
+
[int(actions.get(intersection_id, 0)) for intersection_id in self.controlled_intersection_ids],
|
| 293 |
+
dtype=np.int64,
|
| 294 |
+
)
|
| 295 |
+
array = np.asarray(actions, dtype=np.int64)
|
| 296 |
+
if array.shape != (len(self.controlled_intersection_ids),):
|
| 297 |
+
raise ValueError(
|
| 298 |
+
"Actions must provide exactly one action per controlled intersection."
|
| 299 |
+
)
|
| 300 |
+
return array
|
| 301 |
+
|
| 302 |
+
def _mean_step_intersection_reward(self) -> float:
|
| 303 |
+
denominator = max(
|
| 304 |
+
1,
|
| 305 |
+
self.decision_step_count * len(self.controlled_intersection_ids),
|
| 306 |
+
)
|
| 307 |
+
return float(self.total_episode_return) / float(denominator)
|
| 308 |
+
|
| 309 |
+
def _accumulate_reward_components(self, components: dict[str, np.ndarray]) -> None:
|
| 310 |
+
for name, values in components.items():
|
| 311 |
+
self.reward_component_sums[name] = self.reward_component_sums.get(name, 0.0) + float(
|
| 312 |
+
np.asarray(values, dtype=np.float32).mean()
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
def _reward_component_metrics(
|
| 316 |
+
self,
|
| 317 |
+
reward_components: dict[str, np.ndarray],
|
| 318 |
+
) -> dict[str, float]:
|
| 319 |
+
metrics: dict[str, float] = {}
|
| 320 |
+
for name, values in reward_components.items():
|
| 321 |
+
metrics[f"reward_component_step_{name}"] = float(
|
| 322 |
+
np.asarray(values, dtype=np.float32).mean()
|
| 323 |
+
)
|
| 324 |
+
if self.decision_step_count <= 0:
|
| 325 |
+
return metrics
|
| 326 |
+
for name, total in self.reward_component_sums.items():
|
| 327 |
+
metrics[f"reward_component_mean_{name}"] = float(total) / float(
|
| 328 |
+
self.decision_step_count
|
| 329 |
+
)
|
| 330 |
+
return metrics
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def per_district_type_metrics(
|
| 334 |
+
district_types: tuple[str, ...],
|
| 335 |
+
rewards: np.ndarray,
|
| 336 |
+
avg_incoming_counts: np.ndarray,
|
| 337 |
+
avg_incoming_waiting: np.ndarray,
|
| 338 |
+
) -> dict[str, float]:
|
| 339 |
+
metrics: dict[str, float] = {}
|
| 340 |
+
reward_vector = np.asarray(rewards, dtype=np.float32)
|
| 341 |
+
incoming_totals = avg_incoming_counts.sum(axis=1)
|
| 342 |
+
waiting_totals = avg_incoming_waiting.sum(axis=1)
|
| 343 |
+
|
| 344 |
+
for district_type in sorted(set(district_types)):
|
| 345 |
+
mask = np.asarray(
|
| 346 |
+
[item == district_type for item in district_types],
|
| 347 |
+
dtype=bool,
|
| 348 |
+
)
|
| 349 |
+
if not mask.any():
|
| 350 |
+
continue
|
| 351 |
+
metrics[f"num_{district_type}_intersections"] = float(mask.sum())
|
| 352 |
+
metrics[f"mean_reward_{district_type}"] = float(reward_vector[mask].mean())
|
| 353 |
+
metrics[f"mean_waiting_vehicles_{district_type}"] = float(waiting_totals[mask].mean())
|
| 354 |
+
metrics[f"mean_incoming_vehicles_{district_type}"] = float(incoming_totals[mask].mean())
|
| 355 |
+
|
| 356 |
+
return metrics
|