""" reward_engine.py — Multi-signal reward shaping for USAR environment. Reward philosophy: The reward function must provide useful signal at EVERY step, not just at episode end. An agent that learns from sparse end-of-episode reward will never generalize to the hard task. Reward signals (per step): 1. Survivor rescue reward — primary signal, weighted by criticality 2. Survival preservation reward — reward for keeping survivors alive via speed 3. Resource efficiency penalty — penalize waste (medkits on non-critical, etc.) 4. Inaction penalty — mild penalty for WAIT when urgent cases exist 5. Misallocation penalty — penalty for assigning wrong team to wrong site 6. Secondary collapse penalty — large penalty when a site collapses with survivors 7. Critical save bonus — extra reward for rescuing critical survivors All rewards are normalized to [-1, +1] range per step. Episode score is normalized to [0, 1]. """ from __future__ import annotations from dataclasses import dataclass, field from typing import Dict, List, Optional # --------------------------------------------------------------------------- # Reward weights — tuned for meaningful signal at each task difficulty # --------------------------------------------------------------------------- @dataclass class RewardWeights: rescue_alive: float = 1.00 # per survivor rescued alive rescue_dead: float = 0.10 # partial — they were at least found critical_save_bonus: float = 0.50 # bonus per critical survivor saved alive survival_preservation: float = 0.05 # per step survivors stay alive due to speed inaction_penalty: float = -0.15 # per WAIT step when critical survivors exist misallocation_penalty: float = -0.10 # wrong team spec to wrong debris resource_waste_penalty: float = -0.20 # medkit on non-critical, air on cleared site secondary_collapse_penalty: float = -2.0 # per survivor lost in secondary collapse time_urgency_bonus: float = 0.08 # bonus for rescuing before survival prob drops below 0.5 fatigue_overuse_penalty: float = -0.05 # deploying exhausted team @dataclass class StepRewardBreakdown: """Detailed reward breakdown for one step — used for logging and grading.""" rescue_reward: float = 0.0 critical_bonus: float = 0.0 survival_preservation: float = 0.0 inaction_penalty: float = 0.0 misallocation_penalty: float = 0.0 resource_waste: float = 0.0 collapse_penalty: float = 0.0 time_urgency_bonus: float = 0.0 fatigue_penalty: float = 0.0 @property def total(self) -> float: return ( self.rescue_reward + self.critical_bonus + self.survival_preservation + self.inaction_penalty + self.misallocation_penalty + self.resource_waste + self.collapse_penalty + self.time_urgency_bonus + self.fatigue_penalty ) class RewardEngine: """ Computes per-step and episode-level rewards for USAR episodes. """ def __init__(self, weights: Optional[RewardWeights] = None): self.weights = weights or RewardWeights() self.history: List[StepRewardBreakdown] = [] def compute_step_reward( self, action_type: str, rescued_alive: int, rescued_dead: int, critical_saved: int, survivors_lost_to_collapse: int, was_inaction: bool, urgent_critical_count: int, misallocation: bool, resource_wasted: bool, team_was_fatigued: bool, avg_survival_prob_before: float, avg_survival_prob_after: float, ) -> StepRewardBreakdown: """ Compute the full reward breakdown for one step. """ w = self.weights bd = StepRewardBreakdown() # 1. Rescue rewards bd.rescue_reward = rescued_alive * w.rescue_alive + rescued_dead * w.rescue_dead # 2. Critical save bonus bd.critical_bonus = critical_saved * w.critical_save_bonus # 3. Survival preservation — reward for speed (when survival prob is still high) if rescued_alive > 0 and avg_survival_prob_before > 0.5: bd.time_urgency_bonus = rescued_alive * w.time_urgency_bonus # 4. Survival preservation — penalize/reward based on whether prob improved # (proxy: if a team is deployed on urgent site, survival prob should hold or rise) preservation_delta = avg_survival_prob_after - avg_survival_prob_before if preservation_delta > 0: bd.survival_preservation = preservation_delta * w.survival_preservation * 10 # 5. Inaction penalty — WAIT when critical survivors present is bad if was_inaction and urgent_critical_count > 0: bd.inaction_penalty = w.inaction_penalty * min(urgent_critical_count, 5) # 6. Misallocation penalty if misallocation: bd.misallocation_penalty = w.misallocation_penalty # 7. Resource waste if resource_wasted: bd.resource_waste = w.resource_waste_penalty # 8. Secondary collapse penalty (per survivor lost) if survivors_lost_to_collapse > 0: bd.collapse_penalty = survivors_lost_to_collapse * w.secondary_collapse_penalty # 9. Fatigue overuse if team_was_fatigued: bd.fatigue_penalty = w.fatigue_overuse_penalty self.history.append(bd) return bd def compute_episode_score( self, total_rescued: int, total_lost: int, total_initial: int, critical_rescued: int, total_critical: int, steps_taken: int, max_steps: int, resources_used_efficiently: bool, ) -> Dict[str, float]: """ Compute final normalized episode score [0, 1]. Score = weighted combination of: - Rescue rate (50%) - Critical rescue rate (25%) - Time efficiency (15%) - Resource efficiency (10%) """ if total_initial == 0: return {"score": 0.0, "rescue_rate": 0.0, "critical_rate": 0.0, "time_efficiency": 0.0, "resource_efficiency": 0.0} # Rescue rate — primary metric rescue_rate = total_rescued / total_initial # Critical rescue rate critical_rate = (critical_rescued / total_critical) if total_critical > 0 else 1.0 # Time efficiency — how early in the episode rescues happened # Ideal: rescue most survivors in first 60% of steps step_fraction_used = steps_taken / max_steps if rescue_rate > 0: # Bonus for rescuing survivors early (when survival prob is higher) time_efficiency = max(0.0, 1.0 - step_fraction_used * 0.5) else: time_efficiency = 0.0 # Resource efficiency resource_efficiency = 1.0 if resources_used_efficiently else 0.6 # Weighted score score = ( 0.50 * rescue_rate + 0.25 * critical_rate + 0.15 * time_efficiency + 0.10 * resource_efficiency ) # Penalize if more survivors lost than rescued (catastrophic failure) if total_lost > total_rescued: score *= 0.6 return { "score": round(min(1.0, max(0.0, score)), 4), "rescue_rate": round(rescue_rate, 4), "critical_rate": round(critical_rate, 4), "time_efficiency": round(time_efficiency, 4), "resource_efficiency": round(resource_efficiency, 4), } @property def cumulative_reward(self) -> float: return sum(bd.total for bd in self.history) def reset(self): self.history = []