agentic-traffic / training /rollout.py
Aditya2162's picture
Upload folder using huggingface_hub
3d2dbcf verified
from __future__ import annotations
from typing import Protocol
import numpy as np
import torch
from agents.local_policy import BaseLocalPolicy
from training.models import RunningNormalizer, TrafficControlQNetwork
class TorchLocalPolicyProtocol(Protocol):
def act(
self,
observations: torch.Tensor,
district_type_indices: torch.Tensor,
action_mask: torch.Tensor | None = None,
deterministic: bool = False,
epsilon: float = 0.0,
) -> torch.Tensor:
...
def evaluate_policy(
env_factory,
actor: TrafficControlQNetwork | BaseLocalPolicy | TorchLocalPolicyProtocol,
device: torch.device | None = None,
obs_normalizer: RunningNormalizer | None = None,
deterministic: bool = True,
log_prefix: str | None = None,
log_every_steps: int = 0,
) -> dict[str, float | str]:
env = env_factory()
observation_batch = env.reset()
done = False
final_info = env.last_info
max_decision_steps = max(
1,
int(getattr(env, "max_episode_seconds", 0) // max(1, env.env_config.decision_interval)),
)
while not done:
if isinstance(actor, BaseLocalPolicy):
actions = actor.act(observation_batch)
else:
raw_obs = observation_batch["observations"].astype(np.float32)
normalized_obs = obs_normalizer.normalize(raw_obs) if obs_normalizer else raw_obs
obs_tensor = torch.as_tensor(normalized_obs, dtype=torch.float32, device=device)
district_type_tensor = torch.as_tensor(
observation_batch["district_type_indices"],
dtype=torch.int64,
device=device,
)
action_mask_tensor = torch.as_tensor(
observation_batch["action_mask"],
dtype=torch.float32,
device=device,
)
with torch.no_grad():
action_tensor = actor.act(
observations=obs_tensor,
district_type_indices=district_type_tensor,
action_mask=action_mask_tensor,
deterministic=deterministic,
epsilon=0.0,
)
actions = action_tensor.cpu().numpy()
observation_batch, _, done, final_info = env.step(actions)
if log_prefix and log_every_steps > 0:
decision_step = int(getattr(env, "decision_step_count", 0))
should_log = decision_step == 1 or done or (decision_step % log_every_steps == 0)
if should_log:
sim_time = int(getattr(env.adapter, "get_current_time", lambda: 0)())
metrics = final_info.get("metrics", {}) if isinstance(final_info, dict) else {}
print(
f"{log_prefix} step={decision_step}/{max_decision_steps} "
f"sim_time={sim_time}s wait={float(metrics.get('mean_waiting_vehicles', float('nan'))):.2f} "
f"throughput={float(metrics.get('throughput', float('nan'))):.1f}"
)
metrics = {
key: float(value)
for key, value in final_info["metrics"].items()
if value is not None and isinstance(value, (int, float))
}
metrics.update(
{
"city_id": env.city_id,
"scenario_name": env.scenario_name,
"episode_return": float(env.episode_return),
"total_episode_return": float(env.total_episode_return),
"decision_steps": float(env.decision_step_count),
}
)
return metrics