tokev commited on
Commit
38d40b2
·
verified ·
1 Parent(s): 9c55eea

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. Dockerfile +49 -0
  3. README.md +186 -6
  4. __init__.py +1 -0
  5. agents/README.md +20 -0
  6. agents/__init__.py +15 -0
  7. agents/district_controller.py +187 -0
  8. agents/district_coordinator.py +11 -0
  9. agents/heuristic_controller.py +91 -0
  10. agents/local_policy.py +110 -0
  11. agents/message_protocol.py +116 -0
  12. artifacts/district_llm_adapter_v3/main_run/adapter/README.md +210 -0
  13. artifacts/district_llm_adapter_v3/main_run/adapter/adapter_config.json +50 -0
  14. artifacts/district_llm_adapter_v3/main_run/adapter/adapter_model.safetensors +3 -0
  15. artifacts/district_llm_adapter_v3/main_run/adapter/tokenizer.json +3 -0
  16. artifacts/district_llm_adapter_v3/main_run/adapter/tokenizer_config.json +18 -0
  17. artifacts/dqn_shared/best_validation.pt +3 -0
  18. client.py +52 -0
  19. data/splits/README.md +21 -0
  20. data/splits/test_cities.txt +15 -0
  21. data/splits/train_cities.txt +70 -0
  22. data/splits/val_cities.txt +15 -0
  23. district_llm/FINAL_ABLATION_RUNBOOK.md +79 -0
  24. district_llm/RL_GUIDANCE_EVAL_RUNBOOK.md +126 -0
  25. district_llm/RL_LLM_WRAPPER_SWEEP_RUNBOOK.md +121 -0
  26. district_llm/__init__.py +18 -0
  27. district_llm/data.py +27 -0
  28. district_llm/derivation.py +228 -0
  29. district_llm/eval.py +436 -0
  30. district_llm/generate_dataset.py +390 -0
  31. district_llm/guided_control.py +67 -0
  32. district_llm/heuristic_guidance.py +73 -0
  33. district_llm/inference.py +223 -0
  34. district_llm/metrics.py +97 -0
  35. district_llm/prompting.py +66 -0
  36. district_llm/repair.py +392 -0
  37. district_llm/rl_guidance_wrapper.py +1004 -0
  38. district_llm/schema.py +429 -0
  39. district_llm/summary_builder.py +413 -0
  40. district_llm/teachers.py +227 -0
  41. district_llm/train_unsloth.py +129 -0
  42. env/README.md +43 -0
  43. env/__init__.py +18 -0
  44. env/cityflow_adapter.py +102 -0
  45. env/district_summary.py +9 -0
  46. env/intersection_config.py +49 -0
  47. env/observation_builder.py +224 -0
  48. env/reward.py +244 -0
  49. env/scenarios.py +10 -0
  50. env/traffic_env.py +356 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ artifacts/district_llm_adapter_v3/main_run/adapter/tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ third_party/CityFlow/examples/replay.txt filter=lfs diff=lfs merge=lfs -text
38
+ third_party/CityFlow/extern/pybind11/.git.bak/objects/pack/pack-0b2353194187af7e228cfabd32025bb4d3af8551.idx filter=lfs diff=lfs merge=lfs -text
39
+ third_party/CityFlow/extern/pybind11/.git.bak/objects/pack/pack-0b2353194187af7e228cfabd32025bb4d3af8551.pack filter=lfs diff=lfs merge=lfs -text
40
+ third_party/CityFlow/extern/pybind11/.git.bak/objects/pack/pack-0b2353194187af7e228cfabd32025bb4d3af8551.rev filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim AS builder
2
+
3
+ RUN apt-get update && apt-get install -y --no-install-recommends \
4
+ build-essential \
5
+ cmake \
6
+ libboost-all-dev \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ WORKDIR /build
10
+
11
+ COPY third_party/CityFlow ./CityFlow
12
+ RUN rm -rf ./CityFlow/build
13
+ RUN pip install --no-cache-dir ./CityFlow
14
+
15
+
16
+ FROM python:3.12-slim AS runtime
17
+
18
+ WORKDIR /app
19
+
20
+ COPY --from=builder /usr/local/lib/python3.12/site-packages/cityflow* \
21
+ /usr/local/lib/python3.12/site-packages/
22
+
23
+ COPY openenv_app/requirements.txt ./requirements.txt
24
+ RUN pip install --no-cache-dir -r requirements.txt
25
+
26
+ COPY __init__.py ./__init__.py
27
+ COPY client.py ./client.py
28
+ COPY models.py ./models.py
29
+ COPY agents/ ./agents/
30
+ COPY district_llm/ ./district_llm/
31
+ COPY env/ ./env/
32
+ COPY openenv_app/ ./openenv_app/
33
+ COPY server/ ./server/
34
+ COPY training/ ./training/
35
+ COPY data/splits/ ./data/splits/
36
+ COPY artifacts/dqn_shared/best_validation.pt ./artifacts/dqn_shared/best_validation.pt
37
+ COPY artifacts/district_llm_adapter_v3/main_run/adapter/ ./artifacts/district_llm_adapter_v3/main_run/adapter/
38
+
39
+ RUN mkdir -p /app/data/generated /app/data/splits
40
+
41
+ ENV DATA_DIR=/app/data/generated
42
+ ENV SPLITS_DIR=/app/data/splits
43
+ ENV CHECKPOINT_PATH=/app/artifacts/dqn_shared/best_validation.pt
44
+ ENV DISTRICT_LLM_ADAPTER_PATH=/app/artifacts/district_llm_adapter_v3/main_run/adapter
45
+
46
+ EXPOSE 7860
47
+
48
+ ENV ENABLE_WEB_INTERFACE=true
49
+ CMD ["sh", "-c", "uvicorn server.app:app --host 0.0.0.0 --port ${PORT:-7860}"]
README.md CHANGED
@@ -1,11 +1,191 @@
1
  ---
2
- title: Traffic Visualizer
3
- emoji: 📈
4
- colorFrom: yellow
5
- colorTo: pink
6
  sdk: docker
7
  pinned: false
8
- short_description: test
 
 
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Agentic Traffic
3
+ emoji: 🏢
4
+ colorFrom: green
5
+ colorTo: purple
6
  sdk: docker
7
  pinned: false
8
+ short_description: Agentic AI to control traffic lights
9
+ app_port: 7860
10
+ base_path: /web
11
  ---
12
 
13
+ # traffic-llm
14
+
15
+ CityFlow-based traffic-control project with intersection-level multi-agent DQN training and district-aware policy variants.
16
+
17
+ ## Training
18
+
19
+ The default local-policy trainer now uses parameter-shared dueling Double DQN with prioritized replay and n-step returns:
20
+
21
+ ```bash
22
+ python3 -m training.train_local_policy train
23
+ ```
24
+
25
+ That trains against `data/generated`, uses `data/splits`, writes checkpoints to `artifacts/dqn_shared`, enables TensorBoard logging, uses parallel CPU rollout workers by default, shows `tqdm` progress bars, and now validates plus checkpoints every 40 updates by default.
26
+
27
+ For a broader but still manageable validation pass:
28
+
29
+ ```bash
30
+ python3 -m training.train_local_policy train --max-val-cities 3 --val-scenarios-per-city 7
31
+ ```
32
+
33
+ That evaluates 3 validation cities across all 7 scenario types. This gives 21 learned-policy validation episodes per eval, or 63 total episodes if random and fixed baselines are also enabled.
34
+
35
+ Phase-3-style full training with the same 40-update eval/checkpoint cadence:
36
+
37
+ ```bash
38
+ python3 -m training.train_local_policy train \
39
+ --max-train-cities 70 \
40
+ --max-val-cities 3 \
41
+ --val-scenarios-per-city 7 \
42
+ --policy-arch single_head_with_district_feature \
43
+ --reward-variant wait_queue_throughput
44
+ ```
45
+
46
+ Useful ablations:
47
+
48
+ ```bash
49
+ python3 -m training.train_local_policy train --policy-arch multi_head --reward-variant current
50
+ python3 -m training.train_local_policy train --policy-arch single_head --reward-variant current
51
+ python3 -m training.train_local_policy train --policy-arch single_head_with_district_feature --reward-variant wait_queue_throughput
52
+ ```
53
+
54
+ For a fast phase-1 overfit run on one fixed world:
55
+
56
+ ```bash
57
+ python3 -m training.train_local_policy train \
58
+ --total-updates 25 \
59
+ --train-city-id city_0072 \
60
+ --train-scenario-name normal \
61
+ --overfit-val-on-train-scenario \
62
+ --fast-overfit \
63
+ --policy-arch single_head_with_district_feature \
64
+ --reward-variant wait_queue_throughput
65
+ ```
66
+
67
+ To create or refresh dataset splits:
68
+
69
+ ```bash
70
+ python3 -m training.train_local_policy make-splits
71
+ ```
72
+
73
+ To evaluate the best checkpoint:
74
+
75
+ ```bash
76
+ python3 -m training.train_local_policy evaluate \
77
+ --checkpoint artifacts/dqn_shared/best_validation.pt \
78
+ --split val
79
+ ```
80
+
81
+ To evaluate a heuristic baseline directly:
82
+
83
+ ```bash
84
+ python3 -m training.train_local_policy evaluate --baseline queue_greedy --split val
85
+ ```
86
+
87
+ ## TensorBoard
88
+
89
+ TensorBoard logs are written to `artifacts/dqn_shared/tensorboard` by default.
90
+
91
+ ```bash
92
+ tensorboard --logdir artifacts/dqn_shared/tensorboard
93
+ ```
94
+
95
+ ## District LLM
96
+
97
+ The district LLM stack lives under `district_llm/`. It treats the learned DQN local controller as the low-level executor, derives district-scale SFT labels automatically from DQN rollout windows, and defaults district-model fine-tuning to DQN-derived rows only.
98
+
99
+ Generate district-LLM data from a learned checkpoint:
100
+
101
+ ```bash
102
+ python3 -m district_llm.generate_dataset \
103
+ --controller rl_checkpoint \
104
+ --checkpoint artifacts/dqn_shared/best_validation.pt \
105
+ --episodes 100 \
106
+ --decision-interval 10 \
107
+ --use-checkpoint-env-config \
108
+ --output data/district_llm_train.jsonl
109
+ ```
110
+
111
+ Generate from fixed or heuristic baselines:
112
+
113
+ ```bash
114
+ python3 -m district_llm.generate_dataset --controller fixed --episodes 50 --decision-interval 10 --output data/district_llm_fixed.jsonl
115
+ python3 -m district_llm.generate_dataset --controller queue_greedy --episodes 50 --decision-interval 10 --output data/district_llm_heuristic.jsonl
116
+ python3 -m district_llm.generate_dataset --teacher-spec fixed --teacher-spec random --episodes 50 --decision-interval 10 --output data/district_llm_multi_teacher.jsonl
117
+ ```
118
+
119
+ Train a first-pass district model with Unsloth/QLoRA:
120
+
121
+ ```bash
122
+ python3 -m training.train_district_llm \
123
+ --dataset data/district_llm_train.jsonl \
124
+ --output-dir artifacts/district_llm_qwen \
125
+ --model-name Qwen/Qwen2.5-7B-Instruct \
126
+ --load-in-4bit \
127
+ --lora-rank 16 \
128
+ --max-seq-length 1024 \
129
+ --max-steps 1000
130
+ ```
131
+
132
+ Run single-sample inference:
133
+
134
+ ```bash
135
+ python3 -m district_llm.inference \
136
+ --model artifacts/district_llm_qwen \
137
+ --city-id city_0006 \
138
+ --scenario-name accident \
139
+ --district-id d_00
140
+ ```
141
+
142
+ Run the OpenEnv-compatible district wrapper on top of the current DQN stack:
143
+
144
+ ```bash
145
+ uvicorn openenv_app.app:app --reload
146
+ ```
147
+
148
+ ## Algorithm
149
+
150
+ - Training algorithm: parameter-shared dueling Double DQN.
151
+ - Replay: prioritized replay over per-intersection transitions gathered from full CityFlow worlds.
152
+ - Return target: n-step bootstrap target with target-network updates.
153
+ - Execution: all controllable intersections act simultaneously every RL decision interval.
154
+ - Action space: `0 = hold current phase`, `1 = switch to next green phase`.
155
+ - Safety: `min_green_time` is enforced in the environment and exposed through action masking.
156
+
157
+ Policy architecture modes:
158
+
159
+ - `multi_head`: shared trunk with district-type-specific Q heads.
160
+ - `single_head`: one shared Q head for all intersections, with district type removed from the observation.
161
+ - `single_head_with_district_feature`: one shared Q head for all intersections, with district type left in the observation as an explicit feature.
162
+
163
+ Reward variants:
164
+
165
+ - `current`: backward-compatible waiting and queue penalty.
166
+ - `normalized_wait_queue`: normalized queue and waiting reduction reward.
167
+ - `wait_queue_throughput`: normalized queue/wait reduction plus throughput bonus and imbalance penalty.
168
+
169
+ ## Smoke Test
170
+
171
+ To sanity-check one generated scenario with the real CityFlow environment:
172
+
173
+ ```bash
174
+ python3 scripts/smoke_test_env.py --city-id city_0001 --scenario-name normal --policy random
175
+ ```
176
+
177
+ ## Project layout
178
+
179
+ - `agents/`: heuristic local policies and simple baselines.
180
+ - `env/`: CityFlow environment, topology parsing, observation building, and reward logic.
181
+ - `training/`: dataset utilities, replay-based DQN training, evaluation helpers, TensorBoard logging, and CLIs.
182
+ - `data/`: generated synthetic cities, split files, and dataset generation utilities.
183
+ - `scripts/`: utility scripts, including the CityFlow smoke test.
184
+ - `third_party/`: vendored dependencies, including CityFlow source.
185
+
186
+ ## Notes
187
+
188
+ - The generated dataset is assumed to already exist under `data/generated`.
189
+ - District membership comes from `district_map.json`.
190
+ - District types come from `metadata.json`.
191
+ - Runtime training and evaluation require the `cityflow` Python module to be installed in the active environment.
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """OpenEnv package root for the lean agentic traffic environment."""
agents/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # agents
2
+
3
+ Local traffic-control policies and compatibility shims.
4
+
5
+ ## Main files
6
+
7
+ - [local_policy.py](/Users/aditya/Developer/traffic-llm/agents/local_policy.py)
8
+ Active v1 policy interfaces and simple baselines:
9
+ - `HoldPhasePolicy`
10
+ - `FixedCyclePolicy`
11
+ - `QueueGreedyPolicy`
12
+ - [district_controller.py](/Users/aditya/Developer/traffic-llm/agents/district_controller.py)
13
+ Older district-level prototype logic kept for compatibility.
14
+ - [district_coordinator.py](/Users/aditya/Developer/traffic-llm/agents/district_coordinator.py)
15
+ Import shim for older code paths.
16
+
17
+ ## Notes
18
+
19
+ - The learned local-policy network itself lives in [training/models.py](/Users/aditya/Developer/traffic-llm/training/models.py), not here.
20
+ - For active training, use the parameter-shared DQN path in `training/`, not the district-controller prototypes.
agents/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from agents.local_policy import (
2
+ BaseLocalPolicy,
3
+ FixedCyclePolicy,
4
+ HoldPhasePolicy,
5
+ QueueGreedyPolicy,
6
+ SharedHeuristicLocalPolicy,
7
+ )
8
+
9
+ __all__ = [
10
+ "BaseLocalPolicy",
11
+ "FixedCyclePolicy",
12
+ "HoldPhasePolicy",
13
+ "QueueGreedyPolicy",
14
+ "SharedHeuristicLocalPolicy",
15
+ ]
agents/district_controller.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Callable
5
+
6
+ from agents.message_protocol import DistrictDirective, parse_district_directive
7
+
8
+
9
+ class BaseDistrictCoordinator(ABC):
10
+ @abstractmethod
11
+ def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]:
12
+ raise NotImplementedError
13
+
14
+
15
+ class RuleBasedDistrictCoordinator(BaseDistrictCoordinator):
16
+ """
17
+ Fast, deterministic, and robust.
18
+ Good first coordinator and good fallback if the LLM output fails.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ imbalance_threshold: float = 0.15,
24
+ border_pressure_threshold: float = 0.65,
25
+ default_duration: int = 2,
26
+ ):
27
+ self.imbalance_threshold = imbalance_threshold
28
+ self.border_pressure_threshold = border_pressure_threshold
29
+ self.default_duration = default_duration
30
+
31
+ def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]:
32
+ district_id = district_summary.get("district_id", "unknown")
33
+ intersection_ids = district_summary.get("intersection_ids", [])
34
+
35
+ emergency = district_summary.get("emergency_vehicle", {})
36
+ if emergency.get("present", False):
37
+ return (
38
+ DistrictDirective(
39
+ mode="emergency_route",
40
+ target_intersections=emergency.get("route", intersection_ids),
41
+ duration=2,
42
+ rationale=f"Emergency vehicle detected in district {district_id}.",
43
+ corridor=emergency.get("corridor"),
44
+ district_weight=1.0,
45
+ )
46
+ .validate()
47
+ .to_dict()
48
+ )
49
+
50
+ corridor_loads = district_summary.get("corridor_loads", {})
51
+ ns = float(corridor_loads.get("ns", corridor_loads.get("north_south", 0.0)))
52
+ ew = float(corridor_loads.get("ew", corridor_loads.get("east_west", 0.0)))
53
+
54
+ border_pressure = district_summary.get("border_pressure", {})
55
+ border_max = 0.0
56
+ if isinstance(border_pressure, dict) and border_pressure:
57
+ border_max = max(float(v) for v in border_pressure.values())
58
+
59
+ if ew - ns > self.imbalance_threshold:
60
+ return (
61
+ DistrictDirective(
62
+ mode="prioritize_ew",
63
+ target_intersections=intersection_ids,
64
+ duration=self.default_duration,
65
+ rationale="East-west corridor is currently more congested than north-south.",
66
+ corridor="ew",
67
+ district_weight=(
68
+ 0.7 if border_max < self.border_pressure_threshold else 0.9
69
+ ),
70
+ )
71
+ .validate()
72
+ .to_dict()
73
+ )
74
+
75
+ if ns - ew > self.imbalance_threshold:
76
+ return (
77
+ DistrictDirective(
78
+ mode="prioritize_ns",
79
+ target_intersections=intersection_ids,
80
+ duration=self.default_duration,
81
+ rationale="North-south corridor is currently more congested than east-west.",
82
+ corridor="ns",
83
+ district_weight=(
84
+ 0.7 if border_max < self.border_pressure_threshold else 0.9
85
+ ),
86
+ )
87
+ .validate()
88
+ .to_dict()
89
+ )
90
+
91
+ if border_max >= self.border_pressure_threshold:
92
+ return (
93
+ DistrictDirective(
94
+ mode="damp_border_inflow",
95
+ target_intersections=intersection_ids,
96
+ duration=2,
97
+ rationale="Border pressure is high; reduce spill-in and smooth cross-district flow.",
98
+ district_weight=0.8,
99
+ )
100
+ .validate()
101
+ .to_dict()
102
+ )
103
+
104
+ return (
105
+ DistrictDirective(
106
+ mode="none",
107
+ target_intersections=[],
108
+ duration=1,
109
+ rationale="District is reasonably balanced.",
110
+ district_weight=0.5,
111
+ )
112
+ .validate()
113
+ .to_dict()
114
+ )
115
+
116
+
117
+ class LLMDistrictCoordinator(BaseDistrictCoordinator):
118
+ """
119
+ LLM-backed coordinator.
120
+
121
+ `generator_fn` should accept a prompt string and return either:
122
+ - a JSON string, or
123
+ - a dict
124
+
125
+ Example:
126
+ coordinator = LLMDistrictCoordinator(generator_fn=my_model_call)
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ generator_fn: Callable[[str], str | dict[str, Any]],
132
+ fallback: BaseDistrictCoordinator | None = None,
133
+ max_prompt_chars: int = 4000,
134
+ ):
135
+ self.generator_fn = generator_fn
136
+ self.fallback = fallback or RuleBasedDistrictCoordinator()
137
+ self.max_prompt_chars = max_prompt_chars
138
+
139
+ def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]:
140
+ prompt = self.build_prompt(district_summary)
141
+ try:
142
+ raw = self.generator_fn(prompt)
143
+ directive = parse_district_directive(raw).to_dict()
144
+
145
+ # If the LLM returns a no-op too often or malformed content,
146
+ # the parser still makes it safe. We keep that behavior.
147
+ return directive
148
+ except Exception:
149
+ return self.fallback.decide(district_summary)
150
+
151
+ def build_prompt(self, district_summary: dict[str, Any]) -> str:
152
+ summary_text = repr(district_summary)
153
+ if len(summary_text) > self.max_prompt_chars:
154
+ summary_text = summary_text[: self.max_prompt_chars] + " ...[truncated]"
155
+
156
+ return f"""You are a district-level traffic coordinator.
157
+
158
+ Your job is to choose a single strategic directive for the next few cycles.
159
+
160
+ Allowed modes:
161
+ - none
162
+ - prioritize_ns
163
+ - prioritize_ew
164
+ - green_wave
165
+ - emergency_route
166
+ - damp_border_inflow
167
+
168
+ Return ONLY valid JSON with these fields:
169
+ {{
170
+ "mode": string,
171
+ "target_intersections": list[string],
172
+ "duration": int,
173
+ "rationale": string,
174
+ "corridor": string or null,
175
+ "district_weight": float
176
+ }}
177
+
178
+ Guidelines:
179
+ - Use emergency_route if an emergency vehicle is present.
180
+ - Use prioritize_ns or prioritize_ew when one corridor is clearly more congested.
181
+ - Use damp_border_inflow when cross-district border pressure is high.
182
+ - Keep duration between 1 and 5.
183
+ - district_weight should be between 0.0 and 1.0.
184
+
185
+ District summary:
186
+ {summary_text}
187
+ """
agents/district_coordinator.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from agents.district_controller import (
2
+ BaseDistrictCoordinator,
3
+ LLMDistrictCoordinator,
4
+ RuleBasedDistrictCoordinator,
5
+ )
6
+
7
+ __all__ = [
8
+ "BaseDistrictCoordinator",
9
+ "LLMDistrictCoordinator",
10
+ "RuleBasedDistrictCoordinator",
11
+ ]
agents/heuristic_controller.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+
6
+ class HeuristicController:
7
+ """
8
+ Simple local traffic-light controller.
9
+
10
+ Action space:
11
+ 0 -> choose NS green
12
+ 1 -> choose EW green
13
+
14
+ Assumes:
15
+ queue_lengths = [N, S, E, W]
16
+ waiting_counts = [N, S, E, W]
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ min_green_steps: int = 5,
22
+ switch_margin: float = 1.0,
23
+ district_bonus_scale: float = 3.0,
24
+ neighbor_pressure_scale: float = 0.25,
25
+ ):
26
+ self.min_green_steps = min_green_steps
27
+ self.switch_margin = switch_margin
28
+ self.district_bonus_scale = district_bonus_scale
29
+ self.neighbor_pressure_scale = neighbor_pressure_scale
30
+
31
+ def act(self, obs: dict[str, Any]) -> int:
32
+ queue_lengths = obs.get("queue_lengths", [0, 0, 0, 0])
33
+ waiting_counts = obs.get("waiting_counts", [0, 0, 0, 0])
34
+ current_phase = int(obs.get("current_phase", 0))
35
+ time_since_switch = int(obs.get("time_since_switch", 0))
36
+ district_mode = obs.get("district_mode", "none")
37
+ district_weight = float(obs.get("district_weight", 0.5))
38
+ neighbor_pressure = obs.get("neighbor_pressure", [0.0, 0.0])
39
+
40
+ ns_score = (
41
+ queue_lengths[0]
42
+ + queue_lengths[1]
43
+ + 1.5 * (waiting_counts[0] + waiting_counts[1])
44
+ )
45
+ ew_score = (
46
+ queue_lengths[2]
47
+ + queue_lengths[3]
48
+ + 1.5 * (waiting_counts[2] + waiting_counts[3])
49
+ )
50
+
51
+ # Optional small neighbor-pressure bias
52
+ if isinstance(neighbor_pressure, list) and len(neighbor_pressure) >= 2:
53
+ ns_score += self.neighbor_pressure_scale * float(neighbor_pressure[0])
54
+ ew_score += self.neighbor_pressure_scale * float(neighbor_pressure[1])
55
+
56
+ # District-level strategic bias
57
+ district_bonus = self.district_bonus_scale * district_weight
58
+ if district_mode == "prioritize_ns":
59
+ ns_score += district_bonus
60
+ elif district_mode == "prioritize_ew":
61
+ ew_score += district_bonus
62
+ elif district_mode == "green_wave":
63
+ corridor = obs.get("district_corridor")
64
+ if corridor == "ns":
65
+ ns_score += district_bonus
66
+ elif corridor == "ew":
67
+ ew_score += district_bonus
68
+ elif district_mode == "emergency_route":
69
+ corridor = obs.get("district_corridor")
70
+ if corridor in {"north_to_south", "south_to_north", "ns"}:
71
+ ns_score += district_bonus * 1.5
72
+ elif corridor in {"west_to_east", "east_to_west", "ew"}:
73
+ ew_score += district_bonus * 1.5
74
+
75
+ desired_phase = 0 if ns_score >= ew_score else 1
76
+
77
+ # Avoid thrashing
78
+ if time_since_switch < self.min_green_steps:
79
+ return current_phase
80
+
81
+ # Only switch if the other phase is meaningfully better
82
+ current_score = ns_score if current_phase == 0 else ew_score
83
+ desired_score = ns_score if desired_phase == 0 else ew_score
84
+
85
+ if (
86
+ desired_phase != current_phase
87
+ and desired_score < current_score + self.switch_margin
88
+ ):
89
+ return current_phase
90
+
91
+ return desired_phase
agents/local_policy.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ import numpy as np
6
+
7
+
8
+ class BaseLocalPolicy(ABC):
9
+ @abstractmethod
10
+ def act(self, observation_batch: dict[str, np.ndarray]) -> np.ndarray:
11
+ raise NotImplementedError
12
+
13
+
14
+ class HoldPhasePolicy(BaseLocalPolicy):
15
+ def act(self, observation_batch: dict[str, np.ndarray]) -> np.ndarray:
16
+ intersection_count = len(observation_batch["intersection_ids"])
17
+ return np.zeros(intersection_count, dtype=np.int64)
18
+
19
+
20
+ class RandomPhasePolicy(BaseLocalPolicy):
21
+ def __init__(self, seed: int = 7):
22
+ self.rng = np.random.default_rng(seed)
23
+
24
+ def act(self, observation_batch: dict[str, np.ndarray]) -> np.ndarray:
25
+ action_mask = observation_batch["action_mask"]
26
+ actions = np.zeros(action_mask.shape[0], dtype=np.int64)
27
+ for row_index, mask in enumerate(action_mask):
28
+ valid_actions = np.flatnonzero(mask > 0.0)
29
+ actions[row_index] = int(self.rng.choice(valid_actions))
30
+ return actions
31
+
32
+
33
+ class FixedCyclePolicy(BaseLocalPolicy):
34
+ def __init__(self, green_time: int = 20):
35
+ self.green_time = int(green_time)
36
+
37
+ def act(self, observation_batch: dict[str, np.ndarray]) -> np.ndarray:
38
+ elapsed = observation_batch["phase_elapsed"]
39
+ action_mask = observation_batch["action_mask"]
40
+ should_switch = (elapsed >= self.green_time) & (action_mask[:, 1] > 0.0)
41
+ return should_switch.astype(np.int64)
42
+
43
+
44
+ class QueueGreedyPolicy(BaseLocalPolicy):
45
+ def __init__(self, switch_margin: float = 1.0):
46
+ self.switch_margin = float(switch_margin)
47
+
48
+ def act(self, observation_batch: dict[str, np.ndarray]) -> np.ndarray:
49
+ counts = observation_batch["incoming_counts"]
50
+ waiting = observation_batch["incoming_waiting"]
51
+ lane_mask = observation_batch["lane_mask"]
52
+ current_phase = observation_batch["current_phase"]
53
+ action_mask = observation_batch["action_mask"]
54
+
55
+ midpoint = counts.shape[1] // 2
56
+ ns_score = (
57
+ counts[:, :midpoint].sum(axis=1)
58
+ + 1.5 * waiting[:, :midpoint].sum(axis=1)
59
+ )
60
+ ew_score = (
61
+ counts[:, midpoint:].sum(axis=1)
62
+ + 1.5 * waiting[:, midpoint:].sum(axis=1)
63
+ )
64
+
65
+ valid_midpoint = lane_mask[:, :midpoint].sum(axis=1) > 0
66
+ ns_score = np.where(valid_midpoint, ns_score, 0.0)
67
+
68
+ desired_switch = np.where(
69
+ current_phase == 0,
70
+ ew_score > ns_score + self.switch_margin,
71
+ ns_score > ew_score + self.switch_margin,
72
+ )
73
+ desired_switch = desired_switch & (action_mask[:, 1] > 0.0)
74
+ return desired_switch.astype(np.int64)
75
+
76
+
77
+ class SharedHeuristicLocalPolicy(QueueGreedyPolicy):
78
+ def __init__(
79
+ self,
80
+ min_green_steps: int = 5,
81
+ switch_margin: float = 1.0,
82
+ district_bonus_scale: float = 0.0,
83
+ neighbor_pressure_scale: float = 0.0,
84
+ ):
85
+ self.min_green_steps = int(min_green_steps)
86
+ del district_bonus_scale, neighbor_pressure_scale
87
+ super().__init__(switch_margin=switch_margin)
88
+
89
+ def act_batch(self, observation_batch):
90
+ if "intersection_ids" in observation_batch:
91
+ return self.act(observation_batch)
92
+
93
+ actions: dict[str, int] = {}
94
+ for intersection_id, payload in observation_batch.items():
95
+ waiting = payload.get("waiting_counts", [0, 0, 0, 0])
96
+ queues = payload.get("queue_lengths", [0, 0, 0, 0])
97
+ current_phase = int(payload.get("current_phase", 0))
98
+ time_since_switch = int(payload.get("time_since_switch", 0))
99
+
100
+ ns_score = float(sum(queues[:2]) + 1.5 * sum(waiting[:2]))
101
+ ew_score = float(sum(queues[2:4]) + 1.5 * sum(waiting[2:4]))
102
+ desired_phase = 0 if ns_score >= ew_score else 1
103
+
104
+ if time_since_switch < self.min_green_steps:
105
+ actions[intersection_id] = current_phase
106
+ elif desired_phase != current_phase and abs(ns_score - ew_score) <= self.switch_margin:
107
+ actions[intersection_id] = current_phase
108
+ else:
109
+ actions[intersection_id] = desired_phase
110
+ return actions
agents/message_protocol.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import asdict, dataclass, field
5
+ from typing import Any
6
+
7
+
8
+ VALID_MODES = {
9
+ "none",
10
+ "prioritize_ns",
11
+ "prioritize_ew",
12
+ "green_wave",
13
+ "emergency_route",
14
+ "damp_border_inflow",
15
+ }
16
+
17
+
18
+ @dataclass
19
+ class NeighborMessage:
20
+ sender_intersection: str
21
+ receiver_intersection: str
22
+ congestion_level: float
23
+ spillback_risk: bool
24
+ dominant_direction: str # "ns", "ew", or "balanced"
25
+ queue_total: int
26
+
27
+ def to_dict(self) -> dict[str, Any]:
28
+ return asdict(self)
29
+
30
+
31
+ @dataclass
32
+ class DistrictDirective:
33
+ mode: str = "none"
34
+ target_intersections: list[str] = field(default_factory=list)
35
+ duration: int = 1
36
+ rationale: str = ""
37
+ corridor: str | None = None
38
+ district_weight: float = 0.5
39
+
40
+ def validate(self) -> "DistrictDirective":
41
+ if self.mode not in VALID_MODES:
42
+ self.mode = "none"
43
+
44
+ if not isinstance(self.target_intersections, list):
45
+ self.target_intersections = []
46
+
47
+ if not isinstance(self.duration, int):
48
+ self.duration = 1
49
+ self.duration = max(1, min(self.duration, 10))
50
+
51
+ if not isinstance(self.rationale, str):
52
+ self.rationale = ""
53
+
54
+ if self.corridor is not None and self.corridor not in {
55
+ "ns",
56
+ "ew",
57
+ "west_to_east",
58
+ "east_to_west",
59
+ "north_to_south",
60
+ "south_to_north",
61
+ }:
62
+ self.corridor = None
63
+
64
+ if not isinstance(self.district_weight, (int, float)):
65
+ self.district_weight = 0.5
66
+ self.district_weight = float(max(0.0, min(1.0, self.district_weight)))
67
+
68
+ return self
69
+
70
+ def to_dict(self) -> dict[str, Any]:
71
+ return asdict(self)
72
+
73
+
74
+ def parse_district_directive(payload: str | dict[str, Any]) -> DistrictDirective:
75
+ """
76
+ Accept either raw JSON text or a dict and return a validated DistrictDirective.
77
+ Falls back safely to a no-op directive.
78
+ """
79
+ try:
80
+ if isinstance(payload, str):
81
+ payload = payload.strip()
82
+ if not payload:
83
+ return DistrictDirective().validate()
84
+
85
+ # Try direct JSON parse
86
+ try:
87
+ data = json.loads(payload)
88
+ except json.JSONDecodeError:
89
+ # Try to extract JSON object from surrounding text
90
+ start = payload.find("{")
91
+ end = payload.rfind("}")
92
+ if start == -1 or end == -1 or end <= start:
93
+ return DistrictDirective().validate()
94
+ data = json.loads(payload[start : end + 1])
95
+ elif isinstance(payload, dict):
96
+ data = payload
97
+ else:
98
+ return DistrictDirective().validate()
99
+
100
+ directive = DistrictDirective(
101
+ mode=data.get("mode", "none"),
102
+ target_intersections=data.get("target_intersections", []),
103
+ duration=data.get("duration", 1),
104
+ rationale=data.get("rationale", ""),
105
+ corridor=data.get("corridor"),
106
+ district_weight=data.get("district_weight", 0.5),
107
+ )
108
+ return directive.validate()
109
+ except Exception:
110
+ return DistrictDirective().validate()
111
+
112
+
113
+ def safe_directive_dict(payload: str | dict[str, Any] | None) -> dict[str, Any]:
114
+ if payload is None:
115
+ return DistrictDirective().validate().to_dict()
116
+ return parse_district_directive(payload).to_dict()
artifacts/district_llm_adapter_v3/main_run/adapter/README.md ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: unsloth/llama-3.1-8b-unsloth-bnb-4bit
3
+ library_name: peft
4
+ pipeline_tag: text-generation
5
+ tags:
6
+ - base_model:adapter:unsloth/llama-3.1-8b-unsloth-bnb-4bit
7
+ - lora
8
+ - sft
9
+ - transformers
10
+ - trl
11
+ - unsloth
12
+ ---
13
+
14
+ # Model Card for Model ID
15
+
16
+ <!-- Provide a quick summary of what the model is/does. -->
17
+
18
+
19
+
20
+ ## Model Details
21
+
22
+ ### Model Description
23
+
24
+ <!-- Provide a longer summary of what this model is. -->
25
+
26
+
27
+
28
+ - **Developed by:** [More Information Needed]
29
+ - **Funded by [optional]:** [More Information Needed]
30
+ - **Shared by [optional]:** [More Information Needed]
31
+ - **Model type:** [More Information Needed]
32
+ - **Language(s) (NLP):** [More Information Needed]
33
+ - **License:** [More Information Needed]
34
+ - **Finetuned from model [optional]:** [More Information Needed]
35
+
36
+ ### Model Sources [optional]
37
+
38
+ <!-- Provide the basic links for the model. -->
39
+
40
+ - **Repository:** [More Information Needed]
41
+ - **Paper [optional]:** [More Information Needed]
42
+ - **Demo [optional]:** [More Information Needed]
43
+
44
+ ## Uses
45
+
46
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
47
+
48
+ ### Direct Use
49
+
50
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
51
+
52
+ [More Information Needed]
53
+
54
+ ### Downstream Use [optional]
55
+
56
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
57
+
58
+ [More Information Needed]
59
+
60
+ ### Out-of-Scope Use
61
+
62
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
63
+
64
+ [More Information Needed]
65
+
66
+ ## Bias, Risks, and Limitations
67
+
68
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
69
+
70
+ [More Information Needed]
71
+
72
+ ### Recommendations
73
+
74
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
75
+
76
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
77
+
78
+ ## How to Get Started with the Model
79
+
80
+ Use the code below to get started with the model.
81
+
82
+ [More Information Needed]
83
+
84
+ ## Training Details
85
+
86
+ ### Training Data
87
+
88
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
89
+
90
+ [More Information Needed]
91
+
92
+ ### Training Procedure
93
+
94
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
95
+
96
+ #### Preprocessing [optional]
97
+
98
+ [More Information Needed]
99
+
100
+
101
+ #### Training Hyperparameters
102
+
103
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
104
+
105
+ #### Speeds, Sizes, Times [optional]
106
+
107
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
108
+
109
+ [More Information Needed]
110
+
111
+ ## Evaluation
112
+
113
+ <!-- This section describes the evaluation protocols and provides the results. -->
114
+
115
+ ### Testing Data, Factors & Metrics
116
+
117
+ #### Testing Data
118
+
119
+ <!-- This should link to a Dataset Card if possible. -->
120
+
121
+ [More Information Needed]
122
+
123
+ #### Factors
124
+
125
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
126
+
127
+ [More Information Needed]
128
+
129
+ #### Metrics
130
+
131
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
132
+
133
+ [More Information Needed]
134
+
135
+ ### Results
136
+
137
+ [More Information Needed]
138
+
139
+ #### Summary
140
+
141
+
142
+
143
+ ## Model Examination [optional]
144
+
145
+ <!-- Relevant interpretability work for the model goes here -->
146
+
147
+ [More Information Needed]
148
+
149
+ ## Environmental Impact
150
+
151
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
152
+
153
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
154
+
155
+ - **Hardware Type:** [More Information Needed]
156
+ - **Hours used:** [More Information Needed]
157
+ - **Cloud Provider:** [More Information Needed]
158
+ - **Compute Region:** [More Information Needed]
159
+ - **Carbon Emitted:** [More Information Needed]
160
+
161
+ ## Technical Specifications [optional]
162
+
163
+ ### Model Architecture and Objective
164
+
165
+ [More Information Needed]
166
+
167
+ ### Compute Infrastructure
168
+
169
+ [More Information Needed]
170
+
171
+ #### Hardware
172
+
173
+ [More Information Needed]
174
+
175
+ #### Software
176
+
177
+ [More Information Needed]
178
+
179
+ ## Citation [optional]
180
+
181
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
182
+
183
+ **BibTeX:**
184
+
185
+ [More Information Needed]
186
+
187
+ **APA:**
188
+
189
+ [More Information Needed]
190
+
191
+ ## Glossary [optional]
192
+
193
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
194
+
195
+ [More Information Needed]
196
+
197
+ ## More Information [optional]
198
+
199
+ [More Information Needed]
200
+
201
+ ## Model Card Authors [optional]
202
+
203
+ [More Information Needed]
204
+
205
+ ## Model Card Contact
206
+
207
+ [More Information Needed]
208
+ ### Framework versions
209
+
210
+ - PEFT 0.18.1
artifacts/district_llm_adapter_v3/main_run/adapter/adapter_config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alora_invocation_tokens": null,
3
+ "alpha_pattern": {},
4
+ "arrow_config": null,
5
+ "auto_mapping": {
6
+ "base_model_class": "LlamaForCausalLM",
7
+ "parent_library": "transformers.models.llama.modeling_llama",
8
+ "unsloth_fixed": true
9
+ },
10
+ "base_model_name_or_path": "unsloth/llama-3.1-8b-unsloth-bnb-4bit",
11
+ "bias": "none",
12
+ "corda_config": null,
13
+ "ensure_weight_tying": false,
14
+ "eva_config": null,
15
+ "exclude_modules": null,
16
+ "fan_in_fan_out": false,
17
+ "inference_mode": true,
18
+ "init_lora_weights": true,
19
+ "layer_replication": null,
20
+ "layers_pattern": null,
21
+ "layers_to_transform": null,
22
+ "loftq_config": {},
23
+ "lora_alpha": 16,
24
+ "lora_bias": false,
25
+ "lora_dropout": 0.05,
26
+ "megatron_config": null,
27
+ "megatron_core": "megatron.core",
28
+ "modules_to_save": null,
29
+ "peft_type": "LORA",
30
+ "peft_version": "0.18.1",
31
+ "qalora_group_size": 16,
32
+ "r": 16,
33
+ "rank_pattern": {},
34
+ "revision": null,
35
+ "target_modules": [
36
+ "down_proj",
37
+ "q_proj",
38
+ "up_proj",
39
+ "gate_proj",
40
+ "v_proj",
41
+ "k_proj",
42
+ "o_proj"
43
+ ],
44
+ "target_parameters": null,
45
+ "task_type": "CAUSAL_LM",
46
+ "trainable_token_indices": null,
47
+ "use_dora": false,
48
+ "use_qalora": false,
49
+ "use_rslora": false
50
+ }
artifacts/district_llm_adapter_v3/main_run/adapter/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69163c214c0ad5462dc64dfb110bea79e50a2c4d5affdf084a8b351352231777
3
+ size 167832240
artifacts/district_llm_adapter_v3/main_run/adapter/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b9e4e7fb171f92fd137b777cc2714bf87d11576700a1dcd7a399e7bbe39537b
3
+ size 17209920
artifacts/district_llm_adapter_v3/main_run/adapter/tokenizer_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": "<|begin_of_text|>",
4
+ "clean_up_tokenization_spaces": true,
5
+ "eos_token": "<|end_of_text|>",
6
+ "from_slow": true,
7
+ "is_local": false,
8
+ "legacy": false,
9
+ "model_input_names": [
10
+ "input_ids",
11
+ "attention_mask"
12
+ ],
13
+ "model_max_length": 131072,
14
+ "pad_token": "<|finetune_right_pad_id|>",
15
+ "padding_side": "left",
16
+ "tokenizer_class": "TokenizersBackend",
17
+ "unk_token": null
18
+ }
artifacts/dqn_shared/best_validation.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a478cd3149c74ef0f0a57e5510bccd35262272a0a25bfb76b6aac2e8417af85
3
+ size 1320091
client.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import requests
6
+
7
+ from models import (
8
+ AgenticTrafficAction,
9
+ AgenticTrafficObservation,
10
+ AgenticTrafficState,
11
+ )
12
+
13
+
14
+ class AgenticTrafficClient:
15
+ """Thin HTTP client for the DistrictFlow OpenEnv server."""
16
+
17
+ def __init__(self, base_url: str):
18
+ self.base_url = base_url.rstrip("/")
19
+
20
+ def reset(self, seed: int | None = None) -> AgenticTrafficObservation:
21
+ response = requests.post(
22
+ f"{self.base_url}/reset",
23
+ json={"seed": seed},
24
+ timeout=60,
25
+ )
26
+ response.raise_for_status()
27
+ payload = response.json()
28
+ return AgenticTrafficObservation.model_validate(payload["observation"])
29
+
30
+ def step(self, action: AgenticTrafficAction) -> AgenticTrafficObservation:
31
+ response = requests.post(
32
+ f"{self.base_url}/step",
33
+ json={"action": action.model_dump()},
34
+ timeout=60,
35
+ )
36
+ response.raise_for_status()
37
+ payload = response.json()
38
+ observation = AgenticTrafficObservation.model_validate(payload["observation"])
39
+ observation.done = bool(payload.get("done", False))
40
+ observation.reward = float(payload.get("reward", 0.0))
41
+ return observation
42
+
43
+ def state(self) -> AgenticTrafficState:
44
+ response = requests.get(f"{self.base_url}/state", timeout=60)
45
+ response.raise_for_status()
46
+ payload = response.json()
47
+ return AgenticTrafficState.model_validate(payload["state"])
48
+
49
+ def health(self) -> dict[str, Any]:
50
+ response = requests.get(f"{self.base_url}/health", timeout=30)
51
+ response.raise_for_status()
52
+ return response.json()
data/splits/README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data/splits
2
+
3
+ City-level train/validation/test splits for the generated dataset.
4
+
5
+ ## Files
6
+
7
+ - [train_cities.txt](/Users/aditya/Developer/traffic-llm/data/splits/train_cities.txt)
8
+ - [val_cities.txt](/Users/aditya/Developer/traffic-llm/data/splits/val_cities.txt)
9
+ - [test_cities.txt](/Users/aditya/Developer/traffic-llm/data/splits/test_cities.txt)
10
+
11
+ ## Important rule
12
+
13
+ Splits are by city only. All scenarios for a given city belong to the same split.
14
+
15
+ ## Regeneration
16
+
17
+ Use:
18
+
19
+ `python3 -m training.train_local_policy make-splits`
20
+
21
+ The split logic is implemented in [training/dataset.py](/Users/aditya/Developer/traffic-llm/training/dataset.py).
data/splits/test_cities.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ city_0005
2
+ city_0007
3
+ city_0008
4
+ city_0010
5
+ city_0012
6
+ city_0013
7
+ city_0020
8
+ city_0028
9
+ city_0042
10
+ city_0047
11
+ city_0051
12
+ city_0065
13
+ city_0069
14
+ city_0075
15
+ city_0084
data/splits/train_cities.txt ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ city_0001
2
+ city_0002
3
+ city_0003
4
+ city_0004
5
+ city_0006
6
+ city_0011
7
+ city_0014
8
+ city_0015
9
+ city_0017
10
+ city_0018
11
+ city_0019
12
+ city_0021
13
+ city_0022
14
+ city_0023
15
+ city_0024
16
+ city_0025
17
+ city_0026
18
+ city_0027
19
+ city_0030
20
+ city_0032
21
+ city_0033
22
+ city_0034
23
+ city_0035
24
+ city_0036
25
+ city_0037
26
+ city_0038
27
+ city_0039
28
+ city_0040
29
+ city_0041
30
+ city_0043
31
+ city_0044
32
+ city_0045
33
+ city_0046
34
+ city_0048
35
+ city_0049
36
+ city_0050
37
+ city_0052
38
+ city_0053
39
+ city_0057
40
+ city_0058
41
+ city_0059
42
+ city_0060
43
+ city_0061
44
+ city_0062
45
+ city_0063
46
+ city_0064
47
+ city_0066
48
+ city_0067
49
+ city_0068
50
+ city_0070
51
+ city_0072
52
+ city_0074
53
+ city_0076
54
+ city_0077
55
+ city_0079
56
+ city_0080
57
+ city_0081
58
+ city_0082
59
+ city_0083
60
+ city_0085
61
+ city_0087
62
+ city_0088
63
+ city_0089
64
+ city_0092
65
+ city_0093
66
+ city_0094
67
+ city_0095
68
+ city_0097
69
+ city_0099
70
+ city_0100
data/splits/val_cities.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ city_0009
2
+ city_0016
3
+ city_0029
4
+ city_0031
5
+ city_0054
6
+ city_0055
7
+ city_0056
8
+ city_0071
9
+ city_0073
10
+ city_0078
11
+ city_0086
12
+ city_0090
13
+ city_0091
14
+ city_0096
15
+ city_0098
district_llm/FINAL_ABLATION_RUNBOOK.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Final Ablation Runbook
2
+
3
+ ## Dataset
4
+
5
+ Generate the constrained v3 dataset:
6
+
7
+ ```bash
8
+ python scripts/generate_large_district_dataset.py \
9
+ --num-train 10000 \
10
+ --num-val 2500 \
11
+ --output-dir data/district_llm_dataset_v3 \
12
+ --checkpoint artifacts/dqn_shared/best_validation.pt \
13
+ --max-candidate-intersections 6 \
14
+ --max-target-intersections 3
15
+ ```
16
+
17
+ Defaults:
18
+
19
+ - candidate pool is visible in the prompt via `candidate_intersections`
20
+ - labels are constrained to visible candidates
21
+ - DQN teacher sources are preferred by default
22
+
23
+ ## Notebook
24
+
25
+ Use [notebooks/llama_finetune.ipynb](/root/aditya/agentic-traffic/notebooks/llama_finetune.ipynb).
26
+
27
+ Recommended defaults for the A100 main run:
28
+
29
+ - `RUN_MODE = "main_run"`
30
+ - `num_train_epochs = 2`
31
+ - `per_device_train_batch_size = 8`
32
+ - `gradient_accumulation_steps = 4`
33
+ - effective batch size = 32
34
+ - `learning_rate = 1e-4`
35
+ - `warmup_ratio = 0.05`
36
+ - `eval_steps = 100`
37
+ - `save_steps = 100`
38
+
39
+ Smoke test mode:
40
+
41
+ - `RUN_MODE = "smoke_test"`
42
+ - short `max_steps`
43
+ - verifies formatting, checkpointing, and eval wiring
44
+
45
+ Optional max-step override:
46
+
47
+ - set `MAX_STEPS_OVERRIDE = 5000` only for explicit experimentation
48
+ - do not use it as the default main run
49
+
50
+ Artifacts:
51
+
52
+ - checkpoints: `artifacts/district_llm_adapter_v3/<run_mode>/checkpoints`
53
+ - saved adapter: `artifacts/district_llm_adapter_v3/<run_mode>/adapter`
54
+
55
+ ## Evaluation
56
+
57
+ Run offline eval with repair enabled:
58
+
59
+ ```bash
60
+ python -m district_llm.eval \
61
+ --model-path artifacts/district_llm_adapter_v3/main_run/adapter \
62
+ --val-jsonl data/district_llm_dataset_v3/val.jsonl \
63
+ --generated-root data/generated \
64
+ --max-examples 250 \
65
+ --debug-examples 10 \
66
+ --allow-only-visible-candidates \
67
+ --max-target-intersections 3 \
68
+ --fallback-on-empty-targets \
69
+ --fallback-mode heuristic \
70
+ --restrict-targets-to-visible-summary \
71
+ --report-before-after-repair
72
+ ```
73
+
74
+ Key outputs:
75
+
76
+ - raw vs repaired target metrics
77
+ - invalid target-id rate before and after repair
78
+ - visible-candidate-restricted metrics
79
+ - target failure buckets and debug examples
district_llm/RL_GUIDANCE_EVAL_RUNBOOK.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RL Guidance Eval Runbook
2
+
3
+ This ablation keeps the RL checkpoint fixed.
4
+
5
+ District guidance is only used at inference time through the wrapper in
6
+ `district_llm/rl_guidance_wrapper.py`. The safest default is
7
+ `target_only_soft`, which applies a small local Q-value bias only at
8
+ `target_intersections`.
9
+
10
+ ## Wrapper Modes
11
+
12
+ - `no_op`: guidance is computed and logged, but RL actions are unchanged.
13
+ - `target_only_soft`: weak local prior on target intersections. Default debug mode.
14
+ - `target_only_medium`: same scope, slightly stronger.
15
+ - `corridor_soft`: small corridor prior on targets plus a few aligned boundary intersections.
16
+ - `global_soft`: weak district-wide prior. Use only as an ablation.
17
+ - `current_legacy`: reference mode approximating the old strong/global wrapper.
18
+
19
+ ## Fast Debug Matrix
20
+
21
+ Use a short horizon first so the wrapper can be debugged quickly:
22
+
23
+ ```bash
24
+ python scripts/eval_rl_guidance_ablation.py \
25
+ --rl-checkpoint artifacts/dqn_shared/best_validation.pt \
26
+ --llm-model-path artifacts/district_llm_adapter_v3/main_run/adapter \
27
+ --modes rl_only rl_heuristic rl_llm \
28
+ --wrapper-modes no_op target_only_soft current_legacy \
29
+ --split val \
30
+ --cities city_0001 \
31
+ --scenarios normal \
32
+ --seeds 7 11 13 \
33
+ --num-episodes 1 \
34
+ --max-episode-seconds 300 \
35
+ --guidance-refresh-steps 10 \
36
+ --guidance-persistence-steps 3 \
37
+ --bias-strength 0.12 \
38
+ --target-only-bias-strength 0.18 \
39
+ --corridor-bias-strength 0.05 \
40
+ --max-intersections-affected 3 \
41
+ --fallback-policy hold_previous \
42
+ --save-guidance-traces \
43
+ --output-dir artifacts/rl_guidance_eval/debug_matrix_300s
44
+ ```
45
+
46
+ This expands into the paired comparison:
47
+
48
+ - `rl_only`
49
+ - `rl_heuristic+no_op`
50
+ - `rl_heuristic+target_only_soft`
51
+ - `rl_heuristic+current_legacy`
52
+ - `rl_llm+no_op`
53
+ - `rl_llm+target_only_soft`
54
+ - `rl_llm+current_legacy`
55
+
56
+ That command runs a superset of the exact smaller matrix from the wrapper audit prompt. Focus analysis on:
57
+
58
+ - `rl_only`
59
+ - `rl_heuristic+no_op`
60
+ - `rl_heuristic+target_only_soft`
61
+ - `rl_llm+no_op`
62
+ - `rl_llm+target_only_soft`
63
+ - `rl_llm+current_legacy`
64
+
65
+ ## What To Look At
66
+
67
+ Primary files:
68
+
69
+ - `summary.json`
70
+ - `episode_metrics.csv`
71
+ - `guidance_traces.jsonl`
72
+ - `config.json`
73
+
74
+ Key wrapper metrics in `episode_metrics.csv`:
75
+
76
+ - `wrapper_mode`
77
+ - `mean_bias_magnitude`
78
+ - `max_bias_magnitude`
79
+ - `avg_num_targeted_intersections`
80
+ - `avg_num_affected_intersections`
81
+ - `percent_steps_with_active_guidance`
82
+ - `num_guidance_refreshes`
83
+ - `num_noop_guidance_events`
84
+ - `fallback_policy_used_count`
85
+
86
+ Interpretation:
87
+
88
+ - If `rl_heuristic+no_op` and `rl_llm+no_op` match `rl_only`, the harness itself is fine.
89
+ - If `current_legacy` collapses while `target_only_soft` stays near `rl_only`, the wrapper was too strong/global.
90
+ - If `rl_llm+target_only_soft` diverges from `rl_heuristic+target_only_soft`, the LLM is adding signal under safe integration.
91
+ - If `avg_num_affected_intersections` is large or `percent_steps_with_active_guidance` is near `1.0`, the wrapper is still too persistent or too broad.
92
+ - If `fallback_policy_used_count` stays high in `rl_llm`, inspect `guidance_traces.jsonl` before trusting traffic metrics.
93
+
94
+ ## Cheap Follow-Up Ablations
95
+
96
+ Softer local prior:
97
+
98
+ ```bash
99
+ --wrapper-modes no_op target_only_soft target_only_medium
100
+ ```
101
+
102
+ Scope ablation:
103
+
104
+ ```bash
105
+ --wrapper-modes target_only_soft corridor_soft global_soft current_legacy
106
+ ```
107
+
108
+ More conservative persistence:
109
+
110
+ ```bash
111
+ --guidance-refresh-steps 8 --guidance-persistence-steps 2
112
+ ```
113
+
114
+ ## Output Layout
115
+
116
+ Outputs are saved under the requested directory, for example:
117
+
118
+ ```text
119
+ artifacts/rl_guidance_eval/debug_matrix_300s/
120
+ config.json
121
+ summary.json
122
+ episode_metrics.csv
123
+ episode_metrics.jsonl
124
+ guidance_traces.jsonl
125
+ seeded_configs/
126
+ ```
district_llm/RL_LLM_WRAPPER_SWEEP_RUNBOOK.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RL+LLM Wrapper Sweep
2
+
3
+ This sweep keeps both checkpoints fixed:
4
+
5
+ - RL weights stay fixed.
6
+ - LLM weights stay fixed.
7
+ - Only the inference-time `target_only_soft` wrapper settings change.
8
+
9
+ ## Recommended First Sweep
10
+
11
+ Run the default cheap preset on one city, one scenario, and three seeds:
12
+
13
+ ```bash
14
+ python scripts/sweep_rl_llm_wrapper.py \
15
+ --rl-checkpoint artifacts/dqn_shared/best_validation.pt \
16
+ --llm-model-path artifacts/district_llm_adapter_v3/main_run/adapter \
17
+ --preset strength_targets_gating \
18
+ --split val \
19
+ --cities city_0001 \
20
+ --scenarios normal \
21
+ --seeds 7 11 13 \
22
+ --episodes-per-seed 1 \
23
+ --max-episode-seconds 300 \
24
+ --guidance-refresh-steps 10 \
25
+ --queue-threshold 150 \
26
+ --imbalance-threshold 20 \
27
+ --fallback-policy no_op \
28
+ --output-dir artifacts/rl_llm_wrapper_sweep/first_pass
29
+ ```
30
+
31
+ This preset sweeps a small curated grid over:
32
+
33
+ - `bias_strength` in `{0.025, 0.05, 0.075}`
34
+ - `max_intersections_affected` in `{1, 2}`
35
+ - `gating_mode` in `{always_on, incident_or_spillback, queue_or_imbalance}`
36
+ - `guidance_persistence_steps = 5`
37
+ - `enable_bias_decay = false`
38
+
39
+ It also includes `baseline_current_soft` as a reference row.
40
+
41
+ ## Cheaper Probe
42
+
43
+ If you only want the fastest possible read on strength sensitivity:
44
+
45
+ ```bash
46
+ python scripts/sweep_rl_llm_wrapper.py \
47
+ --rl-checkpoint artifacts/dqn_shared/best_validation.pt \
48
+ --llm-model-path artifacts/district_llm_adapter_v3/main_run/adapter \
49
+ --preset strength_only \
50
+ --cities city_0001 \
51
+ --scenarios normal \
52
+ --seeds 7 11 13 \
53
+ --episodes-per-seed 1 \
54
+ --max-episode-seconds 300 \
55
+ --output-dir artifacts/rl_llm_wrapper_sweep/strength_only
56
+ ```
57
+
58
+ ## Broader Conservative Follow-Up
59
+
60
+ After the first pass identifies a promising strength/gating region:
61
+
62
+ ```bash
63
+ python scripts/sweep_rl_llm_wrapper.py \
64
+ --rl-checkpoint artifacts/dqn_shared/best_validation.pt \
65
+ --llm-model-path artifacts/district_llm_adapter_v3/main_run/adapter \
66
+ --preset full_conservative \
67
+ --cities city_0001 \
68
+ --scenarios normal \
69
+ --seeds 7 11 13 \
70
+ --episodes-per-seed 1 \
71
+ --max-episode-seconds 300 \
72
+ --output-dir artifacts/rl_llm_wrapper_sweep/full_conservative
73
+ ```
74
+
75
+ ## Outputs
76
+
77
+ Each sweep writes:
78
+
79
+ - `config.json`
80
+ - `sweep_results.csv`
81
+ - `sweep_results.parquet` when parquet support is available
82
+ - `paired_episode_metrics.csv`
83
+ - `ranking.json`
84
+ - `summary_report.json`
85
+ - optional `step_metrics.*`
86
+ - optional `guidance_traces.jsonl`
87
+
88
+ ## What To Inspect
89
+
90
+ Start with:
91
+
92
+ - `summary_report.json`
93
+ - `ranking.json`
94
+ - `paired_episode_metrics.csv`
95
+
96
+ Key fields:
97
+
98
+ - `mean_return_delta_vs_rl_only`
99
+ - `mean_throughput_delta_vs_rl_only`
100
+ - `mean_avg_queue_delta_vs_rl_only`
101
+ - `mean_avg_wait_delta_vs_rl_only`
102
+ - `mean_percent_steps_with_active_guidance`
103
+ - `mean_avg_num_affected_intersections`
104
+ - `mean_num_steps_guidance_blocked_by_gate`
105
+
106
+ ## Interpretation
107
+
108
+ The most promising configs should usually look like:
109
+
110
+ - small negative or positive `mean_return_delta_vs_rl_only`
111
+ - low `mean_avg_num_affected_intersections`
112
+ - moderate or low `mean_percent_steps_with_active_guidance`
113
+ - low fallback / invalid guidance counts
114
+
115
+ If the best configs cluster around:
116
+
117
+ - lower `bias_strength`
118
+ - `max_intersections_affected = 1`
119
+ - gated modes like `incident_or_spillback` or `queue_or_imbalance`
120
+
121
+ then the wrapper was still too active and guidance needs to remain a rare local prior.
district_llm/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from district_llm.derivation import DistrictWindowData, LocalIntersectionAction, derive_district_action
2
+ from district_llm.prompting import build_system_prompt, format_district_prompt, format_sft_text
3
+ from district_llm.schema import CandidateIntersection, CongestedIntersection, DistrictAction, DistrictStateSummary
4
+ from district_llm.summary_builder import DistrictStateSummaryBuilder
5
+
6
+ __all__ = [
7
+ "CandidateIntersection",
8
+ "CongestedIntersection",
9
+ "DistrictAction",
10
+ "DistrictStateSummary",
11
+ "DistrictStateSummaryBuilder",
12
+ "DistrictWindowData",
13
+ "LocalIntersectionAction",
14
+ "derive_district_action",
15
+ "build_system_prompt",
16
+ "format_district_prompt",
17
+ "format_sft_text",
18
+ ]
district_llm/data.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ def load_jsonl_text_dataset(
6
+ path: str | Path,
7
+ controller_families: list[str] | None = None,
8
+ controller_types: list[str] | None = None,
9
+ ):
10
+ from datasets import load_dataset
11
+
12
+ dataset = load_dataset("json", data_files=str(Path(path)), split="train")
13
+ if "text" not in dataset.column_names:
14
+ raise ValueError("Expected a JSONL dataset with a 'text' field.")
15
+ if controller_families:
16
+ allowed_families = set(controller_families)
17
+ dataset = dataset.filter(
18
+ lambda row: row.get("controller_family") in allowed_families
19
+ )
20
+ if controller_types:
21
+ allowed_types = set(controller_types)
22
+ dataset = dataset.filter(
23
+ lambda row: row.get("controller_type") in allowed_types
24
+ )
25
+ if len(dataset) == 0:
26
+ raise ValueError("No dataset rows remain after applying the requested filters.")
27
+ return dataset
district_llm/derivation.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any
5
+
6
+ from district_llm.repair import fallback_target_intersections
7
+ from district_llm.schema import DistrictAction, DistrictStateSummary
8
+
9
+
10
+ @dataclass
11
+ class LocalIntersectionAction:
12
+ intersection_id: str
13
+ district_id: str
14
+ action: int
15
+ current_phase: int
16
+ next_phase: int
17
+ queue_total: float
18
+ wait_total: float
19
+ outgoing_load: float
20
+ is_boundary: bool
21
+
22
+ @property
23
+ def switched(self) -> bool:
24
+ return int(self.action) == 1 and self.next_phase != self.current_phase
25
+
26
+
27
+ @dataclass
28
+ class DistrictWindowData:
29
+ district_id: str
30
+ start_summary: DistrictStateSummary
31
+ end_summary: DistrictStateSummary
32
+ controller_actions: list[LocalIntersectionAction] = field(default_factory=list)
33
+ step_count: int = 0
34
+
35
+ def to_dict(self) -> dict[str, Any]:
36
+ return {
37
+ "district_id": self.district_id,
38
+ "step_count": int(self.step_count),
39
+ "queue_delta": round(self.end_summary.total_queue - self.start_summary.total_queue, 3),
40
+ "wait_delta": round(self.end_summary.total_wait - self.start_summary.total_wait, 3),
41
+ "throughput_delta": round(
42
+ self.end_summary.recent_throughput - self.start_summary.recent_throughput,
43
+ 3,
44
+ ),
45
+ }
46
+
47
+
48
+ def derive_district_action(
49
+ window_data: DistrictWindowData,
50
+ controller_actions: list[LocalIntersectionAction] | None = None,
51
+ district_state: DistrictStateSummary | None = None,
52
+ max_target_intersections: int = 3,
53
+ ) -> DistrictAction:
54
+ """
55
+ Deterministic first-pass label extraction from local-controller behavior.
56
+
57
+ Heuristic order:
58
+ 1. Incident-heavy windows map to `incident_response`.
59
+ 2. Strong spillback / boundary pressure maps to `clear_spillback`.
60
+ 3. Rising boundary demand maps to `drain_inbound`.
61
+ 4. Persistently high outgoing pressure maps to `drain_outbound`.
62
+ 5. Boundary-heavy rush windows map to `arterial_priority`.
63
+ 6. Clear NS/EW directional dominance maps to `favor_NS` / `favor_EW`.
64
+ 7. Otherwise emit `hold`.
65
+ """
66
+ actions = controller_actions if controller_actions is not None else window_data.controller_actions
67
+ state = district_state if district_state is not None else window_data.start_summary
68
+ end_state = window_data.end_summary
69
+
70
+ duration_steps = max(1, min(int(window_data.step_count or 1), 20))
71
+ phase_counts = {"NS": 0, "EW": 0}
72
+ focus_scores: dict[str, float] = {}
73
+ boundary_focus = 0
74
+ switch_count = 0
75
+
76
+ for item in actions:
77
+ phase_key = "NS" if int(item.next_phase) == 0 else "EW"
78
+ phase_counts[phase_key] += 1
79
+ switch_count += int(item.switched)
80
+ if item.is_boundary:
81
+ boundary_focus += 1
82
+ focus_scores[item.intersection_id] = focus_scores.get(item.intersection_id, 0.0) + (
83
+ item.queue_total + 1.5 * item.wait_total + 2.0 * float(item.switched)
84
+ )
85
+
86
+ total_action_records = max(1, len(actions))
87
+ ns_phase_ratio = phase_counts["NS"] / float(total_action_records)
88
+ ew_phase_ratio = phase_counts["EW"] / float(total_action_records)
89
+ boundary_focus_ratio = boundary_focus / float(total_action_records)
90
+ queue_delta = end_state.total_queue - state.total_queue
91
+ wait_delta = end_state.total_wait - state.total_wait
92
+ boundary_share = state.boundary_queue_total / max(1.0, state.total_queue)
93
+ outgoing_pressure = end_state.total_outgoing_load / max(1.0, end_state.total_queue)
94
+
95
+ if ns_phase_ratio > ew_phase_ratio + 0.1:
96
+ phase_bias = "NS"
97
+ elif ew_phase_ratio > ns_phase_ratio + 0.1:
98
+ phase_bias = "EW"
99
+ else:
100
+ phase_bias = "NONE"
101
+
102
+ if phase_bias == "NONE" and state.dominant_flow in {"NS", "EW"}:
103
+ phase_bias = state.dominant_flow
104
+
105
+ def select_targets(
106
+ strategy: str,
107
+ priority_corridor: str | None,
108
+ selected_phase_bias: str,
109
+ ) -> list[str]:
110
+ return fallback_target_intersections(
111
+ summary=state,
112
+ max_target_intersections=max_target_intersections,
113
+ strategy=strategy,
114
+ priority_corridor=priority_corridor,
115
+ phase_bias=selected_phase_bias,
116
+ focus_scores=focus_scores,
117
+ )
118
+
119
+ if state.incident_flag or end_state.incident_flag:
120
+ target_intersections = select_targets(
121
+ strategy="incident_response",
122
+ priority_corridor=phase_bias if phase_bias in {"NS", "EW"} else "arterial",
123
+ selected_phase_bias=phase_bias,
124
+ )
125
+ return DistrictAction(
126
+ strategy="incident_response",
127
+ priority_corridor=phase_bias if phase_bias in {"NS", "EW"} else "arterial",
128
+ target_intersections=target_intersections,
129
+ phase_bias=phase_bias,
130
+ duration_steps=duration_steps,
131
+ ).validate()
132
+
133
+ if state.spillback_risk or end_state.spillback_risk or (boundary_share >= 0.55 and outgoing_pressure >= 0.45):
134
+ priority_corridor = "inbound" if boundary_share >= 0.55 else phase_bias if phase_bias in {"NS", "EW"} else None
135
+ target_intersections = select_targets(
136
+ strategy="clear_spillback",
137
+ priority_corridor=priority_corridor,
138
+ selected_phase_bias=phase_bias,
139
+ )
140
+ return DistrictAction(
141
+ strategy="clear_spillback",
142
+ priority_corridor=priority_corridor,
143
+ target_intersections=target_intersections,
144
+ phase_bias=phase_bias,
145
+ duration_steps=duration_steps,
146
+ ).validate()
147
+
148
+ if boundary_share >= 0.55 and (queue_delta >= 0.0 or wait_delta >= 0.0):
149
+ target_intersections = select_targets(
150
+ strategy="drain_inbound",
151
+ priority_corridor="inbound",
152
+ selected_phase_bias=phase_bias,
153
+ )
154
+ return DistrictAction(
155
+ strategy="drain_inbound",
156
+ priority_corridor="inbound",
157
+ target_intersections=target_intersections,
158
+ phase_bias=phase_bias,
159
+ duration_steps=duration_steps,
160
+ ).validate()
161
+
162
+ if outgoing_pressure >= 0.65 and end_state.total_queue >= state.total_queue * 0.9:
163
+ target_intersections = select_targets(
164
+ strategy="drain_outbound",
165
+ priority_corridor="outbound",
166
+ selected_phase_bias=phase_bias,
167
+ )
168
+ return DistrictAction(
169
+ strategy="drain_outbound",
170
+ priority_corridor="outbound",
171
+ target_intersections=target_intersections,
172
+ phase_bias=phase_bias,
173
+ duration_steps=duration_steps,
174
+ ).validate()
175
+
176
+ if (
177
+ state.event_flag
178
+ or state.overload_flag
179
+ or end_state.overload_flag
180
+ or (boundary_focus_ratio >= 0.6 and switch_count >= max(2, duration_steps))
181
+ ):
182
+ priority_corridor = phase_bias if phase_bias in {"NS", "EW"} else "arterial"
183
+ target_intersections = select_targets(
184
+ strategy="arterial_priority",
185
+ priority_corridor=priority_corridor,
186
+ selected_phase_bias=phase_bias,
187
+ )
188
+ return DistrictAction(
189
+ strategy="arterial_priority",
190
+ priority_corridor=priority_corridor,
191
+ target_intersections=target_intersections,
192
+ phase_bias=phase_bias,
193
+ duration_steps=duration_steps,
194
+ ).validate()
195
+
196
+ ns_pressure = state.ns_queue + 1.5 * state.ns_wait
197
+ ew_pressure = state.ew_queue + 1.5 * state.ew_wait
198
+ imbalance_threshold = max(5.0, 0.15 * max(1.0, ns_pressure + ew_pressure))
199
+
200
+ if ns_pressure - ew_pressure >= imbalance_threshold:
201
+ target_intersections = select_targets(
202
+ strategy="favor_NS",
203
+ priority_corridor="NS",
204
+ selected_phase_bias="NS",
205
+ )
206
+ return DistrictAction(
207
+ strategy="favor_NS",
208
+ priority_corridor="NS",
209
+ target_intersections=target_intersections,
210
+ phase_bias="NS",
211
+ duration_steps=duration_steps,
212
+ ).validate()
213
+
214
+ if ew_pressure - ns_pressure >= imbalance_threshold:
215
+ target_intersections = select_targets(
216
+ strategy="favor_EW",
217
+ priority_corridor="EW",
218
+ selected_phase_bias="EW",
219
+ )
220
+ return DistrictAction(
221
+ strategy="favor_EW",
222
+ priority_corridor="EW",
223
+ target_intersections=target_intersections,
224
+ phase_bias="EW",
225
+ duration_steps=duration_steps,
226
+ ).validate()
227
+
228
+ return DistrictAction.default_hold(duration_steps=duration_steps)
district_llm/eval.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from collections import Counter
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from district_llm.metrics import aggregate_target_metrics, compute_target_metrics, safe_ratio, target_failure_buckets
10
+ from district_llm.repair import RepairConfig, extract_visible_candidate_ids, sanitize_action_payload
11
+ from district_llm.schema import DistrictAction
12
+ from env.utils import build_topology
13
+
14
+ try:
15
+ from tqdm.auto import tqdm
16
+ except ImportError: # pragma: no cover
17
+ tqdm = None
18
+
19
+ def parse_args() -> argparse.Namespace:
20
+ parser = argparse.ArgumentParser(
21
+ description="Offline evaluation for district-LLM outputs."
22
+ )
23
+ parser.add_argument("--model-path", required=True)
24
+ parser.add_argument("--val-jsonl", required=True)
25
+ parser.add_argument("--max-examples", type=int, default=200)
26
+ parser.add_argument("--debug-examples", type=int, default=10)
27
+ parser.add_argument("--max-new-tokens", type=int, default=128)
28
+ parser.add_argument("--device", default=None)
29
+ parser.add_argument("--generated-root", default="data/generated")
30
+ parser.add_argument("--restrict-targets-to-visible-summary", action="store_true")
31
+ parser.add_argument(
32
+ "--allow-only-visible-candidates",
33
+ action=argparse.BooleanOptionalAction,
34
+ default=True,
35
+ )
36
+ parser.add_argument("--max-target-intersections", type=int, default=3)
37
+ parser.add_argument(
38
+ "--fallback-on-empty-targets",
39
+ action=argparse.BooleanOptionalAction,
40
+ default=True,
41
+ )
42
+ parser.add_argument(
43
+ "--fallback-mode",
44
+ choices=("heuristic", "hold", "none"),
45
+ default="heuristic",
46
+ )
47
+ parser.add_argument(
48
+ "--report-before-after-repair",
49
+ action=argparse.BooleanOptionalAction,
50
+ default=True,
51
+ )
52
+ return parser.parse_args()
53
+
54
+
55
+ def load_rows(path: str | Path, max_examples: int | None = None) -> list[dict[str, Any]]:
56
+ rows = []
57
+ with Path(path).open("r", encoding="utf-8") as handle:
58
+ for line in handle:
59
+ if not line.strip():
60
+ continue
61
+ rows.append(json.loads(line))
62
+ if max_examples is not None and len(rows) >= max_examples:
63
+ break
64
+ return rows
65
+
66
+
67
+ def extract_json_object(payload: str) -> str:
68
+ start = payload.find("{")
69
+ end = payload.rfind("}")
70
+ if start == -1 or end == -1 or end <= start:
71
+ raise ValueError("No JSON object found.")
72
+ return payload[start : end + 1]
73
+
74
+
75
+ def load_model_and_tokenizer(model_path: str, device: str | None = None):
76
+ import torch
77
+ from transformers import AutoModelForCausalLM, AutoTokenizer
78
+
79
+ model_dir = Path(model_path)
80
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
81
+ if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
82
+ tokenizer.pad_token = tokenizer.eos_token
83
+
84
+ if (model_dir / "adapter_config.json").exists():
85
+ try:
86
+ from peft import AutoPeftModelForCausalLM
87
+ except ImportError as exc:
88
+ raise ImportError(
89
+ "Evaluating a LoRA adapter requires the 'peft' package."
90
+ ) from exc
91
+ model = AutoPeftModelForCausalLM.from_pretrained(model_path)
92
+ else:
93
+ target_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
94
+ model = AutoModelForCausalLM.from_pretrained(model_path).to(target_device)
95
+ model.eval()
96
+ return model, tokenizer
97
+
98
+
99
+ def build_generation_prompt(tokenizer, messages: list[dict[str, str]]) -> str:
100
+ if getattr(tokenizer, "chat_template", None):
101
+ return tokenizer.apply_chat_template(
102
+ messages,
103
+ tokenize=False,
104
+ add_generation_prompt=True,
105
+ )
106
+ return "\n".join(f"{message['role']}: {message['content']}" for message in messages) + "\nassistant:"
107
+
108
+
109
+ def generate_response(model, tokenizer, messages: list[dict[str, str]], max_new_tokens: int) -> str:
110
+ import torch
111
+
112
+ prompt = build_generation_prompt(tokenizer, messages)
113
+ device = getattr(model, "device", None)
114
+ inputs = tokenizer(prompt, return_tensors="pt")
115
+ if device is not None:
116
+ inputs = {key: value.to(device) for key, value in inputs.items()}
117
+ with torch.no_grad():
118
+ outputs = model.generate(
119
+ **inputs,
120
+ max_new_tokens=max_new_tokens,
121
+ do_sample=False,
122
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
123
+ )
124
+ generated = outputs[0][inputs["input_ids"].shape[1] :]
125
+ return tokenizer.decode(generated, skip_special_tokens=True)
126
+
127
+
128
+ def parse_prediction(payload: str) -> tuple[bool, bool, dict[str, Any] | None]:
129
+ try:
130
+ json_payload = json.loads(extract_json_object(payload))
131
+ except Exception:
132
+ return False, False, None
133
+ try:
134
+ action = DistrictAction.from_dict(json_payload)
135
+ except Exception:
136
+ return True, False, json_payload
137
+ return True, True, action.to_dict()
138
+
139
+
140
+ class DistrictTopologyIndex:
141
+ def __init__(self, generated_root: str | Path):
142
+ self.generated_root = Path(generated_root)
143
+ self._cache: dict[str, dict[str, set[str]]] = {}
144
+
145
+ def district_intersections(self, city_id: str, district_id: str) -> set[str]:
146
+ if city_id not in self._cache:
147
+ roadnet_path = self.generated_root / city_id / "roadnet.json"
148
+ district_map_path = self.generated_root / city_id / "district_map.json"
149
+ metadata_path = self.generated_root / city_id / "metadata.json"
150
+ _, districts = build_topology(
151
+ roadnet_path=roadnet_path,
152
+ district_map_path=district_map_path,
153
+ metadata_path=metadata_path,
154
+ )
155
+ self._cache[city_id] = {
156
+ key: set(value.intersection_ids)
157
+ for key, value in districts.items()
158
+ }
159
+ return self._cache[city_id].get(district_id, set())
160
+
161
+
162
+ def field_accuracy(pred: dict[str, Any] | None, gt: dict[str, Any], field: str) -> float:
163
+ if pred is None:
164
+ return 0.0
165
+ return float(pred.get(field) == gt.get(field))
166
+
167
+
168
+ def invalid_target_fraction(pred_targets: list[str], district_candidates: set[str]) -> float:
169
+ if not pred_targets:
170
+ return 0.0
171
+ invalid_count = sum(1 for item in pred_targets if item not in district_candidates)
172
+ return safe_ratio(invalid_count, len(pred_targets))
173
+
174
+
175
+ def evaluate_rows(
176
+ rows: list[dict[str, Any]],
177
+ model,
178
+ tokenizer,
179
+ max_new_tokens: int,
180
+ topology_index: DistrictTopologyIndex,
181
+ restrict_targets_to_visible_summary: bool,
182
+ debug_examples: int,
183
+ repair_config: RepairConfig,
184
+ report_before_after_repair: bool,
185
+ ) -> dict[str, Any]:
186
+ json_valid_count = 0
187
+ schema_valid_count = 0
188
+ field_totals_before = Counter()
189
+ field_totals_after = Counter()
190
+ full_object_correct_before = 0
191
+ full_object_correct_after = 0
192
+ target_rows_before: list[dict[str, float]] = []
193
+ target_rows_after: list[dict[str, float]] = []
194
+ restricted_target_rows_before: list[dict[str, float]] = []
195
+ restricted_target_rows_after: list[dict[str, float]] = []
196
+ invalid_rates_before: list[float] = []
197
+ invalid_rates_after: list[float] = []
198
+ fallback_used_count = 0
199
+ failure_buckets = Counter()
200
+ debug_rows = []
201
+
202
+ progress = (
203
+ tqdm(total=len(rows), desc="eval", dynamic_ncols=True)
204
+ if tqdm is not None
205
+ else None
206
+ )
207
+
208
+ try:
209
+ for row in rows:
210
+ messages = row["messages"]
211
+ ground_truth = json.loads(messages[2]["content"])
212
+ raw_prediction = generate_response(
213
+ model=model,
214
+ tokenizer=tokenizer,
215
+ messages=messages[:2],
216
+ max_new_tokens=max_new_tokens,
217
+ )
218
+ json_valid, schema_valid, prediction_before = parse_prediction(raw_prediction)
219
+ repaired_action, repair_report = sanitize_action_payload(
220
+ payload=prediction_before if json_valid else None,
221
+ summary=row,
222
+ prompt_text=messages[1]["content"],
223
+ config=repair_config,
224
+ )
225
+ prediction_after = repaired_action.to_dict()
226
+ json_valid_count += int(json_valid)
227
+ schema_valid_count += int(schema_valid)
228
+ fallback_used_count += int(repair_report.fallback_used)
229
+
230
+ field_totals_before["strategy"] += field_accuracy(prediction_before, ground_truth, "strategy")
231
+ field_totals_before["priority_corridor"] += field_accuracy(prediction_before, ground_truth, "priority_corridor")
232
+ field_totals_before["phase_bias"] += field_accuracy(prediction_before, ground_truth, "phase_bias")
233
+ field_totals_before["duration_steps"] += field_accuracy(prediction_before, ground_truth, "duration_steps")
234
+
235
+ field_totals_after["strategy"] += field_accuracy(prediction_after, ground_truth, "strategy")
236
+ field_totals_after["priority_corridor"] += field_accuracy(prediction_after, ground_truth, "priority_corridor")
237
+ field_totals_after["phase_bias"] += field_accuracy(prediction_after, ground_truth, "phase_bias")
238
+ field_totals_after["duration_steps"] += field_accuracy(prediction_after, ground_truth, "duration_steps")
239
+
240
+ if prediction_before == ground_truth:
241
+ full_object_correct_before += 1
242
+ if prediction_after == ground_truth:
243
+ full_object_correct_after += 1
244
+
245
+ pred_targets_before = [] if prediction_before is None else list(prediction_before.get("target_intersections", []))
246
+ pred_targets_after = list(prediction_after.get("target_intersections", []))
247
+ gt_targets = list(ground_truth.get("target_intersections", []))
248
+ visible_candidates = set(
249
+ extract_visible_candidate_ids(summary=row, prompt_text=messages[1]["content"])
250
+ )
251
+ district_candidates = topology_index.district_intersections(
252
+ city_id=row["city_id"],
253
+ district_id=row["district_id"],
254
+ )
255
+ invalid_before = [item for item in pred_targets_before if item not in district_candidates]
256
+ invalid_after = [item for item in pred_targets_after if item not in district_candidates]
257
+ non_visible_before = [
258
+ item for item in pred_targets_before
259
+ if visible_candidates and item not in visible_candidates
260
+ ]
261
+
262
+ metrics_before = compute_target_metrics(pred_targets_before, gt_targets)
263
+ metrics_after = compute_target_metrics(pred_targets_after, gt_targets)
264
+ target_rows_before.append(metrics_before)
265
+ target_rows_after.append(metrics_after)
266
+ invalid_rates_before.append(invalid_target_fraction(pred_targets_before, district_candidates))
267
+ invalid_rates_after.append(invalid_target_fraction(pred_targets_after, district_candidates))
268
+
269
+ if restrict_targets_to_visible_summary:
270
+ filtered_pred_before = [item for item in pred_targets_before if item in visible_candidates]
271
+ filtered_pred_after = [item for item in pred_targets_after if item in visible_candidates]
272
+ filtered_gt = [item for item in gt_targets if item in visible_candidates]
273
+ restricted_target_rows_before.append(
274
+ compute_target_metrics(filtered_pred_before, filtered_gt)
275
+ )
276
+ restricted_target_rows_after.append(
277
+ compute_target_metrics(filtered_pred_after, filtered_gt)
278
+ )
279
+
280
+ for failure_bucket in set(
281
+ target_failure_buckets(
282
+ pred_list=pred_targets_before,
283
+ gt_list=gt_targets,
284
+ visible_candidates=visible_candidates,
285
+ invalid_ids=invalid_before,
286
+ non_visible_ids=non_visible_before,
287
+ repaired_targets=pred_targets_after,
288
+ fallback_used=repair_report.fallback_used,
289
+ )
290
+ ):
291
+ failure_buckets[failure_bucket] += 1
292
+
293
+ if len(debug_rows) < debug_examples:
294
+ debug_rows.append(
295
+ {
296
+ "district_summary": messages[1]["content"],
297
+ "predicted_json_raw": raw_prediction,
298
+ "predicted_json_parsed_before_repair": prediction_before,
299
+ "predicted_json_parsed_after_repair": prediction_after,
300
+ "ground_truth_json": ground_truth,
301
+ "target_intersections_metrics_before_repair": metrics_before,
302
+ "target_intersections_metrics_after_repair": metrics_after,
303
+ "repair_report": repair_report.to_dict(),
304
+ "visible_candidate_ids": sorted(visible_candidates),
305
+ "failure_buckets": sorted(
306
+ set(
307
+ target_failure_buckets(
308
+ pred_list=pred_targets_before,
309
+ gt_list=gt_targets,
310
+ visible_candidates=visible_candidates,
311
+ invalid_ids=invalid_before,
312
+ non_visible_ids=non_visible_before,
313
+ repaired_targets=pred_targets_after,
314
+ fallback_used=repair_report.fallback_used,
315
+ )
316
+ )
317
+ ),
318
+ }
319
+ )
320
+ if progress is not None:
321
+ progress.update(1)
322
+ finally:
323
+ if progress is not None:
324
+ progress.close()
325
+
326
+ total_rows = max(1, len(rows))
327
+ results = {
328
+ "num_examples": len(rows),
329
+ "json_validity_rate": float(json_valid_count) / total_rows,
330
+ "schema_validity_rate": float(schema_valid_count) / total_rows,
331
+ "field_accuracy": {
332
+ "strategy": float(field_totals_before["strategy"]) / total_rows,
333
+ "priority_corridor": float(field_totals_before["priority_corridor"]) / total_rows,
334
+ "phase_bias": float(field_totals_before["phase_bias"]) / total_rows,
335
+ "duration_steps": float(field_totals_before["duration_steps"]) / total_rows,
336
+ },
337
+ "field_accuracy_after_repair": {
338
+ "strategy": float(field_totals_after["strategy"]) / total_rows,
339
+ "priority_corridor": float(field_totals_after["priority_corridor"]) / total_rows,
340
+ "phase_bias": float(field_totals_after["phase_bias"]) / total_rows,
341
+ "duration_steps": float(field_totals_after["duration_steps"]) / total_rows,
342
+ },
343
+ "target_intersections_before_repair": aggregate_target_metrics(target_rows_before),
344
+ "target_intersections_after_repair": aggregate_target_metrics(target_rows_after),
345
+ "target_intersections": aggregate_target_metrics(target_rows_after),
346
+ "target_intersections_failure_buckets": dict(sorted(failure_buckets.items())),
347
+ "exact_full_object_accuracy": float(full_object_correct_before) / total_rows,
348
+ "exact_full_object_accuracy_after_repair": float(full_object_correct_after) / total_rows,
349
+ "debug_examples": debug_rows,
350
+ }
351
+ if restrict_targets_to_visible_summary:
352
+ results["target_intersections_restricted_to_visible_summary_before_repair"] = aggregate_target_metrics(
353
+ restricted_target_rows_before
354
+ )
355
+ results["target_intersections_restricted_to_visible_summary_after_repair"] = aggregate_target_metrics(
356
+ restricted_target_rows_after
357
+ )
358
+ results["target_intersections_restricted_to_visible_summary"] = aggregate_target_metrics(
359
+ restricted_target_rows_after
360
+ )
361
+ if report_before_after_repair:
362
+ results["target_intersections_before_after_repair"] = {
363
+ "invalid_id_rate_before_repair": float(sum(invalid_rates_before) / total_rows),
364
+ "invalid_id_rate_after_repair": float(sum(invalid_rates_after) / total_rows),
365
+ "exact_set_match_before_repair": aggregate_target_metrics(target_rows_before).get("exact_set_match", 0.0),
366
+ "exact_set_match_after_repair": aggregate_target_metrics(target_rows_after).get("exact_set_match", 0.0),
367
+ "jaccard_before_repair": aggregate_target_metrics(target_rows_before).get("jaccard", 0.0),
368
+ "jaccard_after_repair": aggregate_target_metrics(target_rows_after).get("jaccard", 0.0),
369
+ "fallback_used_rate": float(fallback_used_count) / total_rows,
370
+ }
371
+ return results
372
+
373
+
374
+ def print_debug_examples(debug_rows: list[dict[str, Any]]) -> None:
375
+ for index, item in enumerate(debug_rows, start=1):
376
+ print(f"[debug {index}] district_summary:")
377
+ print(item["district_summary"])
378
+ print(f"[debug {index}] predicted_json_raw={item['predicted_json_raw']}")
379
+ print(
380
+ f"[debug {index}] predicted_json_parsed_before_repair="
381
+ f"{json.dumps(item['predicted_json_parsed_before_repair'], sort_keys=True)}"
382
+ )
383
+ print(
384
+ f"[debug {index}] predicted_json_parsed_after_repair="
385
+ f"{json.dumps(item['predicted_json_parsed_after_repair'], sort_keys=True)}"
386
+ )
387
+ print(
388
+ f"[debug {index}] ground_truth_json="
389
+ f"{json.dumps(item['ground_truth_json'], sort_keys=True)}"
390
+ )
391
+ print(
392
+ f"[debug {index}] target_intersections_metrics_before_repair="
393
+ f"{json.dumps(item['target_intersections_metrics_before_repair'], sort_keys=True)}"
394
+ )
395
+ print(
396
+ f"[debug {index}] target_intersections_metrics_after_repair="
397
+ f"{json.dumps(item['target_intersections_metrics_after_repair'], sort_keys=True)}"
398
+ )
399
+ print(
400
+ f"[debug {index}] repair_report="
401
+ f"{json.dumps(item['repair_report'], sort_keys=True)}"
402
+ )
403
+ print(
404
+ f"[debug {index}] visible_candidate_ids="
405
+ f"{json.dumps(item['visible_candidate_ids'], sort_keys=True)}"
406
+ )
407
+ print(f"[debug {index}] failure_buckets={json.dumps(item['failure_buckets'])}")
408
+
409
+
410
+ def main() -> None:
411
+ args = parse_args()
412
+ rows = load_rows(args.val_jsonl, max_examples=args.max_examples)
413
+ model, tokenizer = load_model_and_tokenizer(args.model_path, device=args.device)
414
+ topology_index = DistrictTopologyIndex(args.generated_root)
415
+ results = evaluate_rows(
416
+ rows=rows,
417
+ model=model,
418
+ tokenizer=tokenizer,
419
+ max_new_tokens=args.max_new_tokens,
420
+ topology_index=topology_index,
421
+ restrict_targets_to_visible_summary=args.restrict_targets_to_visible_summary,
422
+ debug_examples=args.debug_examples,
423
+ repair_config=RepairConfig(
424
+ allow_only_visible_candidates=args.allow_only_visible_candidates,
425
+ max_target_intersections=args.max_target_intersections,
426
+ fallback_on_empty_targets=args.fallback_on_empty_targets,
427
+ fallback_mode=args.fallback_mode,
428
+ ),
429
+ report_before_after_repair=args.report_before_after_repair,
430
+ )
431
+ print(json.dumps({k: v for k, v in results.items() if k != "debug_examples"}, indent=2, sort_keys=True))
432
+ print_debug_examples(results["debug_examples"])
433
+
434
+
435
+ if __name__ == "__main__":
436
+ main()
district_llm/generate_dataset.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import random
6
+ from dataclasses import dataclass, field
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+
12
+ from district_llm.derivation import DistrictWindowData, LocalIntersectionAction, derive_district_action
13
+ from district_llm.prompting import format_district_prompt, format_sft_text
14
+ from district_llm.summary_builder import DistrictStateSummaryBuilder
15
+ from district_llm.teachers import BaseTeacher, build_teacher, parse_teacher_spec
16
+ from env.observation_builder import ObservationConfig
17
+ from env.reward import RewardConfig
18
+ from env.traffic_env import EnvConfig, TrafficEnv
19
+ from training.cityflow_dataset import CityFlowDataset, ScenarioSpec
20
+
21
+
22
+ @dataclass
23
+ class _WindowBuffer:
24
+ start_summary: Any
25
+ controller_actions: list[LocalIntersectionAction] = field(default_factory=list)
26
+ step_count: int = 0
27
+
28
+
29
+ def parse_args() -> argparse.Namespace:
30
+ parser = argparse.ArgumentParser(
31
+ description="Generate district-LLM SFT data from CityFlow rollouts."
32
+ )
33
+ parser.add_argument(
34
+ "--controller",
35
+ default="queue_greedy",
36
+ choices=("rl_checkpoint", "hold", "fixed", "random", "queue_greedy"),
37
+ help="Single controller source used when --teacher-spec is not provided.",
38
+ )
39
+ parser.add_argument("--checkpoint", default=None)
40
+ parser.add_argument(
41
+ "--teacher-spec",
42
+ action="append",
43
+ default=[],
44
+ help="Repeatable source spec, e.g. rl_checkpoint=artifacts/dqn_shared/best_validation.pt or fixed.",
45
+ )
46
+ parser.add_argument("--episodes", type=int, default=10)
47
+ parser.add_argument(
48
+ "--decision-interval",
49
+ "--district-decision-interval",
50
+ dest="district_decision_interval",
51
+ type=int,
52
+ default=10,
53
+ help="District-LLM decision interval in local-controller decision steps.",
54
+ )
55
+ parser.add_argument("--output", required=True)
56
+ parser.add_argument("--generated-root", default="data/generated")
57
+ parser.add_argument("--splits-root", default="data/splits")
58
+ parser.add_argument("--split", default="train", choices=("train", "val", "test"))
59
+ parser.add_argument("--city-id", default=None)
60
+ parser.add_argument("--scenario-name", default=None)
61
+ parser.add_argument("--seed", type=int, default=7)
62
+ parser.add_argument("--fixed-green-time", type=int, default=20)
63
+ parser.add_argument("--device", default=None)
64
+ parser.add_argument("--append", action="store_true")
65
+ parser.add_argument("--top-k-congested", type=int, default=3)
66
+ parser.add_argument("--max-candidate-intersections", type=int, default=6)
67
+ parser.add_argument("--max-target-intersections", type=int, default=3)
68
+ parser.add_argument("--use-checkpoint-env-config", action="store_true")
69
+
70
+ parser.add_argument("--env-decision-interval", type=int, default=5)
71
+ parser.add_argument("--simulator-interval", type=int, default=1)
72
+ parser.add_argument("--min-green-time", type=int, default=10)
73
+ parser.add_argument("--thread-num", type=int, default=1)
74
+ parser.add_argument("--max-episode-seconds", type=int, default=None)
75
+ parser.add_argument("--max-incoming-lanes", type=int, default=16)
76
+ parser.add_argument("--count-scale", type=float, default=20.0)
77
+ parser.add_argument("--elapsed-time-scale", type=float, default=60.0)
78
+ parser.add_argument("--disable-district-context", action="store_true")
79
+ parser.add_argument("--disable-outgoing-congestion", action="store_true")
80
+ parser.add_argument("--reward-variant", default="wait_queue_throughput")
81
+ parser.add_argument("--waiting-weight", type=float, default=1.0)
82
+ parser.add_argument("--vehicle-weight", type=float, default=0.1)
83
+ parser.add_argument("--pressure-weight", type=float, default=0.0)
84
+ parser.add_argument("--reward-scale", type=float, default=0.1)
85
+ parser.add_argument("--disable-lane-reward-normalization", action="store_true")
86
+ parser.add_argument("--reward-clip", type=float, default=5.0)
87
+ parser.add_argument("--queue-delta-weight", type=float, default=2.0)
88
+ parser.add_argument("--wait-delta-weight", type=float, default=4.0)
89
+ parser.add_argument("--queue-level-weight", type=float, default=0.5)
90
+ parser.add_argument("--wait-level-weight", type=float, default=1.0)
91
+ parser.add_argument("--throughput-weight", type=float, default=0.1)
92
+ parser.add_argument("--imbalance-weight", type=float, default=0.1)
93
+ parser.add_argument("--reward-delta-clip", type=float, default=2.0)
94
+ parser.add_argument("--reward-level-normalizer", type=float, default=10.0)
95
+ parser.add_argument("--throughput-normalizer", type=float, default=2.0)
96
+ return parser.parse_args()
97
+
98
+
99
+ def build_env_config(args: argparse.Namespace) -> EnvConfig:
100
+ return EnvConfig(
101
+ simulator_interval=args.simulator_interval,
102
+ decision_interval=args.env_decision_interval,
103
+ min_green_time=args.min_green_time,
104
+ thread_num=args.thread_num,
105
+ max_episode_seconds=args.max_episode_seconds,
106
+ observation=ObservationConfig(
107
+ max_incoming_lanes=args.max_incoming_lanes,
108
+ count_scale=args.count_scale,
109
+ elapsed_time_scale=args.elapsed_time_scale,
110
+ include_outgoing_congestion=not args.disable_outgoing_congestion,
111
+ include_district_context=not args.disable_district_context,
112
+ include_district_type_feature=True,
113
+ ),
114
+ reward=RewardConfig(
115
+ variant=args.reward_variant,
116
+ waiting_weight=args.waiting_weight,
117
+ vehicle_weight=args.vehicle_weight,
118
+ pressure_weight=args.pressure_weight,
119
+ reward_scale=args.reward_scale,
120
+ normalize_by_lane_count=not args.disable_lane_reward_normalization,
121
+ clip_reward=args.reward_clip,
122
+ queue_delta_weight=args.queue_delta_weight,
123
+ wait_delta_weight=args.wait_delta_weight,
124
+ queue_level_weight=args.queue_level_weight,
125
+ wait_level_weight=args.wait_level_weight,
126
+ throughput_weight=args.throughput_weight,
127
+ imbalance_weight=args.imbalance_weight,
128
+ delta_clip=args.reward_delta_clip,
129
+ level_normalizer=args.reward_level_normalizer,
130
+ throughput_normalizer=args.throughput_normalizer,
131
+ ),
132
+ )
133
+
134
+
135
+ def build_env(env_config: EnvConfig, scenario_spec: ScenarioSpec) -> TrafficEnv:
136
+ return TrafficEnv(
137
+ city_id=scenario_spec.city_id,
138
+ scenario_name=scenario_spec.scenario_name,
139
+ city_dir=scenario_spec.city_dir,
140
+ scenario_dir=scenario_spec.scenario_dir,
141
+ config_path=scenario_spec.config_path,
142
+ roadnet_path=scenario_spec.roadnet_path,
143
+ district_map_path=scenario_spec.district_map_path,
144
+ metadata_path=scenario_spec.metadata_path,
145
+ env_config=env_config,
146
+ )
147
+
148
+
149
+ def resolve_teachers(args: argparse.Namespace) -> list[BaseTeacher]:
150
+ teacher_specs = list(args.teacher_spec)
151
+ if not teacher_specs:
152
+ teacher_specs = [args.controller if args.controller != "rl_checkpoint" else f"rl_checkpoint={args.checkpoint}"]
153
+
154
+ teachers = []
155
+ for spec in teacher_specs:
156
+ controller_type, checkpoint = parse_teacher_spec(spec)
157
+ if controller_type == "rl_checkpoint":
158
+ checkpoint = checkpoint or args.checkpoint
159
+ teachers.append(
160
+ build_teacher(
161
+ controller_type=controller_type,
162
+ checkpoint=checkpoint,
163
+ fixed_green_time=args.fixed_green_time,
164
+ seed=args.seed,
165
+ device=args.device,
166
+ )
167
+ )
168
+ return teachers
169
+
170
+
171
+ def resolve_env_config(args: argparse.Namespace, teachers: list[BaseTeacher]) -> EnvConfig:
172
+ env_config = build_env_config(args)
173
+ if not args.use_checkpoint_env_config:
174
+ return env_config
175
+
176
+ checkpoint_env_configs = [
177
+ teacher.env_config for teacher in teachers if teacher.env_config is not None
178
+ ]
179
+ if not checkpoint_env_configs:
180
+ return env_config
181
+
182
+ first_payload = checkpoint_env_configs[0]
183
+ assert first_payload is not None
184
+ for item in checkpoint_env_configs[1:]:
185
+ if item != first_payload:
186
+ raise ValueError("Checkpoint teachers use different env configs. Generate separate datasets.")
187
+ return first_payload
188
+
189
+
190
+ def sample_scenario(
191
+ dataset: CityFlowDataset,
192
+ rng: random.Random,
193
+ split: str,
194
+ city_id: str | None,
195
+ scenario_name: str | None,
196
+ ) -> ScenarioSpec:
197
+ if city_id and scenario_name:
198
+ return dataset.build_scenario_spec(city_id, scenario_name)
199
+ return dataset.sample_scenario(
200
+ split_name=split,
201
+ rng=rng,
202
+ city_id=city_id,
203
+ scenario_name=scenario_name,
204
+ )
205
+
206
+
207
+ def extract_step_actions(
208
+ env: TrafficEnv,
209
+ observation_batch: dict[str, Any],
210
+ next_observation_batch: dict[str, Any],
211
+ actions: np.ndarray,
212
+ ) -> dict[str, list[LocalIntersectionAction]]:
213
+ grouped: dict[str, list[LocalIntersectionAction]] = {district_id: [] for district_id in env.districts}
214
+ lane_vehicle_count = env.adapter.get_lane_vehicle_count()
215
+
216
+ for index, intersection_id in enumerate(observation_batch["intersection_ids"]):
217
+ district_id = observation_batch["district_ids"][index]
218
+ grouped[district_id].append(
219
+ LocalIntersectionAction(
220
+ intersection_id=intersection_id,
221
+ district_id=district_id,
222
+ action=int(actions[index]),
223
+ current_phase=int(observation_batch["current_phase"][index]),
224
+ next_phase=int(next_observation_batch["current_phase"][index]),
225
+ queue_total=float(np.asarray(observation_batch["incoming_counts"][index], dtype=np.float32).sum()),
226
+ wait_total=float(np.asarray(observation_batch["incoming_waiting"][index], dtype=np.float32).sum()),
227
+ outgoing_load=float(
228
+ sum(
229
+ float(lane_vehicle_count.get(lane_id, 0))
230
+ for lane_id in env.intersections[intersection_id].outgoing_lanes
231
+ )
232
+ ),
233
+ is_boundary=bool(env.intersections[intersection_id].is_boundary),
234
+ )
235
+ )
236
+ return grouped
237
+
238
+
239
+ def generate_examples_for_episode(
240
+ env: TrafficEnv,
241
+ teacher: BaseTeacher,
242
+ district_interval: int,
243
+ top_k_congested: int,
244
+ max_candidate_intersections: int,
245
+ max_target_intersections: int,
246
+ episode_index: int,
247
+ ) -> list[dict[str, Any]]:
248
+ summary_builder = DistrictStateSummaryBuilder(
249
+ top_k=top_k_congested,
250
+ candidate_limit=max_candidate_intersections,
251
+ )
252
+ observation_batch = env.reset()
253
+ summary_builder.reset()
254
+ current_summaries = summary_builder.build_all(env, observation_batch)
255
+ windows = {
256
+ district_id: _WindowBuffer(start_summary=summary)
257
+ for district_id, summary in current_summaries.items()
258
+ }
259
+ samples: list[dict[str, Any]] = []
260
+ done = False
261
+ window_index = 0
262
+
263
+ while not done:
264
+ actions = teacher.act(observation_batch)
265
+ next_observation_batch, rewards, done, info = env.step(actions)
266
+ del rewards, info
267
+ step_actions = extract_step_actions(env, observation_batch, next_observation_batch, actions)
268
+ next_summaries = summary_builder.build_all(env, next_observation_batch)
269
+
270
+ for district_id, buffer in windows.items():
271
+ buffer.controller_actions.extend(step_actions[district_id])
272
+ buffer.step_count += 1
273
+ should_emit = buffer.step_count >= district_interval or done
274
+ if not should_emit:
275
+ continue
276
+
277
+ end_summary = next_summaries[district_id]
278
+ window_data = DistrictWindowData(
279
+ district_id=district_id,
280
+ start_summary=buffer.start_summary,
281
+ end_summary=end_summary,
282
+ controller_actions=list(buffer.controller_actions),
283
+ step_count=buffer.step_count,
284
+ )
285
+ action = derive_district_action(
286
+ window_data=window_data,
287
+ max_target_intersections=max_target_intersections,
288
+ )
289
+ prompt = format_district_prompt(
290
+ buffer.start_summary,
291
+ max_target_intersections=max_target_intersections,
292
+ allow_only_visible_candidates=True,
293
+ )
294
+ samples.append(
295
+ {
296
+ "text": format_sft_text(
297
+ buffer.start_summary,
298
+ action,
299
+ max_target_intersections=max_target_intersections,
300
+ allow_only_visible_candidates=True,
301
+ ),
302
+ "prompt": prompt,
303
+ "response_json": action.to_dict(),
304
+ "state": buffer.start_summary.to_dict(),
305
+ "candidate_intersections": buffer.start_summary.to_dict().get("candidate_intersections", []),
306
+ "window_summary": window_data.to_dict(),
307
+ "city_id": env.city_id,
308
+ "district_id": district_id,
309
+ "district_type": env.districts[district_id].district_type,
310
+ "scenario": env.scenario_name,
311
+ "controller_type": teacher.metadata.controller_type,
312
+ "controller_id": teacher.metadata.controller_id,
313
+ "controller_family": teacher.metadata.controller_family,
314
+ "teacher_algorithm": teacher.metadata.teacher_algorithm,
315
+ "checkpoint_path": teacher.metadata.checkpoint_path,
316
+ "episode_index": int(episode_index),
317
+ "window_index": int(window_index),
318
+ "decision_interval": int(district_interval),
319
+ "sim_time": int(buffer.start_summary.sim_time),
320
+ }
321
+ )
322
+ windows[district_id] = _WindowBuffer(start_summary=end_summary)
323
+ window_index += 1
324
+
325
+ observation_batch = next_observation_batch
326
+
327
+ return samples
328
+
329
+
330
+ def append_jsonl(path: Path, records: list[dict[str, Any]], append: bool) -> None:
331
+ mode = "a" if append else "w"
332
+ path.parent.mkdir(parents=True, exist_ok=True)
333
+ with path.open(mode, encoding="utf-8") as handle:
334
+ for record in records:
335
+ handle.write(json.dumps(record, sort_keys=True))
336
+ handle.write("\n")
337
+
338
+
339
+ def main() -> None:
340
+ args = parse_args()
341
+ dataset = CityFlowDataset(
342
+ generated_root=args.generated_root,
343
+ splits_root=args.splits_root,
344
+ )
345
+ dataset.generate_default_splits()
346
+ rng = random.Random(args.seed)
347
+ teachers = resolve_teachers(args)
348
+ env_config = resolve_env_config(args, teachers)
349
+
350
+ output_path = Path(args.output)
351
+ write_mode_append = bool(args.append)
352
+
353
+ for episode_index in range(args.episodes):
354
+ scenario_spec = sample_scenario(
355
+ dataset=dataset,
356
+ rng=rng,
357
+ split=args.split,
358
+ city_id=args.city_id,
359
+ scenario_name=args.scenario_name,
360
+ )
361
+ episode_records: list[dict[str, Any]] = []
362
+ for teacher in teachers:
363
+ env = build_env(env_config=env_config, scenario_spec=scenario_spec)
364
+ episode_records.extend(
365
+ generate_examples_for_episode(
366
+ env=env,
367
+ teacher=teacher,
368
+ district_interval=args.district_decision_interval,
369
+ top_k_congested=args.top_k_congested,
370
+ max_candidate_intersections=args.max_candidate_intersections,
371
+ max_target_intersections=args.max_target_intersections,
372
+ episode_index=episode_index,
373
+ )
374
+ )
375
+ append_jsonl(output_path, episode_records, append=write_mode_append)
376
+ write_mode_append = True
377
+ print(
378
+ json.dumps(
379
+ {
380
+ "episode_index": episode_index,
381
+ "city_id": scenario_spec.city_id,
382
+ "scenario_name": scenario_spec.scenario_name,
383
+ "records_written": len(episode_records),
384
+ }
385
+ )
386
+ )
387
+
388
+
389
+ if __name__ == "__main__":
390
+ main()
district_llm/guided_control.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+
7
+ from district_llm.schema import DistrictAction
8
+
9
+
10
+ class DistrictGuidedLocalController:
11
+ """
12
+ Wrap a low-level controller and bias its actions with district directives.
13
+
14
+ The shared DQN still produces the base per-intersection action, and the
15
+ district plan only nudges hold/switch decisions toward the requested phase.
16
+ """
17
+
18
+ def __init__(self, base_teacher):
19
+ self.base_teacher = base_teacher
20
+
21
+ def act(
22
+ self,
23
+ observation_batch: dict[str, Any],
24
+ district_actions: dict[str, DistrictAction] | None = None,
25
+ ) -> np.ndarray:
26
+ base_actions = np.asarray(self.base_teacher.act(observation_batch), dtype=np.int64)
27
+ if not district_actions:
28
+ return base_actions
29
+
30
+ guided_actions = base_actions.copy()
31
+ for index, district_id in enumerate(observation_batch["district_ids"]):
32
+ directive = district_actions.get(district_id)
33
+ if directive is None:
34
+ continue
35
+ guided_actions[index] = self._apply_directive(
36
+ observation_batch=observation_batch,
37
+ index=index,
38
+ base_action=int(base_actions[index]),
39
+ directive=directive,
40
+ )
41
+ return guided_actions
42
+
43
+ @staticmethod
44
+ def _apply_directive(
45
+ observation_batch: dict[str, Any],
46
+ index: int,
47
+ base_action: int,
48
+ directive: DistrictAction,
49
+ ) -> int:
50
+ action_mask = observation_batch["action_mask"][index]
51
+ current_phase = int(observation_batch["current_phase"][index])
52
+ can_switch = bool(action_mask[1] > 0.0)
53
+
54
+ if directive.strategy == "hold" or directive.phase_bias == "NONE":
55
+ return int(base_action)
56
+
57
+ if directive.phase_bias == "NS":
58
+ if current_phase == 0:
59
+ return 0
60
+ return 1 if can_switch else 0
61
+
62
+ if directive.phase_bias == "EW":
63
+ if current_phase != 0:
64
+ return 0
65
+ return 1 if can_switch else 0
66
+
67
+ return int(base_action)
district_llm/heuristic_guidance.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from district_llm.repair import fallback_target_intersections
6
+ from district_llm.schema import DistrictAction, DistrictStateSummary
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class HeuristicGuidanceConfig:
11
+ max_target_intersections: int = 3
12
+ incident_duration_steps: int = 12
13
+ spillback_duration_steps: int = 10
14
+ default_duration_steps: int = 8
15
+
16
+
17
+ def generate_heuristic_guidance(
18
+ summary: DistrictStateSummary,
19
+ config: HeuristicGuidanceConfig | None = None,
20
+ ) -> DistrictAction:
21
+ config = config or HeuristicGuidanceConfig()
22
+
23
+ if summary.incident_flag or summary.construction_flag:
24
+ strategy = "incident_response"
25
+ priority_corridor = summary.dominant_flow if summary.dominant_flow in {"NS", "EW"} else "arterial"
26
+ phase_bias = summary.dominant_flow if summary.dominant_flow in {"NS", "EW"} else "NONE"
27
+ duration_steps = config.incident_duration_steps
28
+ elif summary.spillback_risk:
29
+ strategy = "clear_spillback"
30
+ boundary_share = summary.boundary_queue_total / max(1.0, summary.total_queue)
31
+ if boundary_share >= 0.45:
32
+ priority_corridor = "inbound"
33
+ elif summary.dominant_flow in {"NS", "EW"}:
34
+ priority_corridor = summary.dominant_flow
35
+ else:
36
+ priority_corridor = None
37
+ phase_bias = summary.dominant_flow if summary.dominant_flow in {"NS", "EW"} else "NONE"
38
+ duration_steps = config.spillback_duration_steps
39
+ elif summary.event_flag or summary.overload_flag:
40
+ strategy = "arterial_priority"
41
+ priority_corridor = summary.dominant_flow if summary.dominant_flow in {"NS", "EW"} else "arterial"
42
+ phase_bias = summary.dominant_flow if summary.dominant_flow in {"NS", "EW"} else "NONE"
43
+ duration_steps = config.spillback_duration_steps
44
+ elif summary.dominant_flow == "NS":
45
+ strategy = "favor_NS"
46
+ priority_corridor = "NS"
47
+ phase_bias = "NS"
48
+ duration_steps = config.default_duration_steps
49
+ elif summary.dominant_flow == "EW":
50
+ strategy = "favor_EW"
51
+ priority_corridor = "EW"
52
+ phase_bias = "EW"
53
+ duration_steps = config.default_duration_steps
54
+ else:
55
+ strategy = "hold"
56
+ priority_corridor = None
57
+ phase_bias = "NONE"
58
+ duration_steps = config.default_duration_steps
59
+
60
+ targets = fallback_target_intersections(
61
+ summary=summary,
62
+ max_target_intersections=config.max_target_intersections,
63
+ strategy=strategy,
64
+ priority_corridor=priority_corridor,
65
+ phase_bias=phase_bias,
66
+ )
67
+ return DistrictAction(
68
+ strategy=strategy,
69
+ priority_corridor=priority_corridor,
70
+ target_intersections=targets,
71
+ phase_bias=phase_bias,
72
+ duration_steps=duration_steps,
73
+ ).validate()
district_llm/inference.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, Callable
8
+
9
+ from district_llm.prompting import format_district_prompt
10
+ from district_llm.repair import RepairConfig, RepairReport, sanitize_action_payload
11
+ from district_llm.schema import DistrictAction, DistrictStateSummary
12
+ from district_llm.summary_builder import DistrictStateSummaryBuilder
13
+ from env.observation_builder import ObservationConfig
14
+ from env.reward import RewardConfig
15
+ from env.traffic_env import EnvConfig, TrafficEnv
16
+ from training.cityflow_dataset import CityFlowDataset
17
+
18
+
19
+ def _extract_json_object(payload: str) -> str:
20
+ start = payload.find("{")
21
+ end = payload.rfind("}")
22
+ if start == -1 or end == -1 or end <= start:
23
+ raise ValueError("No JSON object found in model output.")
24
+ return payload[start : end + 1]
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class DistrictLLMInferenceResult:
29
+ action: DistrictAction
30
+ raw_text: str
31
+ parsed_payload_before_repair: dict[str, Any]
32
+ repair_report: RepairReport
33
+ json_valid: bool
34
+ schema_valid_before_repair: bool
35
+
36
+
37
+ class DistrictLLMInference:
38
+ def __init__(
39
+ self,
40
+ generator_fn: Callable[[str], str] | None = None,
41
+ model_name_or_path: str | None = None,
42
+ device: str | None = None,
43
+ fallback_action: DistrictAction | None = None,
44
+ repair_config: RepairConfig | None = None,
45
+ ):
46
+ self.fallback_action = fallback_action or DistrictAction.default_hold()
47
+ self.generator_fn = generator_fn
48
+ self.repair_config = repair_config or RepairConfig()
49
+ self.tokenizer = None
50
+ self.model = None
51
+ self.device = device or "cpu"
52
+
53
+ if self.generator_fn is None:
54
+ if not model_name_or_path:
55
+ raise ValueError("Provide either generator_fn or model_name_or_path.")
56
+ import torch
57
+ from transformers import AutoModelForCausalLM, AutoTokenizer
58
+
59
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
60
+ model_dir = Path(model_name_or_path)
61
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
62
+ if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None:
63
+ self.tokenizer.pad_token = self.tokenizer.eos_token
64
+
65
+ if (model_dir / "adapter_config.json").exists():
66
+ try:
67
+ from peft import AutoPeftModelForCausalLM
68
+ except ImportError as exc:
69
+ raise ImportError("Loading a LoRA adapter requires the 'peft' package.") from exc
70
+ self.model = AutoPeftModelForCausalLM.from_pretrained(model_name_or_path).to(self.device)
71
+ else:
72
+ self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(self.device)
73
+ self.model.eval()
74
+
75
+ def generate_raw(self, prompt: str, max_new_tokens: int = 128) -> str:
76
+ if self.generator_fn is not None:
77
+ return self.generator_fn(prompt)
78
+ import torch
79
+
80
+ assert self.model is not None and self.tokenizer is not None
81
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
82
+ with torch.no_grad():
83
+ outputs = self.model.generate(
84
+ **inputs,
85
+ max_new_tokens=max_new_tokens,
86
+ do_sample=False,
87
+ pad_token_id=self.tokenizer.eos_token_id,
88
+ )
89
+ generated = outputs[0][inputs["input_ids"].shape[1] :]
90
+ return self.tokenizer.decode(generated, skip_special_tokens=True)
91
+
92
+ def parse_action(
93
+ self,
94
+ payload: str,
95
+ summary: DistrictStateSummary | None = None,
96
+ ) -> tuple[DistrictAction, RepairReport, dict[str, Any], bool, bool]:
97
+ json_valid = True
98
+ schema_valid_before_repair = True
99
+ try:
100
+ parsed_payload = json.loads(_extract_json_object(payload))
101
+ except Exception:
102
+ json_valid = False
103
+ schema_valid_before_repair = False
104
+ parsed_payload = self.fallback_action.to_dict()
105
+ action, repair_report = sanitize_action_payload(
106
+ payload=parsed_payload,
107
+ summary=summary,
108
+ config=self.repair_config,
109
+ )
110
+ return action, repair_report, parsed_payload, json_valid, schema_valid_before_repair
111
+
112
+ def predict_with_result(
113
+ self,
114
+ summary: DistrictStateSummary,
115
+ max_new_tokens: int = 128,
116
+ ) -> DistrictLLMInferenceResult:
117
+ prompt = format_district_prompt(
118
+ summary,
119
+ max_target_intersections=self.repair_config.max_target_intersections,
120
+ allow_only_visible_candidates=self.repair_config.allow_only_visible_candidates,
121
+ )
122
+ raw = self.generate_raw(prompt=prompt, max_new_tokens=max_new_tokens)
123
+ action, repair_report, parsed_payload, json_valid, schema_valid_before_repair = self.parse_action(
124
+ raw,
125
+ summary=summary,
126
+ )
127
+ return DistrictLLMInferenceResult(
128
+ action=action,
129
+ raw_text=raw,
130
+ parsed_payload_before_repair=parsed_payload,
131
+ repair_report=repair_report,
132
+ json_valid=json_valid,
133
+ schema_valid_before_repair=schema_valid_before_repair,
134
+ )
135
+
136
+ def predict(self, summary: DistrictStateSummary, max_new_tokens: int = 128) -> DistrictAction:
137
+ return self.predict_with_result(summary=summary, max_new_tokens=max_new_tokens).action
138
+
139
+
140
+ def parse_args() -> argparse.Namespace:
141
+ parser = argparse.ArgumentParser(description="Run single-sample district LLM inference.")
142
+ parser.add_argument("--model", required=True, help="Model name, local path, or LoRA adapter path.")
143
+ parser.add_argument("--generated-root", default="data/generated")
144
+ parser.add_argument("--splits-root", default="data/splits")
145
+ parser.add_argument("--city-id", required=True)
146
+ parser.add_argument("--scenario-name", required=True)
147
+ parser.add_argument("--district-id", required=True)
148
+ parser.add_argument("--device", default=None)
149
+ parser.add_argument("--max-new-tokens", type=int, default=128)
150
+ parser.add_argument(
151
+ "--allow-only-visible-candidates",
152
+ action=argparse.BooleanOptionalAction,
153
+ default=True,
154
+ )
155
+ parser.add_argument("--max-target-intersections", type=int, default=3)
156
+ parser.add_argument(
157
+ "--fallback-on-empty-targets",
158
+ action=argparse.BooleanOptionalAction,
159
+ default=True,
160
+ )
161
+ parser.add_argument(
162
+ "--fallback-mode",
163
+ choices=("heuristic", "hold", "none"),
164
+ default="heuristic",
165
+ )
166
+ return parser.parse_args()
167
+
168
+
169
+ def build_env(scenario_spec) -> TrafficEnv:
170
+ env_config = EnvConfig(
171
+ simulator_interval=1,
172
+ decision_interval=5,
173
+ min_green_time=10,
174
+ thread_num=1,
175
+ observation=ObservationConfig(),
176
+ reward=RewardConfig(variant="wait_queue_throughput"),
177
+ )
178
+ return TrafficEnv(
179
+ city_id=scenario_spec.city_id,
180
+ scenario_name=scenario_spec.scenario_name,
181
+ city_dir=scenario_spec.city_dir,
182
+ scenario_dir=scenario_spec.scenario_dir,
183
+ config_path=scenario_spec.config_path,
184
+ roadnet_path=scenario_spec.roadnet_path,
185
+ district_map_path=scenario_spec.district_map_path,
186
+ metadata_path=scenario_spec.metadata_path,
187
+ env_config=env_config,
188
+ )
189
+
190
+
191
+ def main() -> None:
192
+ args = parse_args()
193
+ dataset = CityFlowDataset(
194
+ generated_root=args.generated_root,
195
+ splits_root=args.splits_root,
196
+ )
197
+ scenario_spec = dataset.build_scenario_spec(args.city_id, args.scenario_name)
198
+ env = build_env(scenario_spec)
199
+ summary_builder = DistrictStateSummaryBuilder(candidate_limit=max(6, args.max_target_intersections))
200
+ observation_batch = env.reset()
201
+ summaries = summary_builder.build_all(env, observation_batch)
202
+ if args.district_id not in summaries:
203
+ raise ValueError(f"Unknown district_id '{args.district_id}' for {args.city_id}/{args.scenario_name}.")
204
+ inference = DistrictLLMInference(
205
+ model_name_or_path=args.model,
206
+ device=args.device,
207
+ fallback_action=DistrictAction.default_hold(),
208
+ repair_config=RepairConfig(
209
+ allow_only_visible_candidates=args.allow_only_visible_candidates,
210
+ max_target_intersections=args.max_target_intersections,
211
+ fallback_on_empty_targets=args.fallback_on_empty_targets,
212
+ fallback_mode=args.fallback_mode,
213
+ ),
214
+ )
215
+ action = inference.predict(
216
+ summary=summaries[args.district_id],
217
+ max_new_tokens=args.max_new_tokens,
218
+ )
219
+ print(action.to_pretty_json())
220
+
221
+
222
+ if __name__ == "__main__":
223
+ main()
district_llm/metrics.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+
6
+ def safe_ratio(numerator: int | float, denominator: int | float, default_if_empty: float = 0.0) -> float:
7
+ if denominator <= 0:
8
+ return default_if_empty
9
+ return float(numerator) / float(denominator)
10
+
11
+
12
+ def compute_target_metrics(pred_list: list[str], gt_list: list[str]) -> dict[str, float]:
13
+ pred = list(pred_list)
14
+ gt = list(gt_list)
15
+ pred_set = set(pred)
16
+ gt_set = set(gt)
17
+ overlap = pred_set & gt_set
18
+ union = pred_set | gt_set
19
+
20
+ both_empty = not pred_set and not gt_set
21
+ precision_default = 1.0 if both_empty else 0.0
22
+ recall_default = 1.0 if both_empty else 0.0
23
+ jaccard_default = 1.0 if both_empty else 0.0
24
+ overlap_rate_default = 1.0 if both_empty else 0.0
25
+
26
+ overlap_count = len(overlap)
27
+ return {
28
+ "exact_list_match": float(pred == gt),
29
+ "exact_set_match": float(pred_set == gt_set),
30
+ "overlap_count": float(overlap_count),
31
+ "overlap_rate": safe_ratio(overlap_count, len(gt_set), overlap_rate_default),
32
+ "precision": safe_ratio(overlap_count, len(pred_set), precision_default),
33
+ "recall": safe_ratio(overlap_count, len(gt_set), recall_default),
34
+ "jaccard": safe_ratio(overlap_count, len(union), jaccard_default),
35
+ "hit_at_1": float(overlap_count >= 1),
36
+ "hit_at_2": float(overlap_count >= 2),
37
+ "hit_at_3": float(overlap_count >= 3),
38
+ }
39
+
40
+
41
+ def aggregate_target_metrics(metric_rows: list[dict[str, float]]) -> dict[str, float]:
42
+ if not metric_rows:
43
+ return {}
44
+ keys = metric_rows[0].keys()
45
+ return {
46
+ key: float(sum(row[key] for row in metric_rows) / len(metric_rows))
47
+ for key in keys
48
+ }
49
+
50
+
51
+ def target_failure_buckets(
52
+ pred_list: list[str],
53
+ gt_list: list[str],
54
+ visible_candidates: set[str],
55
+ invalid_ids: list[str] | None = None,
56
+ non_visible_ids: list[str] | None = None,
57
+ repaired_targets: list[str] | None = None,
58
+ fallback_used: bool = False,
59
+ ) -> list[str]:
60
+ buckets: list[str] = []
61
+ pred_set = set(pred_list)
62
+ gt_set = set(gt_list)
63
+
64
+ if not pred_list:
65
+ buckets.append("prediction_empty")
66
+ if not gt_list:
67
+ buckets.append("ground_truth_empty")
68
+ if pred_list and gt_list and pred_set == gt_set and pred_list != gt_list:
69
+ buckets.append("same_set_different_order")
70
+ elif pred_set & gt_set:
71
+ buckets.append("partial_overlap")
72
+ elif pred_list and gt_list:
73
+ buckets.append("no_overlap")
74
+
75
+ if invalid_ids:
76
+ buckets.append("prediction_contains_invalid_ids")
77
+ if non_visible_ids:
78
+ buckets.append("prediction_contains_ids_not_visible_in_summary")
79
+ if pred_list and visible_candidates and any(item not in visible_candidates for item in pred_list):
80
+ buckets.append("prediction_contains_ids_not_visible_in_summary")
81
+ if fallback_used:
82
+ buckets.append("fallback_used")
83
+
84
+ if repaired_targets is not None:
85
+ repaired_set = set(repaired_targets)
86
+ if repaired_set == gt_set and pred_set != gt_set:
87
+ buckets.append("repaired_successfully")
88
+ elif (invalid_ids or non_visible_ids or fallback_used) and repaired_set != gt_set:
89
+ buckets.append("repair_failed")
90
+
91
+ return buckets
92
+
93
+
94
+ def average_item_rate(values: list[list[Any]]) -> float:
95
+ numerators = sum(len(item) for item in values)
96
+ denominators = sum(max(len(item), 1) for item in values)
97
+ return safe_ratio(numerators, denominators)
district_llm/prompting.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from district_llm.schema import DISTRICT_STRATEGIES, PHASE_BIASES, PRIORITY_CORRIDORS, DistrictAction, DistrictStateSummary
4
+
5
+
6
+ DEFAULT_MAX_TARGET_INTERSECTIONS = 3
7
+
8
+
9
+ def build_system_prompt(
10
+ max_target_intersections: int = DEFAULT_MAX_TARGET_INTERSECTIONS,
11
+ allow_only_visible_candidates: bool = True,
12
+ ) -> str:
13
+ candidate_rule = (
14
+ " If candidate_intersections is present, target_intersections must use only ids from that list."
15
+ if allow_only_visible_candidates
16
+ else ""
17
+ )
18
+ return (
19
+ "You are a district traffic coordinator for RL traffic lights. "
20
+ "Return only valid JSON with exactly these keys: "
21
+ "strategy, priority_corridor, target_intersections, phase_bias, duration_steps. "
22
+ f"target_intersections must be a JSON array with at most {int(max_target_intersections)} unique ids."
23
+ f"{candidate_rule} "
24
+ "Do not invent intersection ids. Deduplicate ids. If uncertain, prefer the most congested valid candidates."
25
+ )
26
+
27
+
28
+ def format_district_prompt(
29
+ summary: DistrictStateSummary,
30
+ max_target_intersections: int = DEFAULT_MAX_TARGET_INTERSECTIONS,
31
+ allow_only_visible_candidates: bool = True,
32
+ ) -> str:
33
+ target_rule = (
34
+ f"target_intersections: up to {int(max_target_intersections)} ids from candidate_intersections only"
35
+ if allow_only_visible_candidates
36
+ else f"target_intersections: up to {int(max_target_intersections)} valid ids"
37
+ )
38
+ return "\n".join(
39
+ [
40
+ "### DISTRICT ACTION SCHEMA",
41
+ f"strategy: {'|'.join(DISTRICT_STRATEGIES)}",
42
+ f"phase_bias: {'|'.join(PHASE_BIASES)}",
43
+ f"priority_corridor: {'|'.join(PRIORITY_CORRIDORS)}|none",
44
+ "duration_steps: integer 1..20",
45
+ target_rule,
46
+ "rules: return only valid JSON; do not invent ids; deduplicate target_intersections",
47
+ "fallback: if uncertain, prefer the most congested visible candidates",
48
+ "",
49
+ "### DISTRICT STATE",
50
+ summary.to_prompt_text(),
51
+ "",
52
+ "### DECISION",
53
+ ]
54
+ )
55
+
56
+
57
+ def format_sft_text(
58
+ summary: DistrictStateSummary,
59
+ action: DistrictAction,
60
+ max_target_intersections: int = DEFAULT_MAX_TARGET_INTERSECTIONS,
61
+ allow_only_visible_candidates: bool = True,
62
+ ) -> str:
63
+ return (
64
+ f"{format_district_prompt(summary, max_target_intersections=max_target_intersections, allow_only_visible_candidates=allow_only_visible_candidates)}\n"
65
+ f"{action.to_pretty_json()}"
66
+ )
district_llm/repair.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from dataclasses import dataclass, field
5
+ from typing import Any
6
+
7
+ from district_llm.schema import (
8
+ DISTRICT_STRATEGIES,
9
+ PHASE_BIASES,
10
+ PRIORITY_CORRIDORS,
11
+ CandidateIntersection,
12
+ DistrictAction,
13
+ DistrictStateSummary,
14
+ candidate_priority_score,
15
+ canonicalize_target_intersections,
16
+ )
17
+
18
+
19
+ INTERSECTION_ID_PATTERN = re.compile(r"\bi_\d+\b")
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class RepairConfig:
24
+ allow_only_visible_candidates: bool = True
25
+ max_target_intersections: int = 3
26
+ fallback_on_empty_targets: bool = True
27
+ fallback_mode: str = "heuristic"
28
+
29
+
30
+ @dataclass
31
+ class RepairReport:
32
+ raw_targets: list[str] = field(default_factory=list)
33
+ repaired_targets: list[str] = field(default_factory=list)
34
+ invalid_ids_removed: list[str] = field(default_factory=list)
35
+ non_visible_ids_removed: list[str] = field(default_factory=list)
36
+ deduplicated: bool = False
37
+ truncated: bool = False
38
+ fallback_used: bool = False
39
+ fallback_mode: str | None = None
40
+ empty_after_filtering: bool = False
41
+
42
+ def to_dict(self) -> dict[str, Any]:
43
+ return {
44
+ "raw_targets": list(self.raw_targets),
45
+ "repaired_targets": list(self.repaired_targets),
46
+ "invalid_ids_removed": list(self.invalid_ids_removed),
47
+ "non_visible_ids_removed": list(self.non_visible_ids_removed),
48
+ "deduplicated": bool(self.deduplicated),
49
+ "truncated": bool(self.truncated),
50
+ "fallback_used": bool(self.fallback_used),
51
+ "fallback_mode": self.fallback_mode,
52
+ "empty_after_filtering": bool(self.empty_after_filtering),
53
+ }
54
+
55
+
56
+ def normalize_candidate_intersections(
57
+ payload: list[CandidateIntersection | dict[str, Any]] | None,
58
+ ) -> list[dict[str, Any]]:
59
+ normalized: list[dict[str, Any]] = []
60
+ for item in payload or []:
61
+ if isinstance(item, CandidateIntersection):
62
+ normalized.append(item.to_dict())
63
+ elif isinstance(item, dict):
64
+ normalized.append(dict(item))
65
+ return normalized
66
+
67
+
68
+ def parse_candidate_intersections_from_text(text: str) -> list[dict[str, Any]]:
69
+ if "candidate_intersections:" not in text:
70
+ return []
71
+
72
+ candidates: list[dict[str, Any]] = []
73
+ capture = False
74
+ for line in text.splitlines():
75
+ stripped = line.strip()
76
+ if stripped == "candidate_intersections:":
77
+ capture = True
78
+ continue
79
+ if not capture:
80
+ continue
81
+ if stripped == "- none":
82
+ continue
83
+ if not stripped.startswith("- "):
84
+ if stripped.endswith(":"):
85
+ break
86
+ continue
87
+ fields = stripped[2:].split()
88
+ if not fields:
89
+ continue
90
+ candidate: dict[str, Any] = {
91
+ "intersection_id": fields[0],
92
+ "queue_total": 0.0,
93
+ "wait_total": 0.0,
94
+ "outgoing_load": 0.0,
95
+ "current_phase": 0,
96
+ "is_boundary": False,
97
+ "spillback_risk": False,
98
+ "incident_proximity": False,
99
+ "overload_marker": False,
100
+ "event_proximity": False,
101
+ "corridor_alignment": "BALANCED",
102
+ "selection_reasons": [],
103
+ }
104
+ for token in fields[1:]:
105
+ if "=" not in token:
106
+ continue
107
+ key, value = token.split("=", 1)
108
+ if key == "q":
109
+ candidate["queue_total"] = float(value)
110
+ elif key == "w":
111
+ candidate["wait_total"] = float(value)
112
+ elif key == "out":
113
+ candidate["outgoing_load"] = float(value)
114
+ elif key == "phase":
115
+ candidate["current_phase"] = int(value)
116
+ elif key == "boundary":
117
+ candidate["is_boundary"] = value == "1"
118
+ elif key == "spillback":
119
+ candidate["spillback_risk"] = value == "1"
120
+ elif key == "incident":
121
+ candidate["incident_proximity"] = value == "1"
122
+ elif key == "overload":
123
+ candidate["overload_marker"] = value == "1"
124
+ elif key == "event":
125
+ candidate["event_proximity"] = value == "1"
126
+ elif key == "align":
127
+ candidate["corridor_alignment"] = value
128
+ elif key == "reasons":
129
+ candidate["selection_reasons"] = [] if value == "none" else value.split("|")
130
+ candidates.append(candidate)
131
+ return normalized_candidate_intersections_from_dicts(candidates)
132
+
133
+
134
+ def normalized_candidate_intersections_from_dicts(
135
+ payload: list[dict[str, Any]],
136
+ ) -> list[dict[str, Any]]:
137
+ normalized: list[dict[str, Any]] = []
138
+ for item in payload:
139
+ try:
140
+ normalized.append(
141
+ CandidateIntersection(
142
+ intersection_id=str(item.get("intersection_id", "")).strip(),
143
+ queue_total=float(item.get("queue_total", 0.0)),
144
+ wait_total=float(item.get("wait_total", 0.0)),
145
+ outgoing_load=float(item.get("outgoing_load", 0.0)),
146
+ current_phase=int(item.get("current_phase", 0)),
147
+ is_boundary=bool(item.get("is_boundary", False)),
148
+ spillback_risk=bool(item.get("spillback_risk", False)),
149
+ incident_proximity=bool(item.get("incident_proximity", False)),
150
+ overload_marker=bool(item.get("overload_marker", False)),
151
+ event_proximity=bool(item.get("event_proximity", False)),
152
+ corridor_alignment=str(item.get("corridor_alignment", "BALANCED")),
153
+ selection_reasons=list(item.get("selection_reasons", [])),
154
+ ).to_dict()
155
+ )
156
+ except Exception:
157
+ continue
158
+ return normalized
159
+
160
+
161
+ def candidate_intersections_from_context(
162
+ summary: DistrictStateSummary | dict[str, Any] | None = None,
163
+ prompt_text: str | None = None,
164
+ ) -> list[dict[str, Any]]:
165
+ if isinstance(summary, DistrictStateSummary):
166
+ return normalize_candidate_intersections(summary.candidate_intersections)
167
+ if isinstance(summary, dict):
168
+ if "candidate_intersections" in summary:
169
+ return normalize_candidate_intersections(summary.get("candidate_intersections"))
170
+ state_payload = summary.get("state")
171
+ if isinstance(state_payload, dict) and "candidate_intersections" in state_payload:
172
+ return normalize_candidate_intersections(state_payload.get("candidate_intersections"))
173
+ if prompt_text:
174
+ return parse_candidate_intersections_from_text(prompt_text)
175
+ return []
176
+
177
+
178
+ def fallback_target_intersections(
179
+ summary: DistrictStateSummary | dict[str, Any] | None = None,
180
+ prompt_text: str | None = None,
181
+ max_target_intersections: int = 3,
182
+ strategy: str | None = None,
183
+ priority_corridor: str | None = None,
184
+ phase_bias: str | None = None,
185
+ focus_scores: dict[str, float] | None = None,
186
+ ) -> list[str]:
187
+ candidate_intersections = candidate_intersections_from_context(summary=summary, prompt_text=prompt_text)
188
+ if candidate_intersections:
189
+ ordered_candidates = sorted(
190
+ candidate_intersections,
191
+ key=lambda item: (
192
+ -(
193
+ candidate_priority_score(item)
194
+ + _focus_score_bonus(item, focus_scores)
195
+ + _strategy_target_bonus(
196
+ candidate=item,
197
+ strategy=strategy,
198
+ priority_corridor=priority_corridor,
199
+ phase_bias=phase_bias,
200
+ )
201
+ ),
202
+ -float(item.get("queue_total", 0.0)),
203
+ -float(item.get("wait_total", 0.0)),
204
+ -float(item.get("outgoing_load", 0.0)),
205
+ str(item.get("intersection_id", "")),
206
+ ),
207
+ )
208
+ ordered_ids = canonicalize_target_intersections(
209
+ [item["intersection_id"] for item in ordered_candidates],
210
+ ordered_candidates,
211
+ limit=max_target_intersections,
212
+ )
213
+ return ordered_ids[:max_target_intersections]
214
+
215
+ if isinstance(summary, DistrictStateSummary):
216
+ return [item.intersection_id for item in summary.top_congested_intersections[:max_target_intersections]]
217
+ if isinstance(summary, dict):
218
+ top_congested = summary.get("top_congested_intersections") or summary.get("state", {}).get("top_congested_intersections", [])
219
+ return [
220
+ str(item.get("intersection_id"))
221
+ for item in top_congested[:max_target_intersections]
222
+ if str(item.get("intersection_id", "")).strip()
223
+ ]
224
+ if prompt_text:
225
+ return list(dict.fromkeys(INTERSECTION_ID_PATTERN.findall(prompt_text)))[:max_target_intersections]
226
+ return []
227
+
228
+
229
+ def _focus_score_bonus(candidate: dict[str, Any], focus_scores: dict[str, float] | None) -> float:
230
+ if not focus_scores:
231
+ return 0.0
232
+ max_focus = max(max(focus_scores.values()), 1.0)
233
+ return 4.0 * float(focus_scores.get(str(candidate.get("intersection_id", "")), 0.0)) / max_focus
234
+
235
+
236
+ def _strategy_target_bonus(
237
+ candidate: dict[str, Any],
238
+ strategy: str | None,
239
+ priority_corridor: str | None,
240
+ phase_bias: str | None,
241
+ ) -> float:
242
+ reasons = set(candidate.get("selection_reasons", []))
243
+ corridor_alignment = str(candidate.get("corridor_alignment", "BALANCED"))
244
+ bonus = 0.0
245
+
246
+ if strategy == "incident_response":
247
+ bonus += 4.0 * float(bool(candidate.get("incident_proximity", False)))
248
+ elif strategy == "clear_spillback":
249
+ bonus += 4.0 * float(bool(candidate.get("spillback_risk", False)))
250
+ bonus += 1.0 * float(bool(candidate.get("is_boundary", False)))
251
+ elif strategy == "drain_inbound":
252
+ bonus += 4.0 * float(bool(candidate.get("is_boundary", False)))
253
+ bonus += 1.0 * float(bool(candidate.get("spillback_risk", False)))
254
+ elif strategy == "drain_outbound":
255
+ bonus += 4.0 * float("outgoing" in reasons)
256
+ bonus += 1.0 * float(bool(candidate.get("spillback_risk", False)))
257
+ elif strategy == "arterial_priority":
258
+ bonus += 2.0 * float(bool(candidate.get("is_boundary", False)))
259
+ bonus += 1.5 * float(bool(candidate.get("overload_marker", False)))
260
+ bonus += 1.5 * float(bool(candidate.get("event_proximity", False)))
261
+ elif strategy == "favor_NS":
262
+ bonus += 4.0 * float(corridor_alignment == "NS")
263
+ elif strategy == "favor_EW":
264
+ bonus += 4.0 * float(corridor_alignment == "EW")
265
+
266
+ if priority_corridor in {"NS", "EW"}:
267
+ bonus += 1.5 * float(corridor_alignment == priority_corridor)
268
+ elif priority_corridor == "inbound":
269
+ bonus += 1.5 * float(bool(candidate.get("is_boundary", False)))
270
+ elif priority_corridor == "outbound":
271
+ bonus += 1.5 * float("outgoing" in reasons)
272
+ elif priority_corridor == "arterial":
273
+ bonus += 0.75 * float(bool(candidate.get("is_boundary", False)))
274
+
275
+ if phase_bias in {"NS", "EW"}:
276
+ bonus += 0.5 * float(corridor_alignment == phase_bias)
277
+
278
+ return bonus
279
+
280
+
281
+ def extract_visible_candidate_ids(
282
+ summary: DistrictStateSummary | dict[str, Any] | None = None,
283
+ prompt_text: str | None = None,
284
+ ) -> list[str]:
285
+ candidate_intersections = candidate_intersections_from_context(summary=summary, prompt_text=prompt_text)
286
+ if candidate_intersections:
287
+ return [item["intersection_id"] for item in candidate_intersections]
288
+ if prompt_text:
289
+ return list(dict.fromkeys(INTERSECTION_ID_PATTERN.findall(prompt_text)))
290
+ return []
291
+
292
+
293
+ def sanitize_action_payload(
294
+ payload: dict[str, Any] | None,
295
+ summary: DistrictStateSummary | dict[str, Any] | None = None,
296
+ prompt_text: str | None = None,
297
+ config: RepairConfig | None = None,
298
+ ) -> tuple[DistrictAction, RepairReport]:
299
+ config = config or RepairConfig()
300
+ payload = dict(payload or {})
301
+ candidate_intersections = candidate_intersections_from_context(summary=summary, prompt_text=prompt_text)
302
+ visible_candidate_ids = [item["intersection_id"] for item in candidate_intersections]
303
+ visible_candidate_set = set(visible_candidate_ids)
304
+
305
+ raw_target_payload = payload.get("target_intersections", [])
306
+ if isinstance(raw_target_payload, str):
307
+ raw_targets = INTERSECTION_ID_PATTERN.findall(raw_target_payload)
308
+ elif isinstance(raw_target_payload, (list, tuple)):
309
+ raw_targets = [str(item).strip() for item in raw_target_payload if str(item).strip()]
310
+ else:
311
+ raw_targets = []
312
+
313
+ report = RepairReport(raw_targets=list(raw_targets))
314
+ deduped_targets: list[str] = []
315
+ seen: set[str] = set()
316
+ for item in raw_targets:
317
+ if item in seen:
318
+ report.deduplicated = True
319
+ continue
320
+ seen.add(item)
321
+ deduped_targets.append(item)
322
+
323
+ filtered_targets: list[str] = []
324
+ for item in deduped_targets:
325
+ if not INTERSECTION_ID_PATTERN.fullmatch(item):
326
+ report.invalid_ids_removed.append(item)
327
+ continue
328
+ if config.allow_only_visible_candidates and visible_candidate_set and item not in visible_candidate_set:
329
+ report.non_visible_ids_removed.append(item)
330
+ continue
331
+ filtered_targets.append(item)
332
+
333
+ if len(filtered_targets) > int(config.max_target_intersections):
334
+ report.truncated = True
335
+ filtered_targets = canonicalize_target_intersections(
336
+ filtered_targets,
337
+ candidate_intersections,
338
+ limit=int(config.max_target_intersections),
339
+ )
340
+
341
+ if not filtered_targets:
342
+ report.empty_after_filtering = bool(raw_targets)
343
+ if config.fallback_on_empty_targets:
344
+ report.fallback_used = True
345
+ report.fallback_mode = config.fallback_mode
346
+ if config.fallback_mode == "heuristic":
347
+ filtered_targets = fallback_target_intersections(
348
+ summary=summary,
349
+ prompt_text=prompt_text,
350
+ max_target_intersections=int(config.max_target_intersections),
351
+ )
352
+ elif config.fallback_mode == "hold":
353
+ filtered_targets = []
354
+ elif config.fallback_mode == "none":
355
+ filtered_targets = []
356
+ else:
357
+ raise ValueError(f"Unsupported fallback_mode '{config.fallback_mode}'.")
358
+
359
+ strategy = str(payload.get("strategy", "hold"))
360
+ if strategy not in DISTRICT_STRATEGIES:
361
+ strategy = "hold"
362
+
363
+ priority_corridor = payload.get("priority_corridor")
364
+ if priority_corridor is not None:
365
+ priority_corridor = str(priority_corridor)
366
+ if priority_corridor not in PRIORITY_CORRIDORS:
367
+ priority_corridor = None
368
+
369
+ phase_bias = str(payload.get("phase_bias", "NONE"))
370
+ if phase_bias not in PHASE_BIASES:
371
+ phase_bias = "NONE"
372
+
373
+ duration_steps_raw = payload.get("duration_steps", 1)
374
+ try:
375
+ duration_steps = int(duration_steps_raw)
376
+ except (TypeError, ValueError):
377
+ duration_steps = 1
378
+ duration_steps = max(1, min(duration_steps, 20))
379
+
380
+ if config.fallback_mode == "hold" and report.fallback_used and not filtered_targets:
381
+ action = DistrictAction.default_hold(duration_steps=duration_steps)
382
+ else:
383
+ action = DistrictAction(
384
+ strategy=strategy,
385
+ priority_corridor=priority_corridor,
386
+ target_intersections=filtered_targets,
387
+ phase_bias=phase_bias,
388
+ duration_steps=duration_steps,
389
+ ).validate()
390
+
391
+ report.repaired_targets = list(action.target_intersections)
392
+ return action, report
district_llm/rl_guidance_wrapper.py ADDED
@@ -0,0 +1,1004 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import asdict, dataclass, field, replace
5
+ import hashlib
6
+ from time import perf_counter
7
+ from typing import Any
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from district_llm.heuristic_guidance import (
13
+ HeuristicGuidanceConfig,
14
+ generate_heuristic_guidance,
15
+ )
16
+ from district_llm.inference import DistrictLLMInference, DistrictLLMInferenceResult
17
+ from district_llm.repair import RepairReport
18
+ from district_llm.schema import CandidateIntersection, DistrictAction, DistrictStateSummary
19
+ from district_llm.summary_builder import DistrictStateSummaryBuilder
20
+ from district_llm.teachers import RLCheckpointTeacher
21
+
22
+
23
+ WRAPPER_MODES: tuple[str, ...] = (
24
+ "no_op",
25
+ "target_only_soft",
26
+ "target_only_medium",
27
+ "corridor_soft",
28
+ "global_soft",
29
+ "current_legacy",
30
+ )
31
+ FALLBACK_POLICIES: tuple[str, ...] = (
32
+ "no_op",
33
+ "hold_previous",
34
+ "heuristic_weak",
35
+ )
36
+ GATING_MODES: tuple[str, ...] = (
37
+ "always_on",
38
+ "incident_or_spillback",
39
+ "queue_threshold",
40
+ "imbalance_threshold",
41
+ "queue_or_imbalance",
42
+ "combined",
43
+ )
44
+ BIAS_DECAY_SCHEDULES: tuple[str, ...] = (
45
+ "linear",
46
+ )
47
+ STRATEGY_BIAS_MULTIPLIERS: dict[str, float] = {
48
+ "hold": 0.0,
49
+ "favor_NS": 1.0,
50
+ "favor_EW": 1.0,
51
+ "drain_inbound": 1.05,
52
+ "drain_outbound": 1.05,
53
+ "clear_spillback": 1.1,
54
+ "incident_response": 1.15,
55
+ "arterial_priority": 1.05,
56
+ }
57
+
58
+
59
+ @dataclass(frozen=True)
60
+ class GuidanceInfluenceConfig:
61
+ """
62
+ Conservative inference-time wrapper around the fixed DQN policy.
63
+
64
+ The DQN checkpoint remains unchanged. Guidance is treated as a weak prior
65
+ and only biases Q-values slightly before greedy action selection.
66
+ """
67
+
68
+ wrapper_mode: str = "target_only_soft"
69
+ bias_strength: float = 0.12
70
+ target_only_bias_strength: float = 0.18
71
+ corridor_bias_strength: float = 0.05
72
+ max_intersections_affected: int = 3
73
+ guidance_refresh_steps: int = 5
74
+ guidance_persistence_steps: int = 3
75
+ max_guidance_duration: int = 6
76
+ apply_global_bias: bool = False
77
+ apply_target_only: bool = True
78
+ gating_mode: str = "always_on"
79
+ min_avg_queue_for_guidance: float = 150.0
80
+ min_queue_imbalance_for_guidance: float = 20.0
81
+ require_incident_or_spillback: bool = False
82
+ allow_guidance_in_normal_conditions: bool = True
83
+ enable_bias_decay: bool = True
84
+ bias_decay_schedule: str = "linear"
85
+ fallback_policy: str = "hold_previous"
86
+ log_guidance_debug: bool = False
87
+ max_debug_chars: int = 240
88
+
89
+ def validate(self) -> "GuidanceInfluenceConfig":
90
+ if self.wrapper_mode not in WRAPPER_MODES:
91
+ raise ValueError(
92
+ f"Unsupported wrapper_mode '{self.wrapper_mode}'. Expected one of {WRAPPER_MODES}."
93
+ )
94
+ if self.fallback_policy not in FALLBACK_POLICIES:
95
+ raise ValueError(
96
+ f"Unsupported fallback_policy '{self.fallback_policy}'. Expected one of {FALLBACK_POLICIES}."
97
+ )
98
+ if self.gating_mode not in GATING_MODES:
99
+ raise ValueError(
100
+ f"Unsupported gating_mode '{self.gating_mode}'. Expected one of {GATING_MODES}."
101
+ )
102
+ if self.bias_decay_schedule not in BIAS_DECAY_SCHEDULES:
103
+ raise ValueError(
104
+ f"Unsupported bias_decay_schedule '{self.bias_decay_schedule}'. "
105
+ f"Expected one of {BIAS_DECAY_SCHEDULES}."
106
+ )
107
+ if self.guidance_refresh_steps < 1:
108
+ raise ValueError("guidance_refresh_steps must be at least 1.")
109
+ if self.guidance_persistence_steps < 1:
110
+ raise ValueError("guidance_persistence_steps must be at least 1.")
111
+ if self.max_guidance_duration < 1:
112
+ raise ValueError("max_guidance_duration must be at least 1.")
113
+ if self.max_intersections_affected < 1:
114
+ raise ValueError("max_intersections_affected must be at least 1.")
115
+ return self
116
+
117
+
118
+ @dataclass(frozen=True)
119
+ class RLPolicyDecision:
120
+ q_values: np.ndarray
121
+ actions: np.ndarray
122
+
123
+
124
+ @dataclass
125
+ class GuidanceDecision:
126
+ source: str
127
+ action: DistrictAction
128
+ runtime_seconds: float
129
+ raw_text: str | None = None
130
+ parsed_payload_before_repair: dict[str, Any] | None = None
131
+ repair_report: RepairReport | None = None
132
+ json_valid: bool = True
133
+ schema_valid_before_repair: bool = True
134
+ provider_error: str | None = None
135
+ fallback_policy_applied: str | None = None
136
+
137
+ @property
138
+ def repair_applied(self) -> bool:
139
+ report = self.repair_report
140
+ if report is None:
141
+ return False
142
+ return any(
143
+ (
144
+ report.invalid_ids_removed,
145
+ report.non_visible_ids_removed,
146
+ report.deduplicated,
147
+ report.truncated,
148
+ report.fallback_used,
149
+ report.empty_after_filtering,
150
+ )
151
+ )
152
+
153
+ @property
154
+ def invalid_before_repair(self) -> bool:
155
+ report = self.repair_report
156
+ if self.provider_error:
157
+ return True
158
+ if not self.json_valid or not self.schema_valid_before_repair:
159
+ return True
160
+ if report is None:
161
+ return False
162
+ return bool(
163
+ report.invalid_ids_removed
164
+ or report.non_visible_ids_removed
165
+ or report.empty_after_filtering
166
+ )
167
+
168
+ def to_trace_payload(self) -> dict[str, Any]:
169
+ return {
170
+ "source": self.source,
171
+ "runtime_seconds": float(self.runtime_seconds),
172
+ "action": self.action.to_dict(),
173
+ "raw_text": self.raw_text,
174
+ "parsed_payload_before_repair": self.parsed_payload_before_repair,
175
+ "repair_report": None if self.repair_report is None else self.repair_report.to_dict(),
176
+ "json_valid": bool(self.json_valid),
177
+ "schema_valid_before_repair": bool(self.schema_valid_before_repair),
178
+ "repair_applied": bool(self.repair_applied),
179
+ "invalid_before_repair": bool(self.invalid_before_repair),
180
+ "provider_error": self.provider_error,
181
+ "fallback_policy_applied": self.fallback_policy_applied,
182
+ }
183
+
184
+
185
+ @dataclass(frozen=True)
186
+ class GuidanceApplicationPlan:
187
+ wrapper_mode: str
188
+ scope: str
189
+ affected_intersections: tuple[str, ...]
190
+ targeted_intersections: tuple[str, ...]
191
+ target_candidate_ids: tuple[str, ...]
192
+ priority_direction: str | None
193
+ strength_scale: float
194
+ base_bias_strength: float
195
+ target_bias_strength: float
196
+ corridor_bias_strength: float
197
+ apply_global_bias: bool
198
+ apply_target_only: bool
199
+ max_intersections_affected: int
200
+
201
+ def to_dict(self) -> dict[str, Any]:
202
+ return {
203
+ "wrapper_mode": self.wrapper_mode,
204
+ "scope": self.scope,
205
+ "affected_intersections": list(self.affected_intersections),
206
+ "targeted_intersections": list(self.targeted_intersections),
207
+ "target_candidate_ids": list(self.target_candidate_ids),
208
+ "priority_direction": self.priority_direction,
209
+ "strength_scale": float(self.strength_scale),
210
+ "base_bias_strength": float(self.base_bias_strength),
211
+ "target_bias_strength": float(self.target_bias_strength),
212
+ "corridor_bias_strength": float(self.corridor_bias_strength),
213
+ "apply_global_bias": bool(self.apply_global_bias),
214
+ "apply_target_only": bool(self.apply_target_only),
215
+ "max_intersections_affected": int(self.max_intersections_affected),
216
+ }
217
+
218
+
219
+ @dataclass
220
+ class ActiveDistrictGuidance:
221
+ district_id: str
222
+ summary: DistrictStateSummary
223
+ decision: GuidanceDecision
224
+ application_plan: GuidanceApplicationPlan
225
+ generated_step: int
226
+ expires_step: int
227
+ fallback_used: bool = False
228
+
229
+
230
+ @dataclass(frozen=True)
231
+ class GuidanceGateDecision:
232
+ allowed: bool
233
+ gating_mode: str
234
+ triggered_conditions: tuple[str, ...]
235
+ blocked_reasons: tuple[str, ...]
236
+ avg_queue: float
237
+ queue_imbalance: float
238
+ incident_flag: bool
239
+ spillback_risk: bool
240
+ overload_flag: bool
241
+
242
+ def to_dict(self) -> dict[str, Any]:
243
+ return {
244
+ "allowed": bool(self.allowed),
245
+ "gating_mode": self.gating_mode,
246
+ "triggered_conditions": list(self.triggered_conditions),
247
+ "blocked_reasons": list(self.blocked_reasons),
248
+ "avg_queue": float(self.avg_queue),
249
+ "queue_imbalance": float(self.queue_imbalance),
250
+ "incident_flag": bool(self.incident_flag),
251
+ "spillback_risk": bool(self.spillback_risk),
252
+ "overload_flag": bool(self.overload_flag),
253
+ }
254
+
255
+
256
+ @dataclass
257
+ class GuidanceRefreshTrace:
258
+ mode_source: str
259
+ district_id: str
260
+ decision_step: int
261
+ summary_hash: str
262
+ summary_excerpt: str
263
+ summary_payload: dict[str, Any]
264
+ guidance: dict[str, Any]
265
+ repaired_guidance: dict[str, Any]
266
+ fallback_used: bool
267
+ fallback_policy: str
268
+ application_plan: dict[str, Any]
269
+ applied_biases: dict[str, float]
270
+ gate_decision: dict[str, Any] | None = None
271
+
272
+ def to_dict(self) -> dict[str, Any]:
273
+ return {
274
+ "mode_source": self.mode_source,
275
+ "district_id": self.district_id,
276
+ "decision_step": int(self.decision_step),
277
+ "summary_hash": self.summary_hash,
278
+ "summary_excerpt": self.summary_excerpt,
279
+ "summary": self.summary_payload,
280
+ "raw_guidance": self.guidance,
281
+ "repaired_guidance": self.repaired_guidance,
282
+ "fallback_used": bool(self.fallback_used),
283
+ "fallback_policy": self.fallback_policy,
284
+ "application_plan": self.application_plan,
285
+ "applied_biases": self.applied_biases,
286
+ "gate_decision": self.gate_decision,
287
+ }
288
+
289
+
290
+ @dataclass
291
+ class WrapperEpisodeStats:
292
+ step_count: int = 0
293
+ steps_with_active_guidance: int = 0
294
+ guidance_refresh_count: int = 0
295
+ guidance_blocked_step_count: int = 0
296
+ guidance_blocked_refresh_count: int = 0
297
+ bias_application_count: int = 0
298
+ noop_guidance_events: int = 0
299
+ fallback_event_count: int = 0
300
+ total_affected_intersections: int = 0
301
+ total_targeted_intersections: int = 0
302
+ total_bias_magnitude: float = 0.0
303
+ max_bias_magnitude: float = 0.0
304
+
305
+ def to_dict(self) -> dict[str, float]:
306
+ refresh_count = max(1, self.guidance_refresh_count)
307
+ return {
308
+ "num_guidance_refreshes": float(self.guidance_refresh_count),
309
+ "num_steps_guidance_blocked_by_gate": float(self.guidance_blocked_step_count),
310
+ "num_guidance_refreshes_blocked_by_gate": float(self.guidance_blocked_refresh_count),
311
+ "num_bias_applications": float(self.bias_application_count),
312
+ "num_noop_guidance_events": float(self.noop_guidance_events),
313
+ "fallback_policy_used_count": float(self.fallback_event_count),
314
+ "avg_num_affected_intersections": float(self.total_affected_intersections) / float(refresh_count),
315
+ "avg_num_targeted_intersections": float(self.total_targeted_intersections) / float(refresh_count),
316
+ "mean_bias_magnitude": float(self.total_bias_magnitude) / float(max(1, self.bias_application_count)),
317
+ "max_bias_magnitude": float(self.max_bias_magnitude),
318
+ "percent_steps_with_active_guidance": float(self.steps_with_active_guidance)
319
+ / float(max(1, self.step_count)),
320
+ }
321
+
322
+
323
+ @dataclass
324
+ class GuidedActionBatch:
325
+ actions: np.ndarray
326
+ base_actions: np.ndarray
327
+ base_q_values: np.ndarray
328
+ guided_q_values: np.ndarray
329
+ q_bias: np.ndarray
330
+ refresh_traces: list[GuidanceRefreshTrace] = field(default_factory=list)
331
+ runtime_seconds: float = 0.0
332
+
333
+
334
+ class BaseGuidanceProvider(ABC):
335
+ source_name: str
336
+
337
+ @abstractmethod
338
+ def generate(self, summary: DistrictStateSummary) -> GuidanceDecision:
339
+ raise NotImplementedError
340
+
341
+
342
+ class HeuristicGuidanceProvider(BaseGuidanceProvider):
343
+ source_name = "heuristic"
344
+
345
+ def __init__(self, config: HeuristicGuidanceConfig | None = None):
346
+ self.config = config or HeuristicGuidanceConfig()
347
+
348
+ def generate(self, summary: DistrictStateSummary) -> GuidanceDecision:
349
+ started = perf_counter()
350
+ action = generate_heuristic_guidance(summary=summary, config=self.config)
351
+ return GuidanceDecision(
352
+ source=self.source_name,
353
+ action=action,
354
+ runtime_seconds=perf_counter() - started,
355
+ parsed_payload_before_repair=action.to_dict(),
356
+ )
357
+
358
+
359
+ class LLMGuidanceProvider(BaseGuidanceProvider):
360
+ source_name = "llm"
361
+
362
+ def __init__(self, inference: DistrictLLMInference, max_new_tokens: int = 128):
363
+ self.inference = inference
364
+ self.max_new_tokens = int(max_new_tokens)
365
+
366
+ def generate(self, summary: DistrictStateSummary) -> GuidanceDecision:
367
+ started = perf_counter()
368
+ result: DistrictLLMInferenceResult = self.inference.predict_with_result(
369
+ summary=summary,
370
+ max_new_tokens=self.max_new_tokens,
371
+ )
372
+ return GuidanceDecision(
373
+ source=self.source_name,
374
+ action=result.action,
375
+ runtime_seconds=perf_counter() - started,
376
+ raw_text=result.raw_text,
377
+ parsed_payload_before_repair=result.parsed_payload_before_repair,
378
+ repair_report=result.repair_report,
379
+ json_valid=result.json_valid,
380
+ schema_valid_before_repair=result.schema_valid_before_repair,
381
+ )
382
+
383
+
384
+ class FixedRLPolicyAdapter:
385
+ def __init__(self, checkpoint_path: str, device: str | None = None):
386
+ self.teacher = RLCheckpointTeacher(checkpoint_path=checkpoint_path, device=device)
387
+ self.device = self.teacher.device
388
+
389
+ @property
390
+ def env_config(self) -> Any | None:
391
+ return self.teacher.env_config
392
+
393
+ def decide(self, observation_batch: dict[str, Any]) -> RLPolicyDecision:
394
+ raw_obs = observation_batch["observations"].astype(np.float32)
395
+ normalized_obs = (
396
+ self.teacher.obs_normalizer.normalize(raw_obs)
397
+ if self.teacher.obs_normalizer is not None
398
+ else raw_obs
399
+ )
400
+ obs_tensor = torch.as_tensor(normalized_obs, dtype=torch.float32, device=self.device)
401
+ district_type_tensor = torch.as_tensor(
402
+ observation_batch["district_type_indices"],
403
+ dtype=torch.int64,
404
+ device=self.device,
405
+ )
406
+ action_mask_tensor = torch.as_tensor(
407
+ observation_batch["action_mask"],
408
+ dtype=torch.float32,
409
+ device=self.device,
410
+ )
411
+ with torch.no_grad():
412
+ q_values = self.teacher.model.forward(
413
+ observations=obs_tensor,
414
+ district_type_indices=district_type_tensor,
415
+ action_mask=action_mask_tensor,
416
+ )
417
+ q_values_np = q_values.detach().cpu().numpy().astype(np.float32)
418
+ return RLPolicyDecision(q_values=q_values_np, actions=q_values_np.argmax(axis=1).astype(np.int64))
419
+
420
+
421
+ class DistrictGuidedRLController:
422
+ def __init__(
423
+ self,
424
+ policy: FixedRLPolicyAdapter,
425
+ mode_source: str,
426
+ summary_builder: DistrictStateSummaryBuilder | None = None,
427
+ guidance_provider: BaseGuidanceProvider | None = None,
428
+ influence_config: GuidanceInfluenceConfig | None = None,
429
+ heuristic_provider: BaseGuidanceProvider | None = None,
430
+ ):
431
+ self.policy = policy
432
+ self.mode_source = mode_source
433
+ self.summary_builder = summary_builder
434
+ self.guidance_provider = guidance_provider
435
+ self.influence_config = (influence_config or GuidanceInfluenceConfig()).validate()
436
+ self.heuristic_provider = heuristic_provider
437
+ self._active_guidance: dict[str, ActiveDistrictGuidance] = {}
438
+ self._next_refresh_step_by_district: dict[str, int] = {}
439
+ self._episode_stats = WrapperEpisodeStats()
440
+
441
+ def reset(self) -> None:
442
+ self._active_guidance = {}
443
+ self._next_refresh_step_by_district = {}
444
+ self._episode_stats = WrapperEpisodeStats()
445
+ if self.summary_builder is not None:
446
+ self.summary_builder.reset()
447
+
448
+ def active_guidance_snapshot(self) -> dict[str, dict[str, Any]]:
449
+ return {
450
+ district_id: active.decision.action.to_dict()
451
+ for district_id, active in sorted(self._active_guidance.items())
452
+ }
453
+
454
+ def episode_debug_summary(self) -> dict[str, Any]:
455
+ payload = self._episode_stats.to_dict()
456
+ payload.update(
457
+ {
458
+ "wrapper_mode": self.influence_config.wrapper_mode,
459
+ "fallback_policy": self.influence_config.fallback_policy,
460
+ }
461
+ )
462
+ return payload
463
+
464
+ def act(self, env, observation_batch: dict[str, Any]) -> GuidedActionBatch:
465
+ started = perf_counter()
466
+ base_decision = self.policy.decide(observation_batch)
467
+ base_q_values = base_decision.q_values
468
+ guided_q_values = base_q_values.copy()
469
+ q_bias = np.zeros_like(guided_q_values, dtype=np.float32)
470
+
471
+ refresh_traces = self._refresh_guidance_if_needed(env=env, observation_batch=observation_batch)
472
+ if self.guidance_provider is None:
473
+ self._episode_stats.step_count += 1
474
+ return GuidedActionBatch(
475
+ actions=base_decision.actions.copy(),
476
+ base_actions=base_decision.actions,
477
+ base_q_values=base_q_values,
478
+ guided_q_values=guided_q_values,
479
+ q_bias=q_bias,
480
+ refresh_traces=refresh_traces,
481
+ runtime_seconds=perf_counter() - started,
482
+ )
483
+
484
+ active_any = False
485
+ decision_step = int(observation_batch.get("decision_step", 0))
486
+ candidate_lookup_by_district = {
487
+ district_id: {
488
+ item.intersection_id: item
489
+ for item in active.summary.candidate_intersections
490
+ }
491
+ for district_id, active in self._active_guidance.items()
492
+ }
493
+ for row_index, intersection_id in enumerate(observation_batch["intersection_ids"]):
494
+ district_id = str(observation_batch["district_ids"][row_index])
495
+ active = self._active_guidance.get(district_id)
496
+ if active is None:
497
+ continue
498
+ active_any = True
499
+ candidate = candidate_lookup_by_district[district_id].get(str(intersection_id))
500
+ row_bias = self._row_action_bias(
501
+ active=active,
502
+ candidate=candidate,
503
+ intersection_id=str(intersection_id),
504
+ current_phase=int(observation_batch["current_phase"][row_index]),
505
+ decision_step=decision_step,
506
+ )
507
+ if row_bias is None:
508
+ continue
509
+ q_bias[row_index] = row_bias
510
+ guided_q_values[row_index] = guided_q_values[row_index] + row_bias
511
+ magnitude = float(np.abs(row_bias).max())
512
+ self._episode_stats.bias_application_count += 1
513
+ self._episode_stats.total_bias_magnitude += magnitude
514
+ self._episode_stats.max_bias_magnitude = max(self._episode_stats.max_bias_magnitude, magnitude)
515
+
516
+ self._episode_stats.step_count += 1
517
+ if active_any:
518
+ self._episode_stats.steps_with_active_guidance += 1
519
+ actions = guided_q_values.argmax(axis=1).astype(np.int64)
520
+ return GuidedActionBatch(
521
+ actions=actions,
522
+ base_actions=base_decision.actions,
523
+ base_q_values=base_q_values,
524
+ guided_q_values=guided_q_values,
525
+ q_bias=q_bias,
526
+ refresh_traces=refresh_traces,
527
+ runtime_seconds=perf_counter() - started,
528
+ )
529
+
530
+ def _refresh_guidance_if_needed(
531
+ self,
532
+ env,
533
+ observation_batch: dict[str, Any],
534
+ ) -> list[GuidanceRefreshTrace]:
535
+ if self.guidance_provider is None or self.summary_builder is None:
536
+ return []
537
+
538
+ decision_step = int(observation_batch.get("decision_step", 0))
539
+ due_districts = [
540
+ district_id
541
+ for district_id in tuple(sorted(env.districts))
542
+ if self._district_requires_refresh(district_id=district_id, decision_step=decision_step)
543
+ ]
544
+ if not due_districts:
545
+ return []
546
+
547
+ summaries = self.summary_builder.build_all(env, observation_batch)
548
+ refresh_traces: list[GuidanceRefreshTrace] = []
549
+ gate_blocked_this_step = False
550
+ for district_id in due_districts:
551
+ summary = summaries[district_id]
552
+ previous_active = self._active_guidance.get(district_id)
553
+ gate_decision = _evaluate_guidance_gate(summary=summary, config=self.influence_config)
554
+ if not gate_decision.allowed:
555
+ gate_blocked_this_step = True
556
+ self._active_guidance.pop(district_id, None)
557
+ self._next_refresh_step_by_district[district_id] = (
558
+ decision_step + self._resolve_blocked_refresh_horizon()
559
+ )
560
+ self._episode_stats.guidance_refresh_count += 1
561
+ self._episode_stats.guidance_blocked_refresh_count += 1
562
+ self._episode_stats.noop_guidance_events += 1
563
+ decision = GuidanceDecision(
564
+ source=f"{self.mode_source}_gate_blocked",
565
+ action=DistrictAction.default_hold(),
566
+ runtime_seconds=0.0,
567
+ fallback_policy_applied="gate_blocked",
568
+ )
569
+ application_plan = _build_application_plan(
570
+ summary=summary,
571
+ action=decision.action,
572
+ config=replace(self.influence_config, wrapper_mode="no_op"),
573
+ district_intersection_ids=tuple(env.districts[district_id].intersection_ids),
574
+ )
575
+ trace = GuidanceRefreshTrace(
576
+ mode_source=self.mode_source,
577
+ district_id=district_id,
578
+ decision_step=decision_step,
579
+ summary_hash=_summary_hash(summary),
580
+ summary_excerpt=summary.to_prompt_text()[:240],
581
+ summary_payload=summary.to_dict(),
582
+ guidance=decision.to_trace_payload(),
583
+ repaired_guidance=decision.action.to_dict(),
584
+ fallback_used=False,
585
+ fallback_policy="gate_blocked",
586
+ application_plan=application_plan.to_dict(),
587
+ applied_biases={
588
+ "base": 0.0,
589
+ "target": 0.0,
590
+ "corridor": 0.0,
591
+ "strength_scale": 0.0,
592
+ },
593
+ gate_decision=gate_decision.to_dict(),
594
+ )
595
+ refresh_traces.append(trace)
596
+ if self.influence_config.log_guidance_debug:
597
+ _log_guidance_debug(trace)
598
+ continue
599
+
600
+ decision, fallback_used = self._generate_guidance(
601
+ district_id=district_id,
602
+ summary=summary,
603
+ previous_active=previous_active,
604
+ )
605
+ application_plan = _build_application_plan(
606
+ summary=summary,
607
+ action=decision.action,
608
+ config=self.influence_config,
609
+ district_intersection_ids=tuple(env.districts[district_id].intersection_ids),
610
+ )
611
+ expires_step = decision_step + self._resolve_refresh_horizon(decision.action)
612
+ active = ActiveDistrictGuidance(
613
+ district_id=district_id,
614
+ summary=summary,
615
+ decision=decision,
616
+ application_plan=application_plan,
617
+ generated_step=decision_step,
618
+ expires_step=expires_step,
619
+ fallback_used=fallback_used,
620
+ )
621
+ self._active_guidance[district_id] = active
622
+ self._next_refresh_step_by_district[district_id] = int(expires_step)
623
+
624
+ self._episode_stats.guidance_refresh_count += 1
625
+ self._episode_stats.total_affected_intersections += len(application_plan.affected_intersections)
626
+ self._episode_stats.total_targeted_intersections += len(application_plan.targeted_intersections)
627
+ if application_plan.wrapper_mode == "no_op" or not application_plan.affected_intersections:
628
+ self._episode_stats.noop_guidance_events += 1
629
+ if fallback_used:
630
+ self._episode_stats.fallback_event_count += 1
631
+
632
+ trace = GuidanceRefreshTrace(
633
+ mode_source=self.mode_source,
634
+ district_id=district_id,
635
+ decision_step=decision_step,
636
+ summary_hash=_summary_hash(summary),
637
+ summary_excerpt=summary.to_prompt_text()[:240],
638
+ summary_payload=summary.to_dict(),
639
+ guidance=decision.to_trace_payload(),
640
+ repaired_guidance=decision.action.to_dict(),
641
+ fallback_used=fallback_used,
642
+ fallback_policy=self.influence_config.fallback_policy if fallback_used else "none",
643
+ application_plan=application_plan.to_dict(),
644
+ applied_biases={
645
+ "base": float(application_plan.base_bias_strength),
646
+ "target": float(application_plan.target_bias_strength),
647
+ "corridor": float(application_plan.corridor_bias_strength),
648
+ "strength_scale": float(application_plan.strength_scale),
649
+ },
650
+ gate_decision=gate_decision.to_dict(),
651
+ )
652
+ refresh_traces.append(trace)
653
+ if self.influence_config.log_guidance_debug:
654
+ _log_guidance_debug(trace)
655
+ if gate_blocked_this_step:
656
+ self._episode_stats.guidance_blocked_step_count += 1
657
+ return refresh_traces
658
+
659
+ def _generate_guidance(
660
+ self,
661
+ district_id: str,
662
+ summary: DistrictStateSummary,
663
+ previous_active: ActiveDistrictGuidance | None,
664
+ ) -> tuple[GuidanceDecision, bool]:
665
+ fallback_used = False
666
+ try:
667
+ decision = self.guidance_provider.generate(summary)
668
+ except Exception as exc:
669
+ decision = GuidanceDecision(
670
+ source=self.guidance_provider.source_name,
671
+ action=DistrictAction.default_hold(),
672
+ runtime_seconds=0.0,
673
+ provider_error=str(exc),
674
+ json_valid=False,
675
+ schema_valid_before_repair=False,
676
+ )
677
+ if not _should_fallback(decision):
678
+ return decision, fallback_used
679
+
680
+ fallback_used = True
681
+ fallback_policy = self.influence_config.fallback_policy
682
+ if fallback_policy == "hold_previous" and previous_active is not None:
683
+ fallback_decision = GuidanceDecision(
684
+ source=f"{decision.source}_fallback_hold_previous",
685
+ action=previous_active.decision.action,
686
+ runtime_seconds=decision.runtime_seconds,
687
+ raw_text=decision.raw_text,
688
+ parsed_payload_before_repair=decision.parsed_payload_before_repair,
689
+ repair_report=decision.repair_report,
690
+ json_valid=decision.json_valid,
691
+ schema_valid_before_repair=decision.schema_valid_before_repair,
692
+ provider_error=decision.provider_error,
693
+ fallback_policy_applied=fallback_policy,
694
+ )
695
+ return fallback_decision, fallback_used
696
+
697
+ if fallback_policy == "heuristic_weak" and self.heuristic_provider is not None:
698
+ fallback_decision = self.heuristic_provider.generate(summary)
699
+ fallback_decision.fallback_policy_applied = fallback_policy
700
+ return fallback_decision, fallback_used
701
+
702
+ fallback_decision = GuidanceDecision(
703
+ source=f"{decision.source}_fallback_no_op",
704
+ action=DistrictAction.default_hold(),
705
+ runtime_seconds=decision.runtime_seconds,
706
+ raw_text=decision.raw_text,
707
+ parsed_payload_before_repair=decision.parsed_payload_before_repair,
708
+ repair_report=decision.repair_report,
709
+ json_valid=decision.json_valid,
710
+ schema_valid_before_repair=decision.schema_valid_before_repair,
711
+ provider_error=decision.provider_error,
712
+ fallback_policy_applied=fallback_policy,
713
+ )
714
+ return fallback_decision, fallback_used
715
+
716
+ def _district_requires_refresh(self, district_id: str, decision_step: int) -> bool:
717
+ next_refresh_step = self._next_refresh_step_by_district.get(district_id)
718
+ if next_refresh_step is None:
719
+ return True
720
+ return decision_step >= int(next_refresh_step)
721
+
722
+ def _resolve_refresh_horizon(self, action: DistrictAction) -> int:
723
+ requested = max(1, min(int(action.duration_steps), self.influence_config.max_guidance_duration))
724
+ return min(
725
+ requested,
726
+ int(self.influence_config.guidance_refresh_steps),
727
+ int(self.influence_config.guidance_persistence_steps),
728
+ )
729
+
730
+ def _resolve_blocked_refresh_horizon(self) -> int:
731
+ return max(
732
+ 1,
733
+ min(
734
+ int(self.influence_config.guidance_refresh_steps),
735
+ int(self.influence_config.guidance_persistence_steps),
736
+ ),
737
+ )
738
+
739
+ def _row_action_bias(
740
+ self,
741
+ active: ActiveDistrictGuidance,
742
+ candidate: CandidateIntersection | None,
743
+ intersection_id: str,
744
+ current_phase: int,
745
+ decision_step: int,
746
+ ) -> np.ndarray | None:
747
+ plan = active.application_plan
748
+ if plan.wrapper_mode == "no_op":
749
+ return None
750
+ if intersection_id not in set(plan.affected_intersections):
751
+ return None
752
+
753
+ preferred_action = _preferred_action_for_direction(
754
+ direction=plan.priority_direction,
755
+ current_phase=current_phase,
756
+ )
757
+ if preferred_action is None:
758
+ return None
759
+
760
+ decay = 1.0
761
+ if self.influence_config.enable_bias_decay:
762
+ horizon = max(1, active.expires_step - active.generated_step)
763
+ age = max(0, decision_step - active.generated_step)
764
+ if self.influence_config.bias_decay_schedule == "linear":
765
+ decay = max(0.25, 1.0 - (float(age) / float(horizon)))
766
+
767
+ magnitude = plan.base_bias_strength * plan.strength_scale * decay
768
+ if intersection_id in set(plan.targeted_intersections):
769
+ magnitude += plan.target_bias_strength * plan.strength_scale * decay
770
+ if candidate is not None and plan.priority_direction in {"NS", "EW"}:
771
+ if candidate.corridor_alignment == plan.priority_direction:
772
+ magnitude += plan.corridor_bias_strength * plan.strength_scale * decay
773
+ if candidate.is_boundary and plan.scope in {"corridor_local", "global"}:
774
+ magnitude += 0.5 * plan.corridor_bias_strength * plan.strength_scale * decay
775
+
776
+ strategy_multiplier = STRATEGY_BIAS_MULTIPLIERS.get(active.decision.action.strategy, 1.0)
777
+ magnitude *= strategy_multiplier
778
+ if magnitude <= 0.0:
779
+ return None
780
+
781
+ bias = np.zeros(2, dtype=np.float32)
782
+ bias[preferred_action] += float(magnitude)
783
+ return bias
784
+
785
+
786
+ def _build_application_plan(
787
+ summary: DistrictStateSummary,
788
+ action: DistrictAction,
789
+ config: GuidanceInfluenceConfig,
790
+ district_intersection_ids: tuple[str, ...],
791
+ ) -> GuidanceApplicationPlan:
792
+ wrapper_mode = config.wrapper_mode
793
+ target_ids = tuple(
794
+ intersection_id
795
+ for intersection_id in action.target_intersections
796
+ if intersection_id in {item.intersection_id for item in summary.candidate_intersections}
797
+ )
798
+ candidate_lookup = {
799
+ item.intersection_id: item
800
+ for item in summary.candidate_intersections
801
+ }
802
+ priority_direction = _resolve_guidance_direction(action=action, summary=summary)
803
+ if wrapper_mode == "no_op":
804
+ return GuidanceApplicationPlan(
805
+ wrapper_mode=wrapper_mode,
806
+ scope="none",
807
+ affected_intersections=(),
808
+ targeted_intersections=target_ids,
809
+ target_candidate_ids=tuple(candidate_lookup),
810
+ priority_direction=priority_direction,
811
+ strength_scale=0.0,
812
+ base_bias_strength=0.0,
813
+ target_bias_strength=0.0,
814
+ corridor_bias_strength=0.0,
815
+ apply_global_bias=False,
816
+ apply_target_only=True,
817
+ max_intersections_affected=0,
818
+ )
819
+
820
+ if wrapper_mode == "current_legacy":
821
+ affected = tuple(district_intersection_ids)
822
+ return GuidanceApplicationPlan(
823
+ wrapper_mode=wrapper_mode,
824
+ scope="global",
825
+ affected_intersections=affected,
826
+ targeted_intersections=target_ids,
827
+ target_candidate_ids=tuple(candidate_lookup),
828
+ priority_direction=priority_direction,
829
+ strength_scale=1.0,
830
+ base_bias_strength=float(max(config.bias_strength, 0.75)),
831
+ target_bias_strength=float(max(config.target_only_bias_strength, 1.25)),
832
+ corridor_bias_strength=float(max(config.corridor_bias_strength, 0.5)),
833
+ apply_global_bias=True,
834
+ apply_target_only=False,
835
+ max_intersections_affected=max(len(affected), config.max_intersections_affected),
836
+ )
837
+
838
+ if wrapper_mode in {"target_only_soft", "target_only_medium"}:
839
+ strength_scale = 0.5 if wrapper_mode == "target_only_soft" else 1.0
840
+ affected = target_ids[: config.max_intersections_affected]
841
+ return GuidanceApplicationPlan(
842
+ wrapper_mode=wrapper_mode,
843
+ scope="targeted",
844
+ affected_intersections=affected,
845
+ targeted_intersections=target_ids,
846
+ target_candidate_ids=tuple(candidate_lookup),
847
+ priority_direction=priority_direction,
848
+ strength_scale=strength_scale,
849
+ base_bias_strength=float(config.bias_strength),
850
+ target_bias_strength=float(config.target_only_bias_strength),
851
+ corridor_bias_strength=float(config.corridor_bias_strength),
852
+ apply_global_bias=False,
853
+ apply_target_only=True,
854
+ max_intersections_affected=config.max_intersections_affected,
855
+ )
856
+
857
+ if wrapper_mode == "corridor_soft":
858
+ ranked = list(target_ids)
859
+ extras = [
860
+ item.intersection_id
861
+ for item in summary.candidate_intersections
862
+ if item.intersection_id not in ranked
863
+ and item.is_boundary
864
+ and (priority_direction is None or item.corridor_alignment == priority_direction)
865
+ ]
866
+ affected = tuple((ranked + extras)[: config.max_intersections_affected])
867
+ return GuidanceApplicationPlan(
868
+ wrapper_mode=wrapper_mode,
869
+ scope="corridor_local",
870
+ affected_intersections=affected,
871
+ targeted_intersections=target_ids,
872
+ target_candidate_ids=tuple(candidate_lookup),
873
+ priority_direction=priority_direction,
874
+ strength_scale=0.6,
875
+ base_bias_strength=float(config.bias_strength),
876
+ target_bias_strength=float(config.target_only_bias_strength),
877
+ corridor_bias_strength=float(config.corridor_bias_strength),
878
+ apply_global_bias=False,
879
+ apply_target_only=False,
880
+ max_intersections_affected=config.max_intersections_affected,
881
+ )
882
+
883
+ affected_global = tuple(district_intersection_ids)
884
+ return GuidanceApplicationPlan(
885
+ wrapper_mode="global_soft",
886
+ scope="global",
887
+ affected_intersections=affected_global,
888
+ targeted_intersections=target_ids,
889
+ target_candidate_ids=tuple(candidate_lookup),
890
+ priority_direction=priority_direction,
891
+ strength_scale=0.35,
892
+ base_bias_strength=float(config.bias_strength),
893
+ target_bias_strength=float(config.target_only_bias_strength),
894
+ corridor_bias_strength=float(config.corridor_bias_strength),
895
+ apply_global_bias=True,
896
+ apply_target_only=False,
897
+ max_intersections_affected=config.max_intersections_affected,
898
+ )
899
+
900
+
901
+ def _should_fallback(decision: GuidanceDecision) -> bool:
902
+ if decision.provider_error is not None:
903
+ return True
904
+ if not decision.json_valid or not decision.schema_valid_before_repair:
905
+ return True
906
+ report = decision.repair_report
907
+ if report is None:
908
+ return False
909
+ return bool(
910
+ report.fallback_used
911
+ or report.empty_after_filtering
912
+ )
913
+
914
+
915
+ def _evaluate_guidance_gate(
916
+ summary: DistrictStateSummary,
917
+ config: GuidanceInfluenceConfig,
918
+ ) -> GuidanceGateDecision:
919
+ queue_imbalance = abs(float(summary.ns_queue) - float(summary.ew_queue))
920
+ queue_trigger = float(summary.avg_queue) >= float(config.min_avg_queue_for_guidance)
921
+ imbalance_trigger = queue_imbalance >= float(config.min_queue_imbalance_for_guidance)
922
+ incident_or_spillback = bool(summary.incident_flag or summary.spillback_risk or summary.overload_flag)
923
+ triggers = {
924
+ "incident_or_spillback": incident_or_spillback,
925
+ "queue_threshold": queue_trigger,
926
+ "imbalance_threshold": imbalance_trigger,
927
+ }
928
+ triggered_conditions = tuple(name for name, active in triggers.items() if active)
929
+
930
+ if config.gating_mode == "always_on":
931
+ allowed = True
932
+ elif config.gating_mode == "incident_or_spillback":
933
+ allowed = incident_or_spillback
934
+ elif config.gating_mode == "queue_threshold":
935
+ allowed = queue_trigger
936
+ elif config.gating_mode == "imbalance_threshold":
937
+ allowed = imbalance_trigger
938
+ elif config.gating_mode == "queue_or_imbalance":
939
+ allowed = queue_trigger or imbalance_trigger
940
+ else:
941
+ allowed = incident_or_spillback or queue_trigger or imbalance_trigger
942
+
943
+ blocked_reasons: list[str] = []
944
+ if config.require_incident_or_spillback and not incident_or_spillback:
945
+ allowed = False
946
+ blocked_reasons.append("requires_incident_or_spillback")
947
+ if not config.allow_guidance_in_normal_conditions and not triggered_conditions:
948
+ allowed = False
949
+ blocked_reasons.append("normal_conditions_blocked")
950
+ if not allowed and not blocked_reasons:
951
+ blocked_reasons.append(f"gating_mode:{config.gating_mode}")
952
+
953
+ return GuidanceGateDecision(
954
+ allowed=allowed,
955
+ gating_mode=config.gating_mode,
956
+ triggered_conditions=triggered_conditions,
957
+ blocked_reasons=tuple(blocked_reasons),
958
+ avg_queue=float(summary.avg_queue),
959
+ queue_imbalance=float(queue_imbalance),
960
+ incident_flag=bool(summary.incident_flag),
961
+ spillback_risk=bool(summary.spillback_risk),
962
+ overload_flag=bool(summary.overload_flag),
963
+ )
964
+
965
+
966
+ def _resolve_guidance_direction(action: DistrictAction, summary: DistrictStateSummary) -> str | None:
967
+ if action.phase_bias in {"NS", "EW"}:
968
+ return action.phase_bias
969
+ if action.priority_corridor in {"NS", "EW"}:
970
+ return action.priority_corridor
971
+ if summary.dominant_flow in {"NS", "EW"}:
972
+ return summary.dominant_flow
973
+ return None
974
+
975
+
976
+ def _preferred_action_for_direction(direction: str | None, current_phase: int) -> int | None:
977
+ if direction == "NS":
978
+ return 0 if current_phase == 0 else 1
979
+ if direction == "EW":
980
+ return 0 if current_phase != 0 else 1
981
+ return None
982
+
983
+
984
+ def _summary_hash(summary: DistrictStateSummary) -> str:
985
+ return hashlib.sha1(summary.to_json().encode("utf-8")).hexdigest()[:16]
986
+
987
+
988
+ def guidance_config_payload(config: GuidanceInfluenceConfig) -> dict[str, Any]:
989
+ return asdict(config.validate())
990
+
991
+
992
+ def _log_guidance_debug(trace: GuidanceRefreshTrace) -> None:
993
+ print(
994
+ "[guidance-debug] "
995
+ f"mode={trace.mode_source} "
996
+ f"district={trace.district_id} "
997
+ f"wrapper_mode={trace.application_plan['wrapper_mode']} "
998
+ f"gate_allowed={trace.gate_decision.get('allowed') if trace.gate_decision else True} "
999
+ f"scope={trace.application_plan['scope']} "
1000
+ f"targets={trace.repaired_guidance.get('target_intersections', [])} "
1001
+ f"affected={trace.application_plan['affected_intersections']} "
1002
+ f"fallback_used={trace.fallback_used} "
1003
+ f"fallback_policy={trace.fallback_policy}"
1004
+ )
district_llm/schema.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass, field
5
+ from typing import Any
6
+
7
+
8
+ DISTRICT_STRATEGIES: tuple[str, ...] = (
9
+ "hold",
10
+ "favor_NS",
11
+ "favor_EW",
12
+ "drain_inbound",
13
+ "drain_outbound",
14
+ "clear_spillback",
15
+ "incident_response",
16
+ "arterial_priority",
17
+ )
18
+ PHASE_BIASES: tuple[str, ...] = ("NONE", "NS", "EW")
19
+ PRIORITY_CORRIDORS: tuple[str, ...] = (
20
+ "NS",
21
+ "EW",
22
+ "inbound",
23
+ "outbound",
24
+ "arterial",
25
+ )
26
+ DOMINANT_FLOWS: tuple[str, ...] = ("NS", "EW", "BALANCED")
27
+ CANDIDATE_REASON_TAGS: tuple[str, ...] = (
28
+ "congested",
29
+ "boundary",
30
+ "spillback",
31
+ "incident",
32
+ "outgoing",
33
+ "overload",
34
+ "event",
35
+ )
36
+
37
+
38
+ def _round_float(value: float, digits: int = 3) -> float:
39
+ return round(float(value), digits)
40
+
41
+
42
+ def _dedupe_string_list(values: list[str] | tuple[str, ...] | None, limit: int | None = None) -> list[str]:
43
+ normalized: list[str] = []
44
+ seen: set[str] = set()
45
+ for item in values or []:
46
+ value = str(item).strip()
47
+ if not value or value in seen:
48
+ continue
49
+ normalized.append(value)
50
+ seen.add(value)
51
+ if limit is not None and len(normalized) >= limit:
52
+ break
53
+ return normalized
54
+
55
+
56
+ def _stable_reason_list(values: list[str] | tuple[str, ...] | None) -> list[str]:
57
+ present = {str(item).strip() for item in (values or []) if str(item).strip()}
58
+ return [item for item in CANDIDATE_REASON_TAGS if item in present]
59
+
60
+
61
+ def candidate_priority_score(candidate: "CandidateIntersection | dict[str, Any]") -> float:
62
+ item = candidate.to_dict() if hasattr(candidate, "to_dict") else dict(candidate)
63
+ queue_total = float(item.get("queue_total", 0.0))
64
+ wait_total = float(item.get("wait_total", 0.0))
65
+ outgoing_load = float(item.get("outgoing_load", 0.0))
66
+ score = queue_total + 1.5 * wait_total + 0.5 * outgoing_load
67
+ score += 2.0 * float(bool(item.get("spillback_risk", False)))
68
+ score += 1.5 * float(bool(item.get("incident_proximity", False)))
69
+ score += 1.0 * float(bool(item.get("is_boundary", False)))
70
+ score += 0.75 * float(bool(item.get("event_proximity", False)))
71
+ score += 0.75 * float(bool(item.get("overload_marker", False)))
72
+ return score
73
+
74
+
75
+ def candidate_priority_tuple(candidate: "CandidateIntersection | dict[str, Any]") -> tuple[float, float, float, float, str]:
76
+ item = candidate.to_dict() if hasattr(candidate, "to_dict") else dict(candidate)
77
+ return (
78
+ candidate_priority_score(item),
79
+ float(item.get("queue_total", 0.0)),
80
+ float(item.get("wait_total", 0.0)),
81
+ float(item.get("outgoing_load", 0.0)),
82
+ str(item.get("intersection_id", "")),
83
+ )
84
+
85
+
86
+ def canonicalize_target_intersections(
87
+ targets: list[str] | tuple[str, ...] | None,
88
+ candidates: list["CandidateIntersection | dict[str, Any]"] | None = None,
89
+ limit: int | None = None,
90
+ ) -> list[str]:
91
+ normalized = _dedupe_string_list(targets, limit=None)
92
+ if not candidates:
93
+ return normalized[:limit] if limit is not None else normalized
94
+
95
+ candidate_order = {
96
+ str(candidate.to_dict()["intersection_id"] if hasattr(candidate, "to_dict") else candidate["intersection_id"]): (
97
+ -candidate_priority_tuple(candidate)[0],
98
+ -candidate_priority_tuple(candidate)[1],
99
+ -candidate_priority_tuple(candidate)[2],
100
+ -candidate_priority_tuple(candidate)[3],
101
+ candidate_priority_tuple(candidate)[4],
102
+ )
103
+ for candidate in candidates
104
+ }
105
+ normalized.sort(key=lambda item: candidate_order.get(item, (1.0, 1.0, 1.0, 1.0, item)))
106
+ if limit is not None:
107
+ normalized = normalized[:limit]
108
+ return normalized
109
+
110
+
111
+ @dataclass
112
+ class CongestedIntersection:
113
+ intersection_id: str
114
+ queue_total: float
115
+ wait_total: float
116
+ outgoing_load: float
117
+ current_phase: int
118
+ is_boundary: bool
119
+
120
+ def to_dict(self) -> dict[str, Any]:
121
+ return {
122
+ "intersection_id": self.intersection_id,
123
+ "queue_total": _round_float(self.queue_total),
124
+ "wait_total": _round_float(self.wait_total),
125
+ "outgoing_load": _round_float(self.outgoing_load),
126
+ "current_phase": int(self.current_phase),
127
+ "is_boundary": bool(self.is_boundary),
128
+ }
129
+
130
+ def to_prompt_line(self) -> str:
131
+ return (
132
+ f"- {self.intersection_id} "
133
+ f"q={self.queue_total:.2f} "
134
+ f"w={self.wait_total:.2f} "
135
+ f"out={self.outgoing_load:.2f} "
136
+ f"phase={self.current_phase} "
137
+ f"boundary={int(self.is_boundary)}"
138
+ )
139
+
140
+
141
+ @dataclass
142
+ class CandidateIntersection:
143
+ intersection_id: str
144
+ queue_total: float
145
+ wait_total: float
146
+ outgoing_load: float
147
+ current_phase: int
148
+ is_boundary: bool
149
+ spillback_risk: bool = False
150
+ incident_proximity: bool = False
151
+ overload_marker: bool = False
152
+ event_proximity: bool = False
153
+ corridor_alignment: str = "BALANCED"
154
+ selection_reasons: list[str] = field(default_factory=list)
155
+
156
+ def validate(self) -> "CandidateIntersection":
157
+ if self.corridor_alignment not in DOMINANT_FLOWS:
158
+ raise ValueError(
159
+ f"Invalid corridor_alignment '{self.corridor_alignment}'. Expected one of {DOMINANT_FLOWS}."
160
+ )
161
+ self.selection_reasons = _stable_reason_list(self.selection_reasons)
162
+ return self
163
+
164
+ def to_dict(self) -> dict[str, Any]:
165
+ self.validate()
166
+ return {
167
+ "intersection_id": self.intersection_id,
168
+ "queue_total": _round_float(self.queue_total),
169
+ "wait_total": _round_float(self.wait_total),
170
+ "outgoing_load": _round_float(self.outgoing_load),
171
+ "current_phase": int(self.current_phase),
172
+ "is_boundary": bool(self.is_boundary),
173
+ "spillback_risk": bool(self.spillback_risk),
174
+ "incident_proximity": bool(self.incident_proximity),
175
+ "overload_marker": bool(self.overload_marker),
176
+ "event_proximity": bool(self.event_proximity),
177
+ "corridor_alignment": self.corridor_alignment,
178
+ "selection_reasons": list(self.selection_reasons),
179
+ }
180
+
181
+ def to_prompt_line(self) -> str:
182
+ self.validate()
183
+ reasons = "|".join(self.selection_reasons) if self.selection_reasons else "none"
184
+ return (
185
+ f"- {self.intersection_id} "
186
+ f"q={self.queue_total:.2f} "
187
+ f"w={self.wait_total:.2f} "
188
+ f"out={self.outgoing_load:.2f} "
189
+ f"phase={self.current_phase} "
190
+ f"boundary={int(self.is_boundary)} "
191
+ f"spillback={int(self.spillback_risk)} "
192
+ f"incident={int(self.incident_proximity)} "
193
+ f"overload={int(self.overload_marker)} "
194
+ f"event={int(self.event_proximity)} "
195
+ f"align={self.corridor_alignment} "
196
+ f"reasons={reasons}"
197
+ )
198
+
199
+
200
+ @dataclass
201
+ class DistrictAction:
202
+ strategy: str = "hold"
203
+ priority_corridor: str | None = None
204
+ target_intersections: list[str] = field(default_factory=list)
205
+ phase_bias: str = "NONE"
206
+ duration_steps: int = 1
207
+
208
+ def validate(self) -> "DistrictAction":
209
+ if self.strategy not in DISTRICT_STRATEGIES:
210
+ raise ValueError(
211
+ f"Invalid strategy '{self.strategy}'. Expected one of {DISTRICT_STRATEGIES}."
212
+ )
213
+ if self.priority_corridor is not None and self.priority_corridor not in PRIORITY_CORRIDORS:
214
+ raise ValueError(
215
+ f"Invalid priority_corridor '{self.priority_corridor}'. "
216
+ f"Expected one of {PRIORITY_CORRIDORS} or None."
217
+ )
218
+ if self.phase_bias not in PHASE_BIASES:
219
+ raise ValueError(
220
+ f"Invalid phase_bias '{self.phase_bias}'. Expected one of {PHASE_BIASES}."
221
+ )
222
+ if not isinstance(self.duration_steps, int):
223
+ raise ValueError("duration_steps must be an integer.")
224
+ if not 1 <= self.duration_steps <= 20:
225
+ raise ValueError("duration_steps must be between 1 and 20.")
226
+ self.target_intersections = _dedupe_string_list(self.target_intersections, limit=8)
227
+ return self
228
+
229
+ @classmethod
230
+ def default_hold(cls, duration_steps: int = 1) -> "DistrictAction":
231
+ return cls(
232
+ strategy="hold",
233
+ priority_corridor=None,
234
+ target_intersections=[],
235
+ phase_bias="NONE",
236
+ duration_steps=max(1, min(int(duration_steps), 20)),
237
+ )
238
+
239
+ @classmethod
240
+ def from_dict(cls, payload: dict[str, Any]) -> "DistrictAction":
241
+ return cls(
242
+ strategy=str(payload.get("strategy", "hold")),
243
+ priority_corridor=payload.get("priority_corridor"),
244
+ target_intersections=list(payload.get("target_intersections", [])),
245
+ phase_bias=str(payload.get("phase_bias", "NONE")),
246
+ duration_steps=int(payload.get("duration_steps", 1)),
247
+ ).validate()
248
+
249
+ @classmethod
250
+ def from_json(cls, payload: str) -> "DistrictAction":
251
+ return cls.from_dict(json.loads(payload))
252
+
253
+ def to_dict(self) -> dict[str, Any]:
254
+ self.validate()
255
+ return {
256
+ "strategy": self.strategy,
257
+ "priority_corridor": self.priority_corridor,
258
+ "target_intersections": list(self.target_intersections),
259
+ "phase_bias": self.phase_bias,
260
+ "duration_steps": int(self.duration_steps),
261
+ }
262
+
263
+ def to_json(self) -> str:
264
+ return json.dumps(self.to_dict(), sort_keys=True, separators=(",", ":"))
265
+
266
+ def to_pretty_json(self) -> str:
267
+ return json.dumps(self.to_dict(), sort_keys=True, indent=2)
268
+
269
+ def to_rl_context(self) -> dict[str, Any]:
270
+ payload = self.to_dict()
271
+ payload["district_strategy"] = payload.pop("strategy")
272
+ payload["district_duration_steps"] = payload.pop("duration_steps")
273
+ return payload
274
+
275
+
276
+ @dataclass
277
+ class DistrictStateSummary:
278
+ city_id: str
279
+ district_id: str
280
+ district_type: str
281
+ scenario_name: str
282
+ scenario_type: str
283
+ decision_step: int
284
+ sim_time: int
285
+ intersection_count: int
286
+ avg_queue: float
287
+ max_queue: float
288
+ total_queue: float
289
+ avg_wait: float
290
+ max_wait: float
291
+ total_wait: float
292
+ avg_outgoing_load: float
293
+ max_outgoing_load: float
294
+ total_outgoing_load: float
295
+ recent_throughput: float
296
+ queue_change: float
297
+ wait_change: float
298
+ throughput_change: float
299
+ ns_queue: float
300
+ ew_queue: float
301
+ ns_wait: float
302
+ ew_wait: float
303
+ dominant_flow: str
304
+ boundary_queue_total: float
305
+ boundary_wait_total: float
306
+ spillback_risk: bool
307
+ incident_flag: bool
308
+ construction_flag: bool
309
+ overload_flag: bool
310
+ event_flag: bool
311
+ top_congested_intersections: list[CongestedIntersection] = field(default_factory=list)
312
+ candidate_intersections: list[CandidateIntersection] = field(default_factory=list)
313
+
314
+ def validate(self) -> "DistrictStateSummary":
315
+ if self.dominant_flow not in DOMINANT_FLOWS:
316
+ raise ValueError(
317
+ f"Invalid dominant_flow '{self.dominant_flow}'. Expected one of {DOMINANT_FLOWS}."
318
+ )
319
+ self.top_congested_intersections = list(self.top_congested_intersections[:5])
320
+ self.candidate_intersections = list(self.candidate_intersections[:8])
321
+ return self
322
+
323
+ def candidate_ids(self) -> list[str]:
324
+ self.validate()
325
+ return [item.intersection_id for item in self.candidate_intersections]
326
+
327
+ def candidate_lookup(self) -> dict[str, CandidateIntersection]:
328
+ self.validate()
329
+ return {
330
+ item.intersection_id: item
331
+ for item in self.candidate_intersections
332
+ }
333
+
334
+ def to_dict(self) -> dict[str, Any]:
335
+ self.validate()
336
+ return {
337
+ "city_id": self.city_id,
338
+ "district_id": self.district_id,
339
+ "district_type": self.district_type,
340
+ "scenario_name": self.scenario_name,
341
+ "scenario_type": self.scenario_type,
342
+ "decision_step": int(self.decision_step),
343
+ "sim_time": int(self.sim_time),
344
+ "intersection_count": int(self.intersection_count),
345
+ "avg_queue": _round_float(self.avg_queue),
346
+ "max_queue": _round_float(self.max_queue),
347
+ "total_queue": _round_float(self.total_queue),
348
+ "avg_wait": _round_float(self.avg_wait),
349
+ "max_wait": _round_float(self.max_wait),
350
+ "total_wait": _round_float(self.total_wait),
351
+ "avg_outgoing_load": _round_float(self.avg_outgoing_load),
352
+ "max_outgoing_load": _round_float(self.max_outgoing_load),
353
+ "total_outgoing_load": _round_float(self.total_outgoing_load),
354
+ "recent_throughput": _round_float(self.recent_throughput),
355
+ "queue_change": _round_float(self.queue_change),
356
+ "wait_change": _round_float(self.wait_change),
357
+ "throughput_change": _round_float(self.throughput_change),
358
+ "ns_queue": _round_float(self.ns_queue),
359
+ "ew_queue": _round_float(self.ew_queue),
360
+ "ns_wait": _round_float(self.ns_wait),
361
+ "ew_wait": _round_float(self.ew_wait),
362
+ "dominant_flow": self.dominant_flow,
363
+ "boundary_queue_total": _round_float(self.boundary_queue_total),
364
+ "boundary_wait_total": _round_float(self.boundary_wait_total),
365
+ "spillback_risk": bool(self.spillback_risk),
366
+ "incident_flag": bool(self.incident_flag),
367
+ "construction_flag": bool(self.construction_flag),
368
+ "overload_flag": bool(self.overload_flag),
369
+ "event_flag": bool(self.event_flag),
370
+ "top_congested_intersections": [
371
+ item.to_dict() for item in self.top_congested_intersections
372
+ ],
373
+ "candidate_intersections": [
374
+ item.to_dict() for item in self.candidate_intersections
375
+ ],
376
+ }
377
+
378
+ def to_json(self) -> str:
379
+ return json.dumps(self.to_dict(), sort_keys=True, separators=(",", ":"))
380
+
381
+ def to_prompt_text(self) -> str:
382
+ self.validate()
383
+ top_lines = [item.to_prompt_line() for item in self.top_congested_intersections]
384
+ candidate_lines = [item.to_prompt_line() for item in self.candidate_intersections]
385
+ if not top_lines:
386
+ top_lines = ["- none"]
387
+ if not candidate_lines:
388
+ candidate_lines = ["- none"]
389
+ return "\n".join(
390
+ [
391
+ f"city_id: {self.city_id}",
392
+ f"district_id: {self.district_id}",
393
+ f"district_type: {self.district_type}",
394
+ f"scenario: {self.scenario_name}",
395
+ f"scenario_type: {self.scenario_type}",
396
+ f"decision_step: {self.decision_step}",
397
+ f"sim_time: {self.sim_time}",
398
+ f"intersection_count: {self.intersection_count}",
399
+ f"avg_queue: {self.avg_queue:.2f}",
400
+ f"max_queue: {self.max_queue:.2f}",
401
+ f"total_queue: {self.total_queue:.2f}",
402
+ f"avg_wait: {self.avg_wait:.2f}",
403
+ f"max_wait: {self.max_wait:.2f}",
404
+ f"total_wait: {self.total_wait:.2f}",
405
+ f"avg_outgoing_load: {self.avg_outgoing_load:.2f}",
406
+ f"max_outgoing_load: {self.max_outgoing_load:.2f}",
407
+ f"total_outgoing_load: {self.total_outgoing_load:.2f}",
408
+ f"recent_throughput: {self.recent_throughput:.2f}",
409
+ f"queue_change: {self.queue_change:.2f}",
410
+ f"wait_change: {self.wait_change:.2f}",
411
+ f"throughput_change: {self.throughput_change:.2f}",
412
+ f"ns_queue: {self.ns_queue:.2f}",
413
+ f"ew_queue: {self.ew_queue:.2f}",
414
+ f"ns_wait: {self.ns_wait:.2f}",
415
+ f"ew_wait: {self.ew_wait:.2f}",
416
+ f"dominant_flow: {self.dominant_flow}",
417
+ f"boundary_queue_total: {self.boundary_queue_total:.2f}",
418
+ f"boundary_wait_total: {self.boundary_wait_total:.2f}",
419
+ f"spillback_risk: {int(self.spillback_risk)}",
420
+ f"incident_flag: {int(self.incident_flag)}",
421
+ f"construction_flag: {int(self.construction_flag)}",
422
+ f"overload_flag: {int(self.overload_flag)}",
423
+ f"event_flag: {int(self.event_flag)}",
424
+ "top_congested_intersections:",
425
+ *top_lines,
426
+ "candidate_intersections:",
427
+ *candidate_lines,
428
+ ]
429
+ )
district_llm/summary_builder.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+
9
+ from district_llm.schema import CandidateIntersection, CongestedIntersection, DistrictStateSummary, candidate_priority_score
10
+ from env.utils import load_json
11
+
12
+
13
+ @dataclass
14
+ class _SummaryContext:
15
+ previous_summaries: dict[str, DistrictStateSummary]
16
+ previous_finished_vehicles: int
17
+
18
+
19
+ class DistrictStateSummaryBuilder:
20
+ def __init__(self, top_k: int = 3, candidate_limit: int = 6):
21
+ self.top_k = int(top_k)
22
+ self.candidate_limit = int(candidate_limit)
23
+ self._context = _SummaryContext(previous_summaries={}, previous_finished_vehicles=0)
24
+ self._scenario_metadata: dict[str, Any] | None = None
25
+ self._road_endpoints: dict[str, tuple[str, str]] | None = None
26
+ self._incident_intersections: set[str] = set()
27
+
28
+ def reset(self) -> None:
29
+ self._context = _SummaryContext(previous_summaries={}, previous_finished_vehicles=0)
30
+ self._scenario_metadata = None
31
+ self._road_endpoints = None
32
+ self._incident_intersections = set()
33
+
34
+ def build_all(self, env, observation_batch: dict[str, Any]) -> dict[str, DistrictStateSummary]:
35
+ if self._scenario_metadata is None:
36
+ metadata_path = Path(env.scenario_dir) / "scenario_metadata.json"
37
+ self._scenario_metadata = load_json(metadata_path) if metadata_path.exists() else {}
38
+ self._road_endpoints = self._load_road_endpoints(Path(env.roadnet_path))
39
+ self._incident_intersections = self._derive_incident_intersections()
40
+
41
+ lane_vehicle_count = env.adapter.get_lane_vehicle_count()
42
+ finished_vehicles = int(env.adapter.get_finished_vehicle_count())
43
+ district_summaries: dict[str, DistrictStateSummary] = {}
44
+
45
+ for district_id in env.districts:
46
+ district_summaries[district_id] = self._build_single(
47
+ env=env,
48
+ observation_batch=observation_batch,
49
+ lane_vehicle_count=lane_vehicle_count,
50
+ district_id=district_id,
51
+ finished_vehicles=finished_vehicles,
52
+ )
53
+
54
+ self._context.previous_summaries = district_summaries
55
+ self._context.previous_finished_vehicles = finished_vehicles
56
+ return district_summaries
57
+
58
+ def _build_single(
59
+ self,
60
+ env,
61
+ observation_batch: dict[str, Any],
62
+ lane_vehicle_count: dict[str, int],
63
+ district_id: str,
64
+ finished_vehicles: int,
65
+ ) -> DistrictStateSummary:
66
+ district = env.districts[district_id]
67
+ scenario_metadata = self._scenario_metadata or {}
68
+ intersection_ids = observation_batch["intersection_ids"]
69
+ district_ids = observation_batch["district_ids"]
70
+ incoming_counts = observation_batch["incoming_counts"]
71
+ incoming_waiting = observation_batch["incoming_waiting"]
72
+ current_phase = observation_batch["current_phase"]
73
+
74
+ queue_totals: list[float] = []
75
+ wait_totals: list[float] = []
76
+ outgoing_loads: list[float] = []
77
+ ns_queue = 0.0
78
+ ew_queue = 0.0
79
+ ns_wait = 0.0
80
+ ew_wait = 0.0
81
+ boundary_queue_total = 0.0
82
+ boundary_wait_total = 0.0
83
+ congestion_items: list[CongestedIntersection] = []
84
+ candidate_seed_items: list[dict[str, Any]] = []
85
+
86
+ for index, intersection_id in enumerate(intersection_ids):
87
+ if district_ids[index] != district_id:
88
+ continue
89
+
90
+ queue_total = float(np.asarray(incoming_counts[index], dtype=np.float32).sum())
91
+ wait_total = float(np.asarray(incoming_waiting[index], dtype=np.float32).sum())
92
+ outgoing_load = self._compute_outgoing_load(
93
+ env=env,
94
+ lane_vehicle_count=lane_vehicle_count,
95
+ intersection_id=intersection_id,
96
+ )
97
+ queue_totals.append(queue_total)
98
+ wait_totals.append(wait_total)
99
+ outgoing_loads.append(outgoing_load)
100
+
101
+ midpoint = incoming_counts.shape[1] // 2
102
+ ns_queue_local = float(np.asarray(incoming_counts[index][:midpoint], dtype=np.float32).sum())
103
+ ew_queue_local = float(np.asarray(incoming_counts[index][midpoint:], dtype=np.float32).sum())
104
+ ns_wait_local = float(np.asarray(incoming_waiting[index][:midpoint], dtype=np.float32).sum())
105
+ ew_wait_local = float(np.asarray(incoming_waiting[index][midpoint:], dtype=np.float32).sum())
106
+ ns_queue += ns_queue_local
107
+ ew_queue += ew_queue_local
108
+ ns_wait += ns_wait_local
109
+ ew_wait += ew_wait_local
110
+
111
+ intersection_config = env.intersections[intersection_id]
112
+ if intersection_config.is_boundary:
113
+ boundary_queue_total += queue_total
114
+ boundary_wait_total += wait_total
115
+
116
+ congestion_items.append(
117
+ CongestedIntersection(
118
+ intersection_id=intersection_id,
119
+ queue_total=queue_total,
120
+ wait_total=wait_total,
121
+ outgoing_load=outgoing_load,
122
+ current_phase=int(current_phase[index]),
123
+ is_boundary=bool(intersection_config.is_boundary),
124
+ )
125
+ )
126
+ candidate_seed_items.append(
127
+ {
128
+ "intersection_id": intersection_id,
129
+ "queue_total": queue_total,
130
+ "wait_total": wait_total,
131
+ "outgoing_load": outgoing_load,
132
+ "current_phase": int(current_phase[index]),
133
+ "is_boundary": bool(intersection_config.is_boundary),
134
+ "spillback_risk": bool(
135
+ outgoing_load >= max(6.0, queue_total * 0.6)
136
+ or (
137
+ intersection_config.is_boundary
138
+ and outgoing_load >= max(4.0, queue_total * 0.4)
139
+ )
140
+ ),
141
+ "incident_proximity": intersection_id in self._incident_intersections,
142
+ "corridor_alignment": self._compute_corridor_alignment(
143
+ ns_queue=ns_queue_local,
144
+ ew_queue=ew_queue_local,
145
+ ns_wait=ns_wait_local,
146
+ ew_wait=ew_wait_local,
147
+ ),
148
+ }
149
+ )
150
+
151
+ queue_array = np.asarray(queue_totals or [0.0], dtype=np.float32)
152
+ wait_array = np.asarray(wait_totals or [0.0], dtype=np.float32)
153
+ outgoing_array = np.asarray(outgoing_loads or [0.0], dtype=np.float32)
154
+
155
+ previous_summary = self._context.previous_summaries.get(district_id)
156
+ recent_throughput = float(
157
+ finished_vehicles - self._context.previous_finished_vehicles
158
+ if self._context.previous_finished_vehicles
159
+ else 0.0
160
+ )
161
+ queue_change = 0.0 if previous_summary is None else float(queue_array.sum() - previous_summary.total_queue)
162
+ wait_change = 0.0 if previous_summary is None else float(wait_array.sum() - previous_summary.total_wait)
163
+ throughput_change = (
164
+ 0.0
165
+ if previous_summary is None
166
+ else recent_throughput - previous_summary.recent_throughput
167
+ )
168
+
169
+ directional_ns = ns_queue + 1.5 * ns_wait
170
+ directional_ew = ew_queue + 1.5 * ew_wait
171
+ if directional_ns > directional_ew * 1.1:
172
+ dominant_flow = "NS"
173
+ elif directional_ew > directional_ns * 1.1:
174
+ dominant_flow = "EW"
175
+ else:
176
+ dominant_flow = "BALANCED"
177
+
178
+ boundary_share = boundary_queue_total / max(1.0, float(queue_array.sum()))
179
+ spillback_risk = bool(
180
+ outgoing_array.max() >= max(8.0, queue_array.max() * 0.5)
181
+ or (boundary_share >= 0.6 and queue_change >= 0.0)
182
+ )
183
+
184
+ top_intersections = sorted(
185
+ congestion_items,
186
+ key=lambda item: (item.queue_total + 1.5 * item.wait_total + 0.5 * item.outgoing_load),
187
+ reverse=True,
188
+ )[: self.top_k]
189
+
190
+ overload_flag = bool(
191
+ scenario_metadata.get("overload_district") == district_id
192
+ or (scenario_metadata.get("name") == "district_overload" and queue_array.sum() >= 25.0)
193
+ )
194
+ event_flag = bool(scenario_metadata.get("event_district") == district_id)
195
+ incident_flag = bool(
196
+ scenario_metadata.get("name") in {"accident", "construction"}
197
+ or bool(scenario_metadata.get("blocked_roads"))
198
+ )
199
+ construction_flag = bool(scenario_metadata.get("name") == "construction")
200
+ candidate_intersections = self._build_candidate_intersections(
201
+ candidate_seed_items=candidate_seed_items,
202
+ overload_flag=overload_flag,
203
+ event_flag=event_flag,
204
+ )
205
+
206
+ return DistrictStateSummary(
207
+ city_id=env.city_id,
208
+ district_id=district_id,
209
+ district_type=district.district_type,
210
+ scenario_name=env.scenario_name,
211
+ scenario_type=str(scenario_metadata.get("intensity", env.scenario_name)),
212
+ decision_step=int(observation_batch["decision_step"]),
213
+ sim_time=int(observation_batch["sim_time"]),
214
+ intersection_count=int(len(district.intersection_ids)),
215
+ avg_queue=float(queue_array.mean()),
216
+ max_queue=float(queue_array.max()),
217
+ total_queue=float(queue_array.sum()),
218
+ avg_wait=float(wait_array.mean()),
219
+ max_wait=float(wait_array.max()),
220
+ total_wait=float(wait_array.sum()),
221
+ avg_outgoing_load=float(outgoing_array.mean()),
222
+ max_outgoing_load=float(outgoing_array.max()),
223
+ total_outgoing_load=float(outgoing_array.sum()),
224
+ recent_throughput=recent_throughput,
225
+ queue_change=queue_change,
226
+ wait_change=wait_change,
227
+ throughput_change=throughput_change,
228
+ ns_queue=ns_queue,
229
+ ew_queue=ew_queue,
230
+ ns_wait=ns_wait,
231
+ ew_wait=ew_wait,
232
+ dominant_flow=dominant_flow,
233
+ boundary_queue_total=boundary_queue_total,
234
+ boundary_wait_total=boundary_wait_total,
235
+ spillback_risk=spillback_risk,
236
+ incident_flag=incident_flag,
237
+ construction_flag=construction_flag,
238
+ overload_flag=overload_flag,
239
+ event_flag=event_flag,
240
+ top_congested_intersections=top_intersections,
241
+ candidate_intersections=candidate_intersections,
242
+ ).validate()
243
+
244
+ @staticmethod
245
+ def _compute_outgoing_load(env, lane_vehicle_count: dict[str, int], intersection_id: str) -> float:
246
+ intersection_config = env.intersections[intersection_id]
247
+ if not intersection_config.outgoing_lanes:
248
+ return 0.0
249
+ return float(
250
+ sum(float(lane_vehicle_count.get(lane_id, 0)) for lane_id in intersection_config.outgoing_lanes)
251
+ )
252
+
253
+ @staticmethod
254
+ def _compute_corridor_alignment(
255
+ ns_queue: float,
256
+ ew_queue: float,
257
+ ns_wait: float,
258
+ ew_wait: float,
259
+ ) -> str:
260
+ ns_pressure = ns_queue + 1.5 * ns_wait
261
+ ew_pressure = ew_queue + 1.5 * ew_wait
262
+ if ns_pressure > ew_pressure * 1.1:
263
+ return "NS"
264
+ if ew_pressure > ns_pressure * 1.1:
265
+ return "EW"
266
+ return "BALANCED"
267
+
268
+ @staticmethod
269
+ def _load_road_endpoints(roadnet_path: Path) -> dict[str, tuple[str, str]]:
270
+ roadnet = load_json(roadnet_path)
271
+ return {
272
+ str(road["id"]): (
273
+ str(road["startIntersection"]),
274
+ str(road["endIntersection"]),
275
+ )
276
+ for road in roadnet.get("roads", [])
277
+ }
278
+
279
+ def _derive_incident_intersections(self) -> set[str]:
280
+ if not self._road_endpoints:
281
+ return set()
282
+ scenario_metadata = self._scenario_metadata or {}
283
+ details = scenario_metadata.get("details", {})
284
+ incident_roads = list(scenario_metadata.get("blocked_roads", []))
285
+ incident_roads.extend(details.get("accident_roads", []))
286
+ incident_roads.extend(details.get("construction_roads", []))
287
+ if not incident_roads:
288
+ incident_roads.extend(list((scenario_metadata.get("penalized_roads") or {}).keys()))
289
+
290
+ intersections: set[str] = set()
291
+ for road_id in incident_roads:
292
+ endpoints = self._road_endpoints.get(str(road_id))
293
+ if endpoints is None:
294
+ continue
295
+ intersections.update(endpoints)
296
+ return intersections
297
+
298
+ def _build_candidate_intersections(
299
+ self,
300
+ candidate_seed_items: list[dict[str, Any]],
301
+ overload_flag: bool,
302
+ event_flag: bool,
303
+ ) -> list[CandidateIntersection]:
304
+ if not candidate_seed_items or self.candidate_limit <= 0:
305
+ return []
306
+
307
+ def severity_key(item: dict[str, Any]) -> tuple[float, float, float, float, str]:
308
+ candidate = CandidateIntersection(
309
+ intersection_id=str(item["intersection_id"]),
310
+ queue_total=float(item["queue_total"]),
311
+ wait_total=float(item["wait_total"]),
312
+ outgoing_load=float(item["outgoing_load"]),
313
+ current_phase=int(item["current_phase"]),
314
+ is_boundary=bool(item["is_boundary"]),
315
+ spillback_risk=bool(item["spillback_risk"]),
316
+ incident_proximity=bool(item["incident_proximity"]),
317
+ overload_marker=overload_flag,
318
+ event_proximity=event_flag,
319
+ corridor_alignment=str(item["corridor_alignment"]),
320
+ selection_reasons=[],
321
+ )
322
+ return (
323
+ candidate_priority_score(candidate),
324
+ float(item["queue_total"]),
325
+ float(item["wait_total"]),
326
+ float(item["outgoing_load"]),
327
+ str(item["intersection_id"]),
328
+ )
329
+
330
+ overall_sorted = sorted(
331
+ candidate_seed_items,
332
+ key=lambda item: (
333
+ -severity_key(item)[0],
334
+ -severity_key(item)[1],
335
+ -severity_key(item)[2],
336
+ -severity_key(item)[3],
337
+ severity_key(item)[4],
338
+ ),
339
+ )
340
+ boundary_sorted = [item for item in overall_sorted if item["is_boundary"]]
341
+ spillback_sorted = [item for item in overall_sorted if item["spillback_risk"]]
342
+ incident_sorted = [item for item in overall_sorted if item["incident_proximity"]]
343
+ outgoing_sorted = sorted(
344
+ candidate_seed_items,
345
+ key=lambda item: (
346
+ -float(item["outgoing_load"]),
347
+ -float(item["queue_total"]),
348
+ -float(item["wait_total"]),
349
+ str(item["intersection_id"]),
350
+ ),
351
+ )
352
+
353
+ reason_tags: dict[str, set[str]] = {}
354
+ selected_ids: list[str] = []
355
+
356
+ def mark(items: list[dict[str, Any]], tag: str, limit: int) -> None:
357
+ for item in items[:limit]:
358
+ intersection_id = str(item["intersection_id"])
359
+ reason_tags.setdefault(intersection_id, set()).add(tag)
360
+ if intersection_id not in selected_ids:
361
+ selected_ids.append(intersection_id)
362
+
363
+ mark(overall_sorted, "congested", max(1, min(self.top_k, self.candidate_limit)))
364
+ mark(boundary_sorted, "boundary", min(2, self.candidate_limit))
365
+ mark(spillback_sorted, "spillback", min(2, self.candidate_limit))
366
+ mark(incident_sorted, "incident", min(2, self.candidate_limit))
367
+ mark(outgoing_sorted, "outgoing", min(2, self.candidate_limit))
368
+ if overload_flag:
369
+ mark(overall_sorted, "overload", min(2, self.candidate_limit))
370
+ if event_flag:
371
+ event_seed = boundary_sorted if boundary_sorted else outgoing_sorted
372
+ mark(event_seed, "event", min(2, self.candidate_limit))
373
+
374
+ for item in overall_sorted:
375
+ if len(selected_ids) >= self.candidate_limit:
376
+ break
377
+ intersection_id = str(item["intersection_id"])
378
+ if intersection_id in selected_ids:
379
+ continue
380
+ selected_ids.append(intersection_id)
381
+ reason_tags.setdefault(intersection_id, {"congested"})
382
+
383
+ seed_lookup = {
384
+ str(item["intersection_id"]): item
385
+ for item in candidate_seed_items
386
+ }
387
+ candidates = [
388
+ CandidateIntersection(
389
+ intersection_id=intersection_id,
390
+ queue_total=float(seed_lookup[intersection_id]["queue_total"]),
391
+ wait_total=float(seed_lookup[intersection_id]["wait_total"]),
392
+ outgoing_load=float(seed_lookup[intersection_id]["outgoing_load"]),
393
+ current_phase=int(seed_lookup[intersection_id]["current_phase"]),
394
+ is_boundary=bool(seed_lookup[intersection_id]["is_boundary"]),
395
+ spillback_risk=bool(seed_lookup[intersection_id]["spillback_risk"]),
396
+ incident_proximity=bool(seed_lookup[intersection_id]["incident_proximity"]),
397
+ overload_marker=overload_flag,
398
+ event_proximity=event_flag,
399
+ corridor_alignment=str(seed_lookup[intersection_id]["corridor_alignment"]),
400
+ selection_reasons=sorted(reason_tags.get(intersection_id, {"congested"})),
401
+ ).validate()
402
+ for intersection_id in selected_ids[: self.candidate_limit]
403
+ ]
404
+ return sorted(
405
+ candidates,
406
+ key=lambda item: (
407
+ -candidate_priority_score(item),
408
+ -item.queue_total,
409
+ -item.wait_total,
410
+ -item.outgoing_load,
411
+ item.intersection_id,
412
+ ),
413
+ )
district_llm/teachers.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from abc import ABC, abstractmethod
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import numpy as np
10
+
11
+ from agents.local_policy import (
12
+ BaseLocalPolicy,
13
+ FixedCyclePolicy,
14
+ HoldPhasePolicy,
15
+ QueueGreedyPolicy,
16
+ RandomPhasePolicy,
17
+ )
18
+
19
+
20
+ BASELINE_TYPES: tuple[str, ...] = ("hold", "fixed", "random", "queue_greedy")
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class TeacherMetadata:
25
+ controller_type: str
26
+ controller_id: str
27
+ controller_family: str
28
+ teacher_algorithm: str
29
+ checkpoint_path: str | None = None
30
+
31
+ def to_dict(self) -> dict[str, Any]:
32
+ return {
33
+ "controller_type": self.controller_type,
34
+ "controller_id": self.controller_id,
35
+ "controller_family": self.controller_family,
36
+ "teacher_algorithm": self.teacher_algorithm,
37
+ "checkpoint_path": self.checkpoint_path,
38
+ }
39
+
40
+
41
+ class BaseTeacher(ABC):
42
+ def __init__(self, metadata: TeacherMetadata):
43
+ self.metadata = metadata
44
+
45
+ @property
46
+ def env_config(self) -> Any | None:
47
+ return None
48
+
49
+ @abstractmethod
50
+ def act(self, observation_batch: dict[str, Any]) -> np.ndarray:
51
+ raise NotImplementedError
52
+
53
+
54
+ class BaselineTeacher(BaseTeacher):
55
+ def __init__(self, policy: BaseLocalPolicy, metadata: TeacherMetadata):
56
+ super().__init__(metadata=metadata)
57
+ self.policy = policy
58
+
59
+ def act(self, observation_batch: dict[str, Any]) -> np.ndarray:
60
+ return np.asarray(self.policy.act(observation_batch), dtype=np.int64)
61
+
62
+
63
+ class RLCheckpointTeacher(BaseTeacher):
64
+ def __init__(
65
+ self,
66
+ checkpoint_path: str | Path,
67
+ device: str | None = None,
68
+ deterministic: bool = True,
69
+ ):
70
+ try:
71
+ import torch
72
+ except ImportError as exc:
73
+ raise ImportError(
74
+ "RL checkpoint teachers require PyTorch to be installed."
75
+ ) from exc
76
+
77
+ from training.models import RunningNormalizer, TrafficControlQNetwork
78
+ from training.train_local_policy import load_env_config
79
+
80
+ checkpoint_path = Path(checkpoint_path)
81
+ self._torch = torch
82
+ self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
83
+ self.checkpoint = torch.load(
84
+ checkpoint_path,
85
+ map_location=self.device,
86
+ weights_only=False,
87
+ )
88
+ network_architecture = self.checkpoint.get("network_architecture") or self.checkpoint.get(
89
+ "policy_architecture",
90
+ {},
91
+ )
92
+ trainer_config = self.checkpoint.get("dqn_config", {})
93
+ policy_arch = network_architecture.get(
94
+ "policy_arch",
95
+ trainer_config.get("policy_arch", "single_head_with_district_feature"),
96
+ )
97
+ self.model = TrafficControlQNetwork(
98
+ observation_dim=int(network_architecture["observation_dim"]),
99
+ action_dim=int(network_architecture.get("action_dim", 2)),
100
+ hidden_dim=int(trainer_config.get("hidden_dim", 256)),
101
+ num_layers=int(trainer_config.get("hidden_layers", 2)),
102
+ district_types=tuple(network_architecture.get("district_types", ())),
103
+ policy_arch=policy_arch,
104
+ dueling=bool(network_architecture.get("dueling", True)),
105
+ ).to(self.device)
106
+ self.model.load_state_dict(
107
+ self.checkpoint.get("q_network_state_dict") or self.checkpoint["policy_state_dict"]
108
+ )
109
+ self.model.eval()
110
+ self.obs_normalizer = None
111
+ if self.checkpoint.get("obs_normalizer"):
112
+ self.obs_normalizer = RunningNormalizer()
113
+ self.obs_normalizer.load_state_dict(self.checkpoint["obs_normalizer"])
114
+
115
+ checkpoint_id = checkpoint_path.stem
116
+ super().__init__(
117
+ metadata=TeacherMetadata(
118
+ controller_type="rl_checkpoint",
119
+ controller_id=checkpoint_id,
120
+ controller_family="dqn",
121
+ teacher_algorithm="dqn",
122
+ checkpoint_path=str(checkpoint_path),
123
+ )
124
+ )
125
+ self.deterministic = bool(deterministic)
126
+ self._env_config = (
127
+ load_env_config(self.checkpoint["env_config"])
128
+ if self.checkpoint.get("env_config")
129
+ else None
130
+ )
131
+
132
+ @property
133
+ def env_config(self) -> Any | None:
134
+ return self._env_config
135
+
136
+ def act(self, observation_batch: dict[str, Any]) -> np.ndarray:
137
+ torch = self._torch
138
+ raw_obs = observation_batch["observations"].astype(np.float32)
139
+ normalized_obs = self.obs_normalizer.normalize(raw_obs) if self.obs_normalizer else raw_obs
140
+ obs_tensor = torch.as_tensor(normalized_obs, dtype=torch.float32, device=self.device)
141
+ district_type_tensor = torch.as_tensor(
142
+ observation_batch["district_type_indices"],
143
+ dtype=torch.int64,
144
+ device=self.device,
145
+ )
146
+ action_mask_tensor = torch.as_tensor(
147
+ observation_batch["action_mask"],
148
+ dtype=torch.float32,
149
+ device=self.device,
150
+ )
151
+ with torch.no_grad():
152
+ actions = self.model.act(
153
+ observations=obs_tensor,
154
+ district_type_indices=district_type_tensor,
155
+ action_mask=action_mask_tensor,
156
+ deterministic=self.deterministic,
157
+ epsilon=0.0,
158
+ )
159
+ return actions.cpu().numpy().astype(np.int64)
160
+
161
+
162
+ def build_teacher(
163
+ controller_type: str,
164
+ checkpoint: str | None = None,
165
+ fixed_green_time: int = 20,
166
+ seed: int = 7,
167
+ device: str | None = None,
168
+ ) -> BaseTeacher:
169
+ if controller_type == "rl_checkpoint":
170
+ if not checkpoint:
171
+ raise ValueError("controller_type='rl_checkpoint' requires --checkpoint.")
172
+ return RLCheckpointTeacher(checkpoint_path=checkpoint, device=device)
173
+ if controller_type == "hold":
174
+ return BaselineTeacher(
175
+ policy=HoldPhasePolicy(),
176
+ metadata=TeacherMetadata(
177
+ controller_type="hold",
178
+ controller_id="hold",
179
+ controller_family="baseline",
180
+ teacher_algorithm="hold",
181
+ ),
182
+ )
183
+ if controller_type == "fixed":
184
+ return BaselineTeacher(
185
+ policy=FixedCyclePolicy(green_time=fixed_green_time),
186
+ metadata=TeacherMetadata(
187
+ controller_type="fixed",
188
+ controller_id=f"fixed_{fixed_green_time}",
189
+ controller_family="baseline",
190
+ teacher_algorithm="fixed_cycle",
191
+ ),
192
+ )
193
+ if controller_type == "random":
194
+ return BaselineTeacher(
195
+ policy=RandomPhasePolicy(seed=seed),
196
+ metadata=TeacherMetadata(
197
+ controller_type="random",
198
+ controller_id=f"random_{seed}",
199
+ controller_family="baseline",
200
+ teacher_algorithm="random",
201
+ ),
202
+ )
203
+ if controller_type == "queue_greedy":
204
+ return BaselineTeacher(
205
+ policy=QueueGreedyPolicy(),
206
+ metadata=TeacherMetadata(
207
+ controller_type="queue_greedy",
208
+ controller_id="queue_greedy",
209
+ controller_family="baseline",
210
+ teacher_algorithm="queue_greedy",
211
+ ),
212
+ )
213
+ raise ValueError(
214
+ f"Unsupported controller_type '{controller_type}'. "
215
+ f"Expected rl_checkpoint or one of {BASELINE_TYPES}."
216
+ )
217
+
218
+
219
+ def parse_teacher_spec(spec: str) -> tuple[str, str | None]:
220
+ if "=" not in spec:
221
+ return spec.strip(), None
222
+ controller_type, checkpoint_path = spec.split("=", 1)
223
+ return controller_type.strip(), checkpoint_path.strip() or None
224
+
225
+
226
+ def teachers_metadata_json(teachers: list[BaseTeacher]) -> str:
227
+ return json.dumps([teacher.metadata.to_dict() for teacher in teachers], sort_keys=True)
district_llm/train_unsloth.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ from district_llm.data import load_jsonl_text_dataset
7
+
8
+
9
+ def parse_args() -> argparse.Namespace:
10
+ parser = argparse.ArgumentParser(
11
+ description="Supervised fine-tune a district LLM on DQN-derived district traces with Unsloth/QLoRA."
12
+ )
13
+ parser.add_argument("--dataset", required=True, help="JSONL dataset with a 'text' field.")
14
+ parser.add_argument("--output-dir", required=True)
15
+ parser.add_argument("--model-name", default="Qwen/Qwen2.5-7B-Instruct")
16
+ parser.add_argument("--max-seq-length", type=int, default=1024)
17
+ parser.add_argument("--load-in-4bit", action="store_true")
18
+ parser.add_argument("--lora-rank", type=int, default=16)
19
+ parser.add_argument("--lora-alpha", type=int, default=16)
20
+ parser.add_argument("--lora-dropout", type=float, default=0.0)
21
+ parser.add_argument("--batch-size", type=int, default=2)
22
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
23
+ parser.add_argument("--learning-rate", type=float, default=2e-4)
24
+ parser.add_argument("--warmup-steps", type=int, default=50)
25
+ parser.add_argument("--max-steps", type=int, default=500)
26
+ parser.add_argument("--logging-steps", type=int, default=10)
27
+ parser.add_argument("--save-steps", type=int, default=100)
28
+ parser.add_argument("--seed", type=int, default=7)
29
+ parser.add_argument("--dataset-num-proc", type=int, default=2)
30
+ parser.add_argument("--eval-dataset", default=None)
31
+ parser.add_argument("--resume-from-checkpoint", default=None)
32
+ parser.add_argument(
33
+ "--include-non-dqn-sources",
34
+ action="store_true",
35
+ help="By default the trainer keeps only DQN-derived rows (controller_family=dqn).",
36
+ )
37
+ return parser.parse_args()
38
+
39
+
40
+ def main() -> None:
41
+ args = parse_args()
42
+
43
+ try:
44
+ import torch
45
+ from trl import SFTTrainer
46
+ from transformers import TrainingArguments
47
+ from unsloth import FastLanguageModel
48
+ except ImportError as exc:
49
+ raise ImportError(
50
+ "District LLM training requires 'unsloth' and 'trl'. "
51
+ "Install them in the active environment before running this entry point."
52
+ ) from exc
53
+
54
+ controller_families = None if args.include_non_dqn_sources else ["dqn"]
55
+ train_dataset = load_jsonl_text_dataset(
56
+ args.dataset,
57
+ controller_families=controller_families,
58
+ )
59
+ eval_dataset = (
60
+ load_jsonl_text_dataset(
61
+ args.eval_dataset,
62
+ controller_families=controller_families,
63
+ )
64
+ if args.eval_dataset
65
+ else None
66
+ )
67
+ output_dir = Path(args.output_dir)
68
+ output_dir.mkdir(parents=True, exist_ok=True)
69
+
70
+ model, tokenizer = FastLanguageModel.from_pretrained(
71
+ model_name=args.model_name,
72
+ max_seq_length=args.max_seq_length,
73
+ dtype=None,
74
+ load_in_4bit=bool(args.load_in_4bit),
75
+ )
76
+ model = FastLanguageModel.get_peft_model(
77
+ model,
78
+ r=args.lora_rank,
79
+ target_modules=[
80
+ "q_proj",
81
+ "k_proj",
82
+ "v_proj",
83
+ "o_proj",
84
+ "gate_proj",
85
+ "up_proj",
86
+ "down_proj",
87
+ ],
88
+ lora_alpha=args.lora_alpha,
89
+ lora_dropout=args.lora_dropout,
90
+ bias="none",
91
+ use_gradient_checkpointing="unsloth",
92
+ random_state=args.seed,
93
+ )
94
+
95
+ trainer = SFTTrainer(
96
+ model=model,
97
+ tokenizer=tokenizer,
98
+ train_dataset=train_dataset,
99
+ eval_dataset=eval_dataset,
100
+ dataset_text_field="text",
101
+ max_seq_length=args.max_seq_length,
102
+ dataset_num_proc=args.dataset_num_proc,
103
+ packing=False,
104
+ args=TrainingArguments(
105
+ output_dir=str(output_dir),
106
+ per_device_train_batch_size=args.batch_size,
107
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
108
+ warmup_steps=args.warmup_steps,
109
+ max_steps=args.max_steps,
110
+ learning_rate=args.learning_rate,
111
+ logging_steps=args.logging_steps,
112
+ save_steps=args.save_steps,
113
+ bf16=bool(torch.cuda.is_available() and torch.cuda.is_bf16_supported()),
114
+ fp16=bool(torch.cuda.is_available() and not torch.cuda.is_bf16_supported()),
115
+ optim="paged_adamw_8bit",
116
+ lr_scheduler_type="cosine",
117
+ seed=args.seed,
118
+ report_to="none",
119
+ evaluation_strategy="steps" if eval_dataset is not None else "no",
120
+ eval_steps=args.save_steps if eval_dataset is not None else None,
121
+ ),
122
+ )
123
+ trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
124
+ model.save_pretrained(str(output_dir))
125
+ tokenizer.save_pretrained(str(output_dir))
126
+
127
+
128
+ if __name__ == "__main__":
129
+ main()
env/README.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # env
2
+
3
+ CityFlow environment implementation for intersection-level RL with district-type metadata.
4
+
5
+ ## Main files
6
+
7
+ - [traffic_env.py](/Users/aditya/Developer/traffic-llm/env/traffic_env.py)
8
+ Main environment. One episode corresponds to one `(city, scenario)` pair.
9
+ - [cityflow_adapter.py](/Users/aditya/Developer/traffic-llm/env/cityflow_adapter.py)
10
+ Thin wrapper around the CityFlow Python engine.
11
+ - [observation_builder.py](/Users/aditya/Developer/traffic-llm/env/observation_builder.py)
12
+ Converts variable city topology into fixed-size per-intersection tensors.
13
+ - [reward.py](/Users/aditya/Developer/traffic-llm/env/reward.py)
14
+ Configurable local reward calculation.
15
+ - [utils.py](/Users/aditya/Developer/traffic-llm/env/utils.py)
16
+ Topology parsing and helper functions.
17
+ - [intersection_config.py](/Users/aditya/Developer/traffic-llm/env/intersection_config.py)
18
+ Internal topology dataclasses.
19
+
20
+ ## How it works
21
+
22
+ - Reads `roadnet.json`, `district_map.json`, and district types from `metadata.json`.
23
+ - Identifies non-virtual controllable intersections with at least two green phases.
24
+ - Uses one action per controllable intersection.
25
+ - Enforces `min_green_time` inside the environment.
26
+ - Advances CityFlow for `decision_interval` simulator steps between policy decisions.
27
+ - Returns a batched observation for all controlled intersections.
28
+
29
+ ## Observation model
30
+
31
+ Per intersection:
32
+
33
+ - padded incoming lane vehicle counts
34
+ - padded incoming lane waiting counts
35
+ - incoming lane mask
36
+ - current green phase index
37
+ - elapsed time in current phase
38
+ - optional outgoing congestion summary
39
+ - district-type one-hot features
40
+ - optional small district context
41
+ - boundary-intersection indicator
42
+
43
+ The observation dimension is exposed as `TrafficEnv.observation_dim`.
env/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from env.intersection_config import DistrictConfig, IntersectionConfig, PhaseConfig
2
+ from env.observation_builder import ObservationBuilder, ObservationConfig
3
+ from env.reward import RewardCalculator, RewardConfig
4
+ from env.traffic_env import EnvConfig, TrafficEnv
5
+ from env.utils import build_topology
6
+
7
+ __all__ = [
8
+ "DistrictConfig",
9
+ "EnvConfig",
10
+ "IntersectionConfig",
11
+ "ObservationBuilder",
12
+ "ObservationConfig",
13
+ "PhaseConfig",
14
+ "RewardCalculator",
15
+ "RewardConfig",
16
+ "TrafficEnv",
17
+ "build_topology",
18
+ ]
env/cityflow_adapter.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+
7
+ class CityFlowAdapter:
8
+ def __init__(self, config_path: str | Path, thread_num: int = 1):
9
+ self.config_path = str(config_path)
10
+ self.thread_num = int(thread_num)
11
+ self.engine = None
12
+ self._phase_cache: dict[str, int] = {}
13
+ self._active_vehicle_ids: set[str] = set()
14
+ self._finished_vehicle_ids: set[str] = set()
15
+
16
+ def reset(self) -> None:
17
+ try:
18
+ import cityflow
19
+ except ImportError as exc:
20
+ raise RuntimeError(
21
+ "CityFlow is not installed. Install the CityFlow Python bindings "
22
+ "before running smoke tests, training, or evaluation."
23
+ ) from exc
24
+
25
+ self.engine = cityflow.Engine(self.config_path, thread_num=self.thread_num)
26
+ self._phase_cache.clear()
27
+ self._active_vehicle_ids = self._fetch_active_vehicle_ids()
28
+ self._finished_vehicle_ids.clear()
29
+
30
+ def step(self) -> None:
31
+ self._require_engine()
32
+ self.engine.next_step()
33
+ current_vehicle_ids = self._fetch_active_vehicle_ids()
34
+ self._finished_vehicle_ids.update(self._active_vehicle_ids - current_vehicle_ids)
35
+ self._active_vehicle_ids = current_vehicle_ids
36
+
37
+ def set_tl_phase(self, intersection_id: str, phase: int) -> None:
38
+ self._require_engine()
39
+ self.engine.set_tl_phase(intersection_id, int(phase))
40
+ self._phase_cache[intersection_id] = int(phase)
41
+
42
+ def get_tl_phase(self, intersection_id: str) -> int:
43
+ self._require_engine()
44
+ if hasattr(self.engine, "get_tl_phase"):
45
+ phase = int(self.engine.get_tl_phase(intersection_id))
46
+ self._phase_cache[intersection_id] = phase
47
+ return phase
48
+ return int(self._phase_cache.get(intersection_id, 0))
49
+
50
+ def get_lane_vehicle_count(self) -> dict[str, int]:
51
+ self._require_engine()
52
+ return {
53
+ lane_id: int(count)
54
+ for lane_id, count in self.engine.get_lane_vehicle_count().items()
55
+ }
56
+
57
+ def get_lane_waiting_vehicle_count(self) -> dict[str, int]:
58
+ self._require_engine()
59
+ return {
60
+ lane_id: int(count)
61
+ for lane_id, count in self.engine.get_lane_waiting_vehicle_count().items()
62
+ }
63
+
64
+ def get_current_time(self) -> int:
65
+ self._require_engine()
66
+ return int(self.engine.get_current_time())
67
+
68
+ def get_vehicle_count(self) -> int:
69
+ self._require_engine()
70
+ if hasattr(self.engine, "get_vehicle_count"):
71
+ return int(self.engine.get_vehicle_count())
72
+ return len(self._active_vehicle_ids)
73
+
74
+ def get_average_travel_time(self) -> float | None:
75
+ self._require_engine()
76
+ if hasattr(self.engine, "get_average_travel_time"):
77
+ return float(self.engine.get_average_travel_time())
78
+ return None
79
+
80
+ def get_finished_vehicle_count(self) -> int:
81
+ self._require_engine()
82
+ if hasattr(self.engine, "get_finished_vehicle_count"):
83
+ return int(self.engine.get_finished_vehicle_count())
84
+ return len(self._finished_vehicle_ids)
85
+
86
+ def get_active_vehicle_ids(self) -> set[str]:
87
+ return set(self._active_vehicle_ids)
88
+
89
+ def _fetch_active_vehicle_ids(self) -> set[str]:
90
+ if self.engine is None or not hasattr(self.engine, "get_vehicles"):
91
+ return set()
92
+
93
+ vehicles = self.engine.get_vehicles()
94
+ if isinstance(vehicles, dict):
95
+ return set(vehicles.keys())
96
+ return set(vehicles)
97
+
98
+ def _require_engine(self) -> None:
99
+ if self.engine is None:
100
+ raise RuntimeError(
101
+ "CityFlow engine has not been initialized. Call reset() before use."
102
+ )
env/district_summary.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ class DistrictSummaryBuilder:
2
+ def build(self, adapter, district_config):
3
+ waiting = adapter.get_lane_waiting_vehicle_count()
4
+
5
+ return {
6
+ "district_id": district_config.id,
7
+ "intersection_ids": district_config.intersection_ids,
8
+ "avg_wait": sum(waiting.values()) / len(waiting),
9
+ }
env/intersection_config.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ DISTRICT_TYPES: tuple[str, ...] = (
6
+ "residential",
7
+ "commercial",
8
+ "industrial",
9
+ "mixed",
10
+ )
11
+ DISTRICT_TYPE_TO_INDEX: dict[str, int] = {
12
+ district_type: index for index, district_type in enumerate(DISTRICT_TYPES)
13
+ }
14
+ DEFAULT_DISTRICT_TYPE = "mixed"
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class PhaseConfig:
19
+ engine_phase_index: int
20
+ available_road_links: tuple[int, ...]
21
+ incoming_lanes_served: tuple[str, ...]
22
+ outgoing_lanes_served: tuple[str, ...]
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class IntersectionConfig:
27
+ intersection_id: str
28
+ district_id: str
29
+ district_type: str
30
+ district_type_index: int
31
+ incoming_lanes: tuple[str, ...]
32
+ outgoing_lanes: tuple[str, ...]
33
+ is_boundary: bool
34
+ green_phases: tuple[PhaseConfig, ...]
35
+ all_phase_indices: tuple[int, ...]
36
+ initial_engine_phase_index: int
37
+
38
+ @property
39
+ def num_green_phases(self) -> int:
40
+ return len(self.green_phases)
41
+
42
+
43
+ @dataclass(frozen=True)
44
+ class DistrictConfig:
45
+ district_id: str
46
+ district_type: str
47
+ district_type_index: int
48
+ intersection_ids: tuple[str, ...]
49
+ neighbor_districts: tuple[str, ...]
env/observation_builder.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+
7
+ from env.intersection_config import DistrictConfig, IntersectionConfig
8
+ from env.utils import normalize_scalar
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class ObservationConfig:
13
+ max_incoming_lanes: int = 16
14
+ count_scale: float = 20.0
15
+ elapsed_time_scale: float = 60.0
16
+ include_outgoing_congestion: bool = True
17
+ include_district_context: bool = True
18
+ include_district_type_feature: bool = True
19
+
20
+
21
+ class ObservationBuilder:
22
+ def __init__(
23
+ self,
24
+ intersections: dict[str, IntersectionConfig],
25
+ districts: dict[str, DistrictConfig],
26
+ config: ObservationConfig | None = None,
27
+ ):
28
+ self.intersections = intersections
29
+ self.districts = districts
30
+ self.config = config or ObservationConfig()
31
+ self.intersection_ids = tuple(sorted(intersections))
32
+ self._district_lookup = {
33
+ intersection_id: intersections[intersection_id].district_id
34
+ for intersection_id in self.intersection_ids
35
+ }
36
+ self._district_sizes = {
37
+ district_id: max(1, len(district.intersection_ids))
38
+ for district_id, district in districts.items()
39
+ }
40
+ self.observation_dim = self._compute_observation_dim()
41
+
42
+ def build(
43
+ self,
44
+ lane_vehicle_count: dict[str, int],
45
+ lane_waiting_count: dict[str, int],
46
+ phase_positions: dict[str, int],
47
+ phase_elapsed_times: dict[str, int],
48
+ switch_allowed: dict[str, bool],
49
+ ) -> dict[str, np.ndarray | tuple[str, ...]]:
50
+ district_context = self._compute_district_context(
51
+ lane_vehicle_count=lane_vehicle_count,
52
+ lane_waiting_count=lane_waiting_count,
53
+ )
54
+
55
+ num_intersections = len(self.intersection_ids)
56
+ max_lanes = self.config.max_incoming_lanes
57
+
58
+ observations = np.zeros(
59
+ (num_intersections, self.observation_dim),
60
+ dtype=np.float32,
61
+ )
62
+ incoming_counts = np.zeros((num_intersections, max_lanes), dtype=np.float32)
63
+ incoming_waiting = np.zeros((num_intersections, max_lanes), dtype=np.float32)
64
+ lane_mask = np.zeros((num_intersections, max_lanes), dtype=np.float32)
65
+ action_mask = np.ones((num_intersections, 2), dtype=np.float32)
66
+ current_phase = np.zeros(num_intersections, dtype=np.int64)
67
+ phase_elapsed = np.zeros(num_intersections, dtype=np.float32)
68
+ outgoing_congestion = np.zeros(num_intersections, dtype=np.float32)
69
+ district_type_indices = np.zeros(num_intersections, dtype=np.int64)
70
+ boundary_mask = np.zeros(num_intersections, dtype=np.float32)
71
+
72
+ for row_index, intersection_id in enumerate(self.intersection_ids):
73
+ config = self.intersections[intersection_id]
74
+ lane_count_vector, waiting_vector, mask_vector = self._lane_vectors(
75
+ config=config,
76
+ lane_vehicle_count=lane_vehicle_count,
77
+ lane_waiting_count=lane_waiting_count,
78
+ )
79
+ incoming_counts[row_index] = lane_count_vector
80
+ incoming_waiting[row_index] = waiting_vector
81
+ lane_mask[row_index] = mask_vector
82
+
83
+ phase_index = int(phase_positions[intersection_id])
84
+ phase_time = float(phase_elapsed_times[intersection_id])
85
+ phase_count = max(1, config.num_green_phases)
86
+ current_phase[row_index] = phase_index
87
+ phase_elapsed[row_index] = phase_time
88
+ district_type_indices[row_index] = config.district_type_index
89
+ boundary_mask[row_index] = 1.0 if config.is_boundary else 0.0
90
+
91
+ next_col = 0
92
+ observations[row_index, next_col : next_col + max_lanes] = (
93
+ lane_count_vector / self.config.count_scale
94
+ )
95
+ next_col += max_lanes
96
+ observations[row_index, next_col : next_col + max_lanes] = (
97
+ waiting_vector / self.config.count_scale
98
+ )
99
+ next_col += max_lanes
100
+ observations[row_index, next_col : next_col + max_lanes] = mask_vector
101
+ next_col += max_lanes
102
+
103
+ if self.config.include_outgoing_congestion:
104
+ outgoing_congestion[row_index] = self._mean_outgoing_congestion(
105
+ config=config,
106
+ lane_vehicle_count=lane_vehicle_count,
107
+ )
108
+
109
+ meta_features = [
110
+ normalize_scalar(phase_index, max(1, phase_count - 1))
111
+ if phase_count > 1
112
+ else 0.0,
113
+ normalize_scalar(phase_time, self.config.elapsed_time_scale),
114
+ normalize_scalar(float(outgoing_congestion[row_index]), self.config.count_scale),
115
+ normalize_scalar(float(lane_count_vector.sum()), self.config.count_scale),
116
+ normalize_scalar(float(phase_count), 4.0),
117
+ 1.0 if switch_allowed[intersection_id] else 0.0,
118
+ boundary_mask[row_index],
119
+ ]
120
+ observations[row_index, next_col : next_col + len(meta_features)] = meta_features
121
+ next_col += len(meta_features)
122
+
123
+ if self.config.include_district_type_feature:
124
+ observations[row_index, next_col + config.district_type_index] = 1.0
125
+ next_col += 4
126
+
127
+ if self.config.include_district_context:
128
+ district_values = district_context.get(
129
+ config.district_id,
130
+ (0.0, 0.0),
131
+ )
132
+ observations[row_index, next_col : next_col + len(district_values)] = district_values
133
+
134
+ if not switch_allowed[intersection_id]:
135
+ action_mask[row_index, 1] = 0.0
136
+
137
+ return {
138
+ "observations": observations,
139
+ "incoming_counts": incoming_counts,
140
+ "incoming_waiting": incoming_waiting,
141
+ "lane_mask": lane_mask,
142
+ "action_mask": action_mask,
143
+ "current_phase": current_phase,
144
+ "phase_elapsed": phase_elapsed,
145
+ "outgoing_congestion": outgoing_congestion,
146
+ "boundary_mask": boundary_mask,
147
+ "district_type_indices": district_type_indices,
148
+ "district_types": tuple(
149
+ self.intersections[intersection_id].district_type
150
+ for intersection_id in self.intersection_ids
151
+ ),
152
+ "district_ids": tuple(
153
+ self.intersections[intersection_id].district_id
154
+ for intersection_id in self.intersection_ids
155
+ ),
156
+ "intersection_ids": self.intersection_ids,
157
+ }
158
+
159
+ def _compute_observation_dim(self) -> int:
160
+ base_dim = self.config.max_incoming_lanes * 3 + 7
161
+ if self.config.include_district_type_feature:
162
+ base_dim += 4
163
+ if self.config.include_district_context:
164
+ base_dim += 2
165
+ return base_dim
166
+
167
+ def _lane_vectors(
168
+ self,
169
+ config: IntersectionConfig,
170
+ lane_vehicle_count: dict[str, int],
171
+ lane_waiting_count: dict[str, int],
172
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
173
+ max_lanes = self.config.max_incoming_lanes
174
+ count_vector = np.zeros(max_lanes, dtype=np.float32)
175
+ waiting_vector = np.zeros(max_lanes, dtype=np.float32)
176
+ mask_vector = np.zeros(max_lanes, dtype=np.float32)
177
+
178
+ for lane_index, lane_id in enumerate(config.incoming_lanes[:max_lanes]):
179
+ count_vector[lane_index] = float(lane_vehicle_count.get(lane_id, 0))
180
+ waiting_vector[lane_index] = float(lane_waiting_count.get(lane_id, 0))
181
+ mask_vector[lane_index] = 1.0
182
+
183
+ return count_vector, waiting_vector, mask_vector
184
+
185
+ def _mean_outgoing_congestion(
186
+ self,
187
+ config: IntersectionConfig,
188
+ lane_vehicle_count: dict[str, int],
189
+ ) -> float:
190
+ if not config.outgoing_lanes:
191
+ return 0.0
192
+ total = sum(float(lane_vehicle_count.get(lane_id, 0)) for lane_id in config.outgoing_lanes)
193
+ return total / float(len(config.outgoing_lanes))
194
+
195
+ def _compute_district_context(
196
+ self,
197
+ lane_vehicle_count: dict[str, int],
198
+ lane_waiting_count: dict[str, int],
199
+ ) -> dict[str, tuple[float, float]]:
200
+ context: dict[str, tuple[float, float]] = {}
201
+ if not self.config.include_district_context:
202
+ return context
203
+
204
+ for district_id, district in self.districts.items():
205
+ total_count = 0.0
206
+ total_waiting = 0.0
207
+ for intersection_id in district.intersection_ids:
208
+ config = self.intersections[intersection_id]
209
+ total_count += sum(
210
+ float(lane_vehicle_count.get(lane_id, 0))
211
+ for lane_id in config.incoming_lanes
212
+ )
213
+ total_waiting += sum(
214
+ float(lane_waiting_count.get(lane_id, 0))
215
+ for lane_id in config.incoming_lanes
216
+ )
217
+
218
+ size = float(self._district_sizes[district_id])
219
+ context[district_id] = (
220
+ normalize_scalar(total_count / size, self.config.count_scale),
221
+ normalize_scalar(total_waiting / size, self.config.count_scale),
222
+ )
223
+
224
+ return context
env/reward.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+
7
+ REWARD_VARIANTS: tuple[str, ...] = (
8
+ "current",
9
+ "normalized_wait_queue",
10
+ "wait_queue_throughput",
11
+ )
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class RewardConfig:
16
+ variant: str = "current"
17
+ waiting_weight: float = 1.0
18
+ vehicle_weight: float = 0.1
19
+ pressure_weight: float = 0.0
20
+ reward_scale: float = 0.1
21
+ normalize_by_lane_count: bool = True
22
+ clip_reward: float | None = 5.0
23
+ queue_delta_weight: float = 2.0
24
+ wait_delta_weight: float = 4.0
25
+ queue_level_weight: float = 0.5
26
+ wait_level_weight: float = 1.0
27
+ throughput_weight: float = 0.1
28
+ imbalance_weight: float = 0.1
29
+ delta_clip: float = 2.0
30
+ level_normalizer: float = 10.0
31
+ throughput_normalizer: float = 2.0
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class RewardBreakdown:
36
+ reward: np.ndarray
37
+ components: dict[str, np.ndarray]
38
+
39
+
40
+ class RewardCalculator:
41
+ def __init__(self, config: RewardConfig | None = None):
42
+ self.config = config or RewardConfig()
43
+ if self.config.variant not in REWARD_VARIANTS:
44
+ raise ValueError(
45
+ f"Unsupported reward variant: {self.config.variant}. "
46
+ f"Expected one of {REWARD_VARIANTS}."
47
+ )
48
+ self._prev_queue_norm: np.ndarray | None = None
49
+ self._prev_wait_norm: np.ndarray | None = None
50
+ self._prev_finished_vehicle_count = 0.0
51
+
52
+ def reset(
53
+ self,
54
+ incoming_waiting: np.ndarray,
55
+ incoming_counts: np.ndarray,
56
+ incoming_lane_counts: np.ndarray | None = None,
57
+ finished_vehicle_count: float = 0.0,
58
+ ) -> None:
59
+ queue_norm, wait_norm, _ = self._normalized_state(
60
+ incoming_waiting=incoming_waiting,
61
+ incoming_counts=incoming_counts,
62
+ incoming_lane_counts=incoming_lane_counts,
63
+ )
64
+ self._prev_queue_norm = queue_norm
65
+ self._prev_wait_norm = wait_norm
66
+ self._prev_finished_vehicle_count = float(finished_vehicle_count)
67
+
68
+ def compute(
69
+ self,
70
+ incoming_waiting: np.ndarray,
71
+ incoming_counts: np.ndarray,
72
+ outgoing_counts: np.ndarray | None = None,
73
+ incoming_lane_counts: np.ndarray | None = None,
74
+ finished_vehicle_count: float = 0.0,
75
+ ) -> np.ndarray:
76
+ return self.compute_breakdown(
77
+ incoming_waiting=incoming_waiting,
78
+ incoming_counts=incoming_counts,
79
+ outgoing_counts=outgoing_counts,
80
+ incoming_lane_counts=incoming_lane_counts,
81
+ finished_vehicle_count=finished_vehicle_count,
82
+ ).reward
83
+
84
+ def compute_breakdown(
85
+ self,
86
+ incoming_waiting: np.ndarray,
87
+ incoming_counts: np.ndarray,
88
+ outgoing_counts: np.ndarray | None = None,
89
+ incoming_lane_counts: np.ndarray | None = None,
90
+ finished_vehicle_count: float = 0.0,
91
+ ) -> RewardBreakdown:
92
+ if self.config.variant == "current":
93
+ return self._compute_current(
94
+ incoming_waiting=incoming_waiting,
95
+ incoming_counts=incoming_counts,
96
+ outgoing_counts=outgoing_counts,
97
+ incoming_lane_counts=incoming_lane_counts,
98
+ )
99
+ return self._compute_delta_based(
100
+ incoming_waiting=incoming_waiting,
101
+ incoming_counts=incoming_counts,
102
+ incoming_lane_counts=incoming_lane_counts,
103
+ finished_vehicle_count=finished_vehicle_count,
104
+ include_throughput=self.config.variant == "wait_queue_throughput",
105
+ )
106
+
107
+ def _compute_current(
108
+ self,
109
+ incoming_waiting: np.ndarray,
110
+ incoming_counts: np.ndarray,
111
+ outgoing_counts: np.ndarray | None = None,
112
+ incoming_lane_counts: np.ndarray | None = None,
113
+ ) -> RewardBreakdown:
114
+ waiting_total = incoming_waiting.sum(axis=1)
115
+ vehicle_total = incoming_counts.sum(axis=1)
116
+ normalization = self._lane_normalization(waiting_total.shape[0], incoming_lane_counts)
117
+
118
+ components = {
119
+ "wait_term": (-self.config.waiting_weight * waiting_total / normalization).astype(np.float32),
120
+ "queue_term": (-self.config.vehicle_weight * vehicle_total / normalization).astype(np.float32),
121
+ }
122
+ if outgoing_counts is not None and self.config.pressure_weight != 0.0:
123
+ outgoing_total = outgoing_counts.sum(axis=1)
124
+ components["pressure_term"] = (
125
+ self.config.pressure_weight * (outgoing_total - vehicle_total) / normalization
126
+ ).astype(np.float32)
127
+ components = self._scale_components(components)
128
+ reward = self._finalize_reward(components)
129
+ return RewardBreakdown(reward=reward, components=components)
130
+
131
+ def _compute_delta_based(
132
+ self,
133
+ incoming_waiting: np.ndarray,
134
+ incoming_counts: np.ndarray,
135
+ incoming_lane_counts: np.ndarray | None,
136
+ finished_vehicle_count: float,
137
+ include_throughput: bool,
138
+ ) -> RewardBreakdown:
139
+ queue_norm, wait_norm, lane_norm = self._normalized_state(
140
+ incoming_waiting=incoming_waiting,
141
+ incoming_counts=incoming_counts,
142
+ incoming_lane_counts=incoming_lane_counts,
143
+ )
144
+
145
+ if self._prev_queue_norm is None or self._prev_wait_norm is None:
146
+ self._prev_queue_norm = queue_norm.copy()
147
+ self._prev_wait_norm = wait_norm.copy()
148
+
149
+ queue_delta = np.clip(
150
+ self._prev_queue_norm - queue_norm,
151
+ -self.config.delta_clip,
152
+ self.config.delta_clip,
153
+ ).astype(np.float32)
154
+ wait_delta = np.clip(
155
+ self._prev_wait_norm - wait_norm,
156
+ -self.config.delta_clip,
157
+ self.config.delta_clip,
158
+ ).astype(np.float32)
159
+
160
+ components: dict[str, np.ndarray] = {
161
+ "queue_term": (self.config.queue_delta_weight * queue_delta).astype(np.float32),
162
+ "wait_term": (self.config.wait_delta_weight * wait_delta).astype(np.float32),
163
+ "queue_level_term": (
164
+ -self.config.queue_level_weight
165
+ * np.clip(queue_norm / self.config.level_normalizer, 0.0, self.config.delta_clip)
166
+ ).astype(np.float32),
167
+ "wait_level_term": (
168
+ -self.config.wait_level_weight
169
+ * np.clip(wait_norm / self.config.level_normalizer, 0.0, self.config.delta_clip)
170
+ ).astype(np.float32),
171
+ }
172
+
173
+ if include_throughput:
174
+ num_intersections = max(1, queue_norm.shape[0])
175
+ finished_delta = max(
176
+ 0.0,
177
+ float(finished_vehicle_count) - self._prev_finished_vehicle_count,
178
+ )
179
+ throughput_per_intersection = finished_delta / float(num_intersections)
180
+ throughput_term = np.full(
181
+ queue_norm.shape,
182
+ self.config.throughput_weight
183
+ * min(1.0, throughput_per_intersection / self.config.throughput_normalizer),
184
+ dtype=np.float32,
185
+ )
186
+ imbalance = np.std(
187
+ incoming_waiting / lane_norm[:, None],
188
+ axis=1,
189
+ ).astype(np.float32)
190
+ components["throughput_term"] = throughput_term
191
+ components["imbalance_term"] = (-self.config.imbalance_weight * imbalance).astype(
192
+ np.float32
193
+ )
194
+
195
+ components = self._scale_components(components)
196
+ reward = self._finalize_reward(components)
197
+ self._prev_queue_norm = queue_norm
198
+ self._prev_wait_norm = wait_norm
199
+ self._prev_finished_vehicle_count = float(finished_vehicle_count)
200
+ return RewardBreakdown(reward=reward, components=components)
201
+
202
+ def _normalized_state(
203
+ self,
204
+ incoming_waiting: np.ndarray,
205
+ incoming_counts: np.ndarray,
206
+ incoming_lane_counts: np.ndarray | None,
207
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
208
+ lane_norm = self._lane_normalization(incoming_counts.shape[0], incoming_lane_counts)
209
+ queue_norm = (incoming_counts.sum(axis=1) / lane_norm).astype(np.float32)
210
+ wait_norm = (incoming_waiting.sum(axis=1) / lane_norm).astype(np.float32)
211
+ return queue_norm, wait_norm, lane_norm
212
+
213
+ def _lane_normalization(
214
+ self,
215
+ batch_size: int,
216
+ incoming_lane_counts: np.ndarray | None,
217
+ ) -> np.ndarray:
218
+ normalization = np.ones(batch_size, dtype=np.float32)
219
+ if incoming_lane_counts is not None and self.config.normalize_by_lane_count:
220
+ normalization = np.maximum(1.0, incoming_lane_counts.astype(np.float32))
221
+ return normalization
222
+
223
+ def _finalize_reward(self, components: dict[str, np.ndarray]) -> np.ndarray:
224
+ reward = np.zeros_like(next(iter(components.values())), dtype=np.float32)
225
+ for term in components.values():
226
+ reward += term.astype(np.float32)
227
+
228
+ if self.config.clip_reward is not None:
229
+ reward = np.clip(
230
+ reward,
231
+ -float(self.config.clip_reward),
232
+ float(self.config.clip_reward),
233
+ )
234
+ return reward.astype(np.float32)
235
+
236
+ def _scale_components(
237
+ self,
238
+ components: dict[str, np.ndarray],
239
+ ) -> dict[str, np.ndarray]:
240
+ scale = float(self.config.reward_scale)
241
+ return {
242
+ name: (values.astype(np.float32) * scale).astype(np.float32)
243
+ for name, values in components.items()
244
+ }
env/scenarios.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ class ScenarioGenerator:
2
+ def generate(self, seed):
3
+ import random
4
+
5
+ random.seed(seed)
6
+
7
+ return {
8
+ "traffic_bias": random.choice(["ns", "ew", "balanced"]),
9
+ "emergency_vehicle": random.random() < 0.2,
10
+ }
env/traffic_env.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+
8
+ from env.cityflow_adapter import CityFlowAdapter
9
+ from env.intersection_config import DistrictConfig, IntersectionConfig
10
+ from env.observation_builder import ObservationBuilder, ObservationConfig
11
+ from env.reward import RewardCalculator, RewardConfig
12
+ from env.utils import build_topology, load_json
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class EnvConfig:
17
+ simulator_interval: int = 1
18
+ decision_interval: int = 5
19
+ min_green_time: int = 10
20
+ thread_num: int = 1
21
+ observation: ObservationConfig = ObservationConfig()
22
+ reward: RewardConfig = RewardConfig()
23
+ max_episode_seconds: int | None = None
24
+
25
+
26
+ class TrafficEnv:
27
+ def __init__(
28
+ self,
29
+ city_id: str,
30
+ scenario_name: str,
31
+ city_dir: str | Path,
32
+ scenario_dir: str | Path,
33
+ config_path: str | Path,
34
+ roadnet_path: str | Path,
35
+ district_map_path: str | Path | None = None,
36
+ metadata_path: str | Path | None = None,
37
+ env_config: EnvConfig | None = None,
38
+ ):
39
+ self.city_id = city_id
40
+ self.scenario_name = scenario_name
41
+ self.city_dir = Path(city_dir)
42
+ self.scenario_dir = Path(scenario_dir)
43
+ self.original_config_path = Path(config_path)
44
+ self.roadnet_path = Path(roadnet_path)
45
+ self.district_map_path = Path(district_map_path) if district_map_path else None
46
+ self.metadata_path = Path(metadata_path) if metadata_path else None
47
+ self.env_config = env_config or EnvConfig()
48
+
49
+ self.intersections, self.districts = build_topology(
50
+ roadnet_path=self.roadnet_path,
51
+ district_map_path=self.district_map_path,
52
+ metadata_path=self.metadata_path,
53
+ )
54
+ if not self.intersections:
55
+ raise ValueError(
56
+ f"No controllable intersections found for {self.city_id}/{self.scenario_name}."
57
+ )
58
+
59
+ self.controlled_intersection_ids = tuple(sorted(self.intersections))
60
+ self.observation_builder = ObservationBuilder(
61
+ intersections=self.intersections,
62
+ districts=self.districts,
63
+ config=self.env_config.observation,
64
+ )
65
+ self.reward_calculator = RewardCalculator(self.env_config.reward)
66
+ self.adapter = CityFlowAdapter(
67
+ config_path=self.original_config_path,
68
+ thread_num=self.env_config.thread_num,
69
+ )
70
+
71
+ config_payload = load_json(self.original_config_path)
72
+ self.max_episode_seconds = int(
73
+ self.env_config.max_episode_seconds
74
+ or config_payload.get("step", 0)
75
+ )
76
+ self.metadata = load_json(self.metadata_path) if self.metadata_path else {}
77
+ self._district_type_labels = tuple(
78
+ self.intersections[intersection_id].district_type
79
+ for intersection_id in self.controlled_intersection_ids
80
+ )
81
+ self._incoming_lane_counts = np.asarray(
82
+ [
83
+ max(1, len(self.intersections[intersection_id].incoming_lanes))
84
+ for intersection_id in self.controlled_intersection_ids
85
+ ],
86
+ dtype=np.float32,
87
+ )
88
+
89
+ self.current_phase_positions: dict[str, int] = {}
90
+ self.phase_elapsed_times: dict[str, int] = {}
91
+ self.decision_step_count = 0
92
+ self.episode_return = 0.0
93
+ self.total_episode_return = 0.0
94
+ self.last_info: dict[str, Any] = {}
95
+ self.reward_component_sums: dict[str, float] = {}
96
+
97
+ @property
98
+ def observation_dim(self) -> int:
99
+ return self.observation_builder.observation_dim
100
+
101
+ def reset(self, seed: int | None = None) -> dict[str, Any]:
102
+ del seed
103
+ self.adapter.reset()
104
+ self.decision_step_count = 0
105
+ self.episode_return = 0.0
106
+ self.total_episode_return = 0.0
107
+ self.reward_component_sums = {}
108
+
109
+ self.current_phase_positions = {}
110
+ self.phase_elapsed_times = {}
111
+ for intersection_id in self.controlled_intersection_ids:
112
+ config = self.intersections[intersection_id]
113
+ initial_position = 0
114
+ initial_phase = config.green_phases[initial_position].engine_phase_index
115
+ self.current_phase_positions[intersection_id] = initial_position
116
+ self.phase_elapsed_times[intersection_id] = 0
117
+ self.adapter.set_tl_phase(intersection_id, initial_phase)
118
+
119
+ observation = self._build_observation()
120
+ self.reward_calculator.reset(
121
+ incoming_waiting=observation["incoming_waiting"],
122
+ incoming_counts=observation["incoming_counts"],
123
+ incoming_lane_counts=self._incoming_lane_counts,
124
+ finished_vehicle_count=self.adapter.get_finished_vehicle_count(),
125
+ )
126
+ self.last_info = self._build_info(
127
+ rewards=np.zeros(len(self.controlled_intersection_ids), dtype=np.float32),
128
+ avg_incoming_counts=observation["incoming_counts"],
129
+ avg_incoming_waiting=observation["incoming_waiting"],
130
+ reward_components={},
131
+ )
132
+ return observation
133
+
134
+ def step(
135
+ self,
136
+ actions: dict[str, int] | list[int] | np.ndarray,
137
+ ) -> tuple[dict[str, Any], np.ndarray, bool, dict[str, Any]]:
138
+ normalized_actions = self._normalize_actions(actions)
139
+ self._apply_actions(normalized_actions)
140
+
141
+ avg_incoming_counts, avg_incoming_waiting, avg_outgoing_counts = self._advance_simulator()
142
+ reward_breakdown = self.reward_calculator.compute_breakdown(
143
+ incoming_waiting=avg_incoming_waiting,
144
+ incoming_counts=avg_incoming_counts,
145
+ outgoing_counts=avg_outgoing_counts,
146
+ incoming_lane_counts=self._incoming_lane_counts,
147
+ finished_vehicle_count=self.adapter.get_finished_vehicle_count(),
148
+ )
149
+ rewards = reward_breakdown.reward
150
+ self.decision_step_count += 1
151
+ self.total_episode_return += float(rewards.sum())
152
+ self.episode_return = self._mean_step_intersection_reward()
153
+ self._accumulate_reward_components(reward_breakdown.components)
154
+
155
+ observation = self._build_observation()
156
+ done = self.adapter.get_current_time() >= self.max_episode_seconds
157
+ info = self._build_info(
158
+ rewards=rewards,
159
+ avg_incoming_counts=avg_incoming_counts,
160
+ avg_incoming_waiting=avg_incoming_waiting,
161
+ reward_components=reward_breakdown.components,
162
+ )
163
+ self.last_info = info
164
+ return observation, rewards, done, info
165
+
166
+ def _build_observation(self) -> dict[str, Any]:
167
+ lane_vehicle_count = self.adapter.get_lane_vehicle_count()
168
+ lane_waiting_count = self.adapter.get_lane_waiting_vehicle_count()
169
+ switch_allowed = {
170
+ intersection_id: (
171
+ self.phase_elapsed_times[intersection_id] >= self.env_config.min_green_time
172
+ )
173
+ for intersection_id in self.controlled_intersection_ids
174
+ }
175
+
176
+ observation = self.observation_builder.build(
177
+ lane_vehicle_count=lane_vehicle_count,
178
+ lane_waiting_count=lane_waiting_count,
179
+ phase_positions=self.current_phase_positions,
180
+ phase_elapsed_times=self.phase_elapsed_times,
181
+ switch_allowed=switch_allowed,
182
+ )
183
+ observation["city_id"] = self.city_id
184
+ observation["scenario_name"] = self.scenario_name
185
+ observation["decision_step"] = self.decision_step_count
186
+ observation["sim_time"] = self.adapter.get_current_time()
187
+ return observation
188
+
189
+ def _apply_actions(self, actions: np.ndarray) -> None:
190
+ for action_index, intersection_id in enumerate(self.controlled_intersection_ids):
191
+ config = self.intersections[intersection_id]
192
+ current_position = self.current_phase_positions[intersection_id]
193
+ can_switch = self.phase_elapsed_times[intersection_id] >= self.env_config.min_green_time
194
+ should_switch = int(actions[action_index]) == 1 and can_switch
195
+
196
+ if should_switch:
197
+ next_position = (current_position + 1) % config.num_green_phases
198
+ engine_phase = config.green_phases[next_position].engine_phase_index
199
+ self.adapter.set_tl_phase(intersection_id, engine_phase)
200
+ self.current_phase_positions[intersection_id] = next_position
201
+ self.phase_elapsed_times[intersection_id] = 0
202
+ else:
203
+ current_engine_phase = config.green_phases[current_position].engine_phase_index
204
+ self.adapter.set_tl_phase(intersection_id, current_engine_phase)
205
+
206
+ def _advance_simulator(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
207
+ num_intersections = len(self.controlled_intersection_ids)
208
+ max_lanes = self.env_config.observation.max_incoming_lanes
209
+ avg_incoming_counts = np.zeros((num_intersections, max_lanes), dtype=np.float32)
210
+ avg_incoming_waiting = np.zeros((num_intersections, max_lanes), dtype=np.float32)
211
+ avg_outgoing_counts = np.zeros((num_intersections, max_lanes), dtype=np.float32)
212
+
213
+ for _ in range(self.env_config.decision_interval):
214
+ self.adapter.step()
215
+ lane_vehicle_count = self.adapter.get_lane_vehicle_count()
216
+ lane_waiting_count = self.adapter.get_lane_waiting_vehicle_count()
217
+
218
+ for row_index, intersection_id in enumerate(self.controlled_intersection_ids):
219
+ config = self.intersections[intersection_id]
220
+ for lane_index, lane_id in enumerate(
221
+ config.incoming_lanes[: self.env_config.observation.max_incoming_lanes]
222
+ ):
223
+ avg_incoming_counts[row_index, lane_index] += float(
224
+ lane_vehicle_count.get(lane_id, 0)
225
+ )
226
+ avg_incoming_waiting[row_index, lane_index] += float(
227
+ lane_waiting_count.get(lane_id, 0)
228
+ )
229
+ for lane_index, lane_id in enumerate(
230
+ config.outgoing_lanes[: self.env_config.observation.max_incoming_lanes]
231
+ ):
232
+ avg_outgoing_counts[row_index, lane_index] += float(
233
+ lane_vehicle_count.get(lane_id, 0)
234
+ )
235
+
236
+ self.phase_elapsed_times[intersection_id] += self.env_config.simulator_interval
237
+
238
+ avg_incoming_counts /= float(self.env_config.decision_interval)
239
+ avg_incoming_waiting /= float(self.env_config.decision_interval)
240
+ avg_outgoing_counts /= float(self.env_config.decision_interval)
241
+ return avg_incoming_counts, avg_incoming_waiting, avg_outgoing_counts
242
+
243
+ def _build_info(
244
+ self,
245
+ rewards: np.ndarray,
246
+ avg_incoming_counts: np.ndarray,
247
+ avg_incoming_waiting: np.ndarray,
248
+ reward_components: dict[str, np.ndarray],
249
+ ) -> dict[str, Any]:
250
+ mean_reward = float(rewards.mean()) if rewards.size else 0.0
251
+ average_travel_time = self.adapter.get_average_travel_time()
252
+ info = {
253
+ "city_id": self.city_id,
254
+ "scenario_name": self.scenario_name,
255
+ "decision_step": self.decision_step_count,
256
+ "sim_time": self.adapter.get_current_time(),
257
+ "episode_return": float(self.episode_return),
258
+ "total_episode_return": float(self.total_episode_return),
259
+ "intersection_ids": self.controlled_intersection_ids,
260
+ "district_types": self._district_type_labels,
261
+ "metrics": {
262
+ "num_controlled_intersections": len(self.controlled_intersection_ids),
263
+ "mean_reward": mean_reward,
264
+ "mean_step_intersection_reward": self._mean_step_intersection_reward(),
265
+ "mean_waiting_vehicles": float(avg_incoming_waiting.sum(axis=1).mean()),
266
+ "mean_incoming_vehicles": float(avg_incoming_counts.sum(axis=1).mean()),
267
+ "total_waiting_vehicles": float(avg_incoming_waiting.sum()),
268
+ "total_incoming_vehicles": float(avg_incoming_counts.sum()),
269
+ "running_vehicles": self.adapter.get_vehicle_count(),
270
+ "throughput": self.adapter.get_finished_vehicle_count(),
271
+ "average_travel_time": average_travel_time,
272
+ "reward_variant": self.env_config.reward.variant,
273
+ },
274
+ }
275
+ info["metrics"].update(self._reward_component_metrics(reward_components))
276
+ info["metrics"].update(
277
+ per_district_type_metrics(
278
+ district_types=self._district_type_labels,
279
+ rewards=rewards,
280
+ avg_incoming_counts=avg_incoming_counts,
281
+ avg_incoming_waiting=avg_incoming_waiting,
282
+ )
283
+ )
284
+ return info
285
+
286
+ def _normalize_actions(
287
+ self,
288
+ actions: dict[str, int] | list[int] | np.ndarray,
289
+ ) -> np.ndarray:
290
+ if isinstance(actions, dict):
291
+ return np.asarray(
292
+ [int(actions.get(intersection_id, 0)) for intersection_id in self.controlled_intersection_ids],
293
+ dtype=np.int64,
294
+ )
295
+ array = np.asarray(actions, dtype=np.int64)
296
+ if array.shape != (len(self.controlled_intersection_ids),):
297
+ raise ValueError(
298
+ "Actions must provide exactly one action per controlled intersection."
299
+ )
300
+ return array
301
+
302
+ def _mean_step_intersection_reward(self) -> float:
303
+ denominator = max(
304
+ 1,
305
+ self.decision_step_count * len(self.controlled_intersection_ids),
306
+ )
307
+ return float(self.total_episode_return) / float(denominator)
308
+
309
+ def _accumulate_reward_components(self, components: dict[str, np.ndarray]) -> None:
310
+ for name, values in components.items():
311
+ self.reward_component_sums[name] = self.reward_component_sums.get(name, 0.0) + float(
312
+ np.asarray(values, dtype=np.float32).mean()
313
+ )
314
+
315
+ def _reward_component_metrics(
316
+ self,
317
+ reward_components: dict[str, np.ndarray],
318
+ ) -> dict[str, float]:
319
+ metrics: dict[str, float] = {}
320
+ for name, values in reward_components.items():
321
+ metrics[f"reward_component_step_{name}"] = float(
322
+ np.asarray(values, dtype=np.float32).mean()
323
+ )
324
+ if self.decision_step_count <= 0:
325
+ return metrics
326
+ for name, total in self.reward_component_sums.items():
327
+ metrics[f"reward_component_mean_{name}"] = float(total) / float(
328
+ self.decision_step_count
329
+ )
330
+ return metrics
331
+
332
+
333
+ def per_district_type_metrics(
334
+ district_types: tuple[str, ...],
335
+ rewards: np.ndarray,
336
+ avg_incoming_counts: np.ndarray,
337
+ avg_incoming_waiting: np.ndarray,
338
+ ) -> dict[str, float]:
339
+ metrics: dict[str, float] = {}
340
+ reward_vector = np.asarray(rewards, dtype=np.float32)
341
+ incoming_totals = avg_incoming_counts.sum(axis=1)
342
+ waiting_totals = avg_incoming_waiting.sum(axis=1)
343
+
344
+ for district_type in sorted(set(district_types)):
345
+ mask = np.asarray(
346
+ [item == district_type for item in district_types],
347
+ dtype=bool,
348
+ )
349
+ if not mask.any():
350
+ continue
351
+ metrics[f"num_{district_type}_intersections"] = float(mask.sum())
352
+ metrics[f"mean_reward_{district_type}"] = float(reward_vector[mask].mean())
353
+ metrics[f"mean_waiting_vehicles_{district_type}"] = float(waiting_totals[mask].mean())
354
+ metrics[f"mean_incoming_vehicles_{district_type}"] = float(incoming_totals[mask].mean())
355
+
356
+ return metrics