Spaces:
Sleeping
Sleeping
| """ | |
| reward_engine.py — Multi-signal reward shaping for USAR environment. | |
| Reward philosophy: | |
| The reward function must provide useful signal at EVERY step, not just | |
| at episode end. An agent that learns from sparse end-of-episode reward | |
| will never generalize to the hard task. | |
| Reward signals (per step): | |
| 1. Survivor rescue reward — primary signal, weighted by criticality | |
| 2. Survival preservation reward — reward for keeping survivors alive via speed | |
| 3. Resource efficiency penalty — penalize waste (medkits on non-critical, etc.) | |
| 4. Inaction penalty — mild penalty for WAIT when urgent cases exist | |
| 5. Misallocation penalty — penalty for assigning wrong team to wrong site | |
| 6. Secondary collapse penalty — large penalty when a site collapses with survivors | |
| 7. Critical save bonus — extra reward for rescuing critical survivors | |
| All rewards are normalized to [-1, +1] range per step. | |
| Episode score is normalized to [0, 1]. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional | |
| # --------------------------------------------------------------------------- | |
| # Reward weights — tuned for meaningful signal at each task difficulty | |
| # --------------------------------------------------------------------------- | |
| class RewardWeights: | |
| rescue_alive: float = 1.00 # per survivor rescued alive | |
| rescue_dead: float = 0.10 # partial — they were at least found | |
| critical_save_bonus: float = 0.50 # bonus per critical survivor saved alive | |
| survival_preservation: float = 0.05 # per step survivors stay alive due to speed | |
| inaction_penalty: float = -0.15 # per WAIT step when critical survivors exist | |
| misallocation_penalty: float = -0.10 # wrong team spec to wrong debris | |
| resource_waste_penalty: float = -0.20 # medkit on non-critical, air on cleared site | |
| secondary_collapse_penalty: float = -2.0 # per survivor lost in secondary collapse | |
| time_urgency_bonus: float = 0.08 # bonus for rescuing before survival prob drops below 0.5 | |
| fatigue_overuse_penalty: float = -0.05 # deploying exhausted team | |
| class StepRewardBreakdown: | |
| """Detailed reward breakdown for one step — used for logging and grading.""" | |
| rescue_reward: float = 0.0 | |
| critical_bonus: float = 0.0 | |
| survival_preservation: float = 0.0 | |
| inaction_penalty: float = 0.0 | |
| misallocation_penalty: float = 0.0 | |
| resource_waste: float = 0.0 | |
| collapse_penalty: float = 0.0 | |
| time_urgency_bonus: float = 0.0 | |
| fatigue_penalty: float = 0.0 | |
| def total(self) -> float: | |
| return ( | |
| self.rescue_reward | |
| + self.critical_bonus | |
| + self.survival_preservation | |
| + self.inaction_penalty | |
| + self.misallocation_penalty | |
| + self.resource_waste | |
| + self.collapse_penalty | |
| + self.time_urgency_bonus | |
| + self.fatigue_penalty | |
| ) | |
| class RewardEngine: | |
| """ | |
| Computes per-step and episode-level rewards for USAR episodes. | |
| """ | |
| def __init__(self, weights: Optional[RewardWeights] = None): | |
| self.weights = weights or RewardWeights() | |
| self.history: List[StepRewardBreakdown] = [] | |
| def compute_step_reward( | |
| self, | |
| action_type: str, | |
| rescued_alive: int, | |
| rescued_dead: int, | |
| critical_saved: int, | |
| survivors_lost_to_collapse: int, | |
| was_inaction: bool, | |
| urgent_critical_count: int, | |
| misallocation: bool, | |
| resource_wasted: bool, | |
| team_was_fatigued: bool, | |
| avg_survival_prob_before: float, | |
| avg_survival_prob_after: float, | |
| ) -> StepRewardBreakdown: | |
| """ | |
| Compute the full reward breakdown for one step. | |
| """ | |
| w = self.weights | |
| bd = StepRewardBreakdown() | |
| # 1. Rescue rewards | |
| bd.rescue_reward = rescued_alive * w.rescue_alive + rescued_dead * w.rescue_dead | |
| # 2. Critical save bonus | |
| bd.critical_bonus = critical_saved * w.critical_save_bonus | |
| # 3. Survival preservation — reward for speed (when survival prob is still high) | |
| if rescued_alive > 0 and avg_survival_prob_before > 0.5: | |
| bd.time_urgency_bonus = rescued_alive * w.time_urgency_bonus | |
| # 4. Survival preservation — penalize/reward based on whether prob improved | |
| # (proxy: if a team is deployed on urgent site, survival prob should hold or rise) | |
| preservation_delta = avg_survival_prob_after - avg_survival_prob_before | |
| if preservation_delta > 0: | |
| bd.survival_preservation = preservation_delta * w.survival_preservation * 10 | |
| # 5. Inaction penalty — WAIT when critical survivors present is bad | |
| if was_inaction and urgent_critical_count > 0: | |
| bd.inaction_penalty = w.inaction_penalty * min(urgent_critical_count, 5) | |
| # 6. Misallocation penalty | |
| if misallocation: | |
| bd.misallocation_penalty = w.misallocation_penalty | |
| # 7. Resource waste | |
| if resource_wasted: | |
| bd.resource_waste = w.resource_waste_penalty | |
| # 8. Secondary collapse penalty (per survivor lost) | |
| if survivors_lost_to_collapse > 0: | |
| bd.collapse_penalty = survivors_lost_to_collapse * w.secondary_collapse_penalty | |
| # 9. Fatigue overuse | |
| if team_was_fatigued: | |
| bd.fatigue_penalty = w.fatigue_overuse_penalty | |
| self.history.append(bd) | |
| return bd | |
| def compute_episode_score( | |
| self, | |
| total_rescued: int, | |
| total_lost: int, | |
| total_initial: int, | |
| critical_rescued: int, | |
| total_critical: int, | |
| steps_taken: int, | |
| max_steps: int, | |
| resources_used_efficiently: bool, | |
| ) -> Dict[str, float]: | |
| """ | |
| Compute final normalized episode score [0, 1]. | |
| Score = weighted combination of: | |
| - Rescue rate (50%) | |
| - Critical rescue rate (25%) | |
| - Time efficiency (15%) | |
| - Resource efficiency (10%) | |
| """ | |
| if total_initial == 0: | |
| return {"score": 0.0, "rescue_rate": 0.0, "critical_rate": 0.0, | |
| "time_efficiency": 0.0, "resource_efficiency": 0.0} | |
| # Rescue rate — primary metric | |
| rescue_rate = total_rescued / total_initial | |
| # Critical rescue rate | |
| critical_rate = (critical_rescued / total_critical) if total_critical > 0 else 1.0 | |
| # Time efficiency — how early in the episode rescues happened | |
| # Ideal: rescue most survivors in first 60% of steps | |
| step_fraction_used = steps_taken / max_steps | |
| if rescue_rate > 0: | |
| # Bonus for rescuing survivors early (when survival prob is higher) | |
| time_efficiency = max(0.0, 1.0 - step_fraction_used * 0.5) | |
| else: | |
| time_efficiency = 0.0 | |
| # Resource efficiency | |
| resource_efficiency = 1.0 if resources_used_efficiently else 0.6 | |
| # Weighted score | |
| score = ( | |
| 0.50 * rescue_rate | |
| + 0.25 * critical_rate | |
| + 0.15 * time_efficiency | |
| + 0.10 * resource_efficiency | |
| ) | |
| # Penalize if more survivors lost than rescued (catastrophic failure) | |
| if total_lost > total_rescued: | |
| score *= 0.6 | |
| return { | |
| "score": round(min(1.0, max(0.0, score)), 4), | |
| "rescue_rate": round(rescue_rate, 4), | |
| "critical_rate": round(critical_rate, 4), | |
| "time_efficiency": round(time_efficiency, 4), | |
| "resource_efficiency": round(resource_efficiency, 4), | |
| } | |
| def cumulative_reward(self) -> float: | |
| return sum(bd.total for bd in self.history) | |
| def reset(self): | |
| self.history = [] | |