from __future__ import annotations from typing import Protocol import numpy as np import torch from agents.local_policy import BaseLocalPolicy from training.models import RunningNormalizer, TrafficControlQNetwork class TorchLocalPolicyProtocol(Protocol): def act( self, observations: torch.Tensor, district_type_indices: torch.Tensor, action_mask: torch.Tensor | None = None, deterministic: bool = False, epsilon: float = 0.0, ) -> torch.Tensor: ... def evaluate_policy( env_factory, actor: TrafficControlQNetwork | BaseLocalPolicy | TorchLocalPolicyProtocol, device: torch.device | None = None, obs_normalizer: RunningNormalizer | None = None, deterministic: bool = True, log_prefix: str | None = None, log_every_steps: int = 0, ) -> dict[str, float | str]: env = env_factory() observation_batch = env.reset() done = False final_info = env.last_info max_decision_steps = max( 1, int(getattr(env, "max_episode_seconds", 0) // max(1, env.env_config.decision_interval)), ) while not done: if isinstance(actor, BaseLocalPolicy): actions = actor.act(observation_batch) else: raw_obs = observation_batch["observations"].astype(np.float32) normalized_obs = obs_normalizer.normalize(raw_obs) if obs_normalizer else raw_obs obs_tensor = torch.as_tensor(normalized_obs, dtype=torch.float32, device=device) district_type_tensor = torch.as_tensor( observation_batch["district_type_indices"], dtype=torch.int64, device=device, ) action_mask_tensor = torch.as_tensor( observation_batch["action_mask"], dtype=torch.float32, device=device, ) with torch.no_grad(): action_tensor = actor.act( observations=obs_tensor, district_type_indices=district_type_tensor, action_mask=action_mask_tensor, deterministic=deterministic, epsilon=0.0, ) actions = action_tensor.cpu().numpy() observation_batch, _, done, final_info = env.step(actions) if log_prefix and log_every_steps > 0: decision_step = int(getattr(env, "decision_step_count", 0)) should_log = decision_step == 1 or done or (decision_step % log_every_steps == 0) if should_log: sim_time = int(getattr(env.adapter, "get_current_time", lambda: 0)()) metrics = final_info.get("metrics", {}) if isinstance(final_info, dict) else {} print( f"{log_prefix} step={decision_step}/{max_decision_steps} " f"sim_time={sim_time}s wait={float(metrics.get('mean_waiting_vehicles', float('nan'))):.2f} " f"throughput={float(metrics.get('throughput', float('nan'))):.1f}" ) metrics = { key: float(value) for key, value in final_info["metrics"].items() if value is not None and isinstance(value, (int, float)) } metrics.update( { "city_id": env.city_id, "scenario_name": env.scenario_name, "episode_return": float(env.episode_return), "total_episode_return": float(env.total_episode_return), "decision_steps": float(env.decision_step_count), } ) return metrics