Aditya2162's picture
Upload folder using huggingface_hub
3d2dbcf verified
# training
Training, evaluation, device selection, and dataset utilities for the local multi-agent DQN stack.
## Main files
- [train_local_policy.py](/Users/aditya/Developer/traffic-llm/training/train_local_policy.py)
Main CLI for split generation, DQN training, and checkpoint evaluation.
- [trainer.py](/Users/aditya/Developer/traffic-llm/training/trainer.py)
Replay buffer, DQN training loop, checkpointing, validation, and aggregate metrics.
- [rollout.py](/Users/aditya/Developer/traffic-llm/training/rollout.py)
Evaluation helper shared by learned agents and rule-based baselines.
- [models.py](/Users/aditya/Developer/traffic-llm/training/models.py)
Dueling Q-network and running observation normalization.
- [device.py](/Users/aditya/Developer/traffic-llm/training/device.py)
Torch device selection for `cuda`, `mps`, or `cpu`.
- [cityflow_dataset.py](/root/aditya/agentic-traffic/training/cityflow_dataset.py)
City discovery, split loading, and scenario sampling.
## Main entry points
- `python3 -m training.train_local_policy make-splits`
- `python3 -m training.train_local_policy train`
- `python3 -m training.train_local_policy train --max-val-cities 3 --val-scenarios-per-city 7`
- `python3 -m training.train_local_policy train --max-train-cities 70 --max-val-cities 3 --val-scenarios-per-city 7 --policy-arch single_head_with_district_feature --reward-variant wait_queue_throughput`
- `python3 -m training.train_local_policy train --policy-arch single_head_with_district_feature --reward-variant wait_queue_throughput`
- `python3 -m training.train_local_policy evaluate --checkpoint artifacts/dqn_shared/best_validation.pt --split val`
- `python3 -m training.train_local_policy evaluate --baseline queue_greedy --split val`
- `tensorboard --logdir artifacts/dqn_shared/tensorboard`
## Training flow
1. Load city-level splits from `data/splits/`.
2. Sample one `(city, scenario)` episode at a time from the train split.
3. Run one shared Q-network across all controlled intersections in that city.
4. Collect per-intersection transitions into prioritized replay with n-step returns, using parallel CPU rollout workers by default.
5. Perform Double DQN updates against a target network.
6. Periodically evaluate on validation cities and save checkpoints. By default, eval and checkpoint cadence are both 40 updates, and each validation pass also writes an `update_XXXX.pt` checkpoint.