usar / server /simulation /reward_engine.py
PranovRaghavendhra's picture
USAR OpenEnv v1.0 - Urban Search and Rescue
ce614ef
"""
reward_engine.py — Multi-signal reward shaping for USAR environment.
Reward philosophy:
The reward function must provide useful signal at EVERY step, not just
at episode end. An agent that learns from sparse end-of-episode reward
will never generalize to the hard task.
Reward signals (per step):
1. Survivor rescue reward — primary signal, weighted by criticality
2. Survival preservation reward — reward for keeping survivors alive via speed
3. Resource efficiency penalty — penalize waste (medkits on non-critical, etc.)
4. Inaction penalty — mild penalty for WAIT when urgent cases exist
5. Misallocation penalty — penalty for assigning wrong team to wrong site
6. Secondary collapse penalty — large penalty when a site collapses with survivors
7. Critical save bonus — extra reward for rescuing critical survivors
All rewards are normalized to [-1, +1] range per step.
Episode score is normalized to [0, 1].
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Optional
# ---------------------------------------------------------------------------
# Reward weights — tuned for meaningful signal at each task difficulty
# ---------------------------------------------------------------------------
@dataclass
class RewardWeights:
rescue_alive: float = 1.00 # per survivor rescued alive
rescue_dead: float = 0.10 # partial — they were at least found
critical_save_bonus: float = 0.50 # bonus per critical survivor saved alive
survival_preservation: float = 0.05 # per step survivors stay alive due to speed
inaction_penalty: float = -0.15 # per WAIT step when critical survivors exist
misallocation_penalty: float = -0.10 # wrong team spec to wrong debris
resource_waste_penalty: float = -0.20 # medkit on non-critical, air on cleared site
secondary_collapse_penalty: float = -2.0 # per survivor lost in secondary collapse
time_urgency_bonus: float = 0.08 # bonus for rescuing before survival prob drops below 0.5
fatigue_overuse_penalty: float = -0.05 # deploying exhausted team
@dataclass
class StepRewardBreakdown:
"""Detailed reward breakdown for one step — used for logging and grading."""
rescue_reward: float = 0.0
critical_bonus: float = 0.0
survival_preservation: float = 0.0
inaction_penalty: float = 0.0
misallocation_penalty: float = 0.0
resource_waste: float = 0.0
collapse_penalty: float = 0.0
time_urgency_bonus: float = 0.0
fatigue_penalty: float = 0.0
@property
def total(self) -> float:
return (
self.rescue_reward
+ self.critical_bonus
+ self.survival_preservation
+ self.inaction_penalty
+ self.misallocation_penalty
+ self.resource_waste
+ self.collapse_penalty
+ self.time_urgency_bonus
+ self.fatigue_penalty
)
class RewardEngine:
"""
Computes per-step and episode-level rewards for USAR episodes.
"""
def __init__(self, weights: Optional[RewardWeights] = None):
self.weights = weights or RewardWeights()
self.history: List[StepRewardBreakdown] = []
def compute_step_reward(
self,
action_type: str,
rescued_alive: int,
rescued_dead: int,
critical_saved: int,
survivors_lost_to_collapse: int,
was_inaction: bool,
urgent_critical_count: int,
misallocation: bool,
resource_wasted: bool,
team_was_fatigued: bool,
avg_survival_prob_before: float,
avg_survival_prob_after: float,
) -> StepRewardBreakdown:
"""
Compute the full reward breakdown for one step.
"""
w = self.weights
bd = StepRewardBreakdown()
# 1. Rescue rewards
bd.rescue_reward = rescued_alive * w.rescue_alive + rescued_dead * w.rescue_dead
# 2. Critical save bonus
bd.critical_bonus = critical_saved * w.critical_save_bonus
# 3. Survival preservation — reward for speed (when survival prob is still high)
if rescued_alive > 0 and avg_survival_prob_before > 0.5:
bd.time_urgency_bonus = rescued_alive * w.time_urgency_bonus
# 4. Survival preservation — penalize/reward based on whether prob improved
# (proxy: if a team is deployed on urgent site, survival prob should hold or rise)
preservation_delta = avg_survival_prob_after - avg_survival_prob_before
if preservation_delta > 0:
bd.survival_preservation = preservation_delta * w.survival_preservation * 10
# 5. Inaction penalty — WAIT when critical survivors present is bad
if was_inaction and urgent_critical_count > 0:
bd.inaction_penalty = w.inaction_penalty * min(urgent_critical_count, 5)
# 6. Misallocation penalty
if misallocation:
bd.misallocation_penalty = w.misallocation_penalty
# 7. Resource waste
if resource_wasted:
bd.resource_waste = w.resource_waste_penalty
# 8. Secondary collapse penalty (per survivor lost)
if survivors_lost_to_collapse > 0:
bd.collapse_penalty = survivors_lost_to_collapse * w.secondary_collapse_penalty
# 9. Fatigue overuse
if team_was_fatigued:
bd.fatigue_penalty = w.fatigue_overuse_penalty
self.history.append(bd)
return bd
def compute_episode_score(
self,
total_rescued: int,
total_lost: int,
total_initial: int,
critical_rescued: int,
total_critical: int,
steps_taken: int,
max_steps: int,
resources_used_efficiently: bool,
) -> Dict[str, float]:
"""
Compute final normalized episode score [0, 1].
Score = weighted combination of:
- Rescue rate (50%)
- Critical rescue rate (25%)
- Time efficiency (15%)
- Resource efficiency (10%)
"""
if total_initial == 0:
return {"score": 0.0, "rescue_rate": 0.0, "critical_rate": 0.0,
"time_efficiency": 0.0, "resource_efficiency": 0.0}
# Rescue rate — primary metric
rescue_rate = total_rescued / total_initial
# Critical rescue rate
critical_rate = (critical_rescued / total_critical) if total_critical > 0 else 1.0
# Time efficiency — how early in the episode rescues happened
# Ideal: rescue most survivors in first 60% of steps
step_fraction_used = steps_taken / max_steps
if rescue_rate > 0:
# Bonus for rescuing survivors early (when survival prob is higher)
time_efficiency = max(0.0, 1.0 - step_fraction_used * 0.5)
else:
time_efficiency = 0.0
# Resource efficiency
resource_efficiency = 1.0 if resources_used_efficiently else 0.6
# Weighted score
score = (
0.50 * rescue_rate
+ 0.25 * critical_rate
+ 0.15 * time_efficiency
+ 0.10 * resource_efficiency
)
# Penalize if more survivors lost than rescued (catastrophic failure)
if total_lost > total_rescued:
score *= 0.6
return {
"score": round(min(1.0, max(0.0, score)), 4),
"rescue_rate": round(rescue_rate, 4),
"critical_rate": round(critical_rate, 4),
"time_efficiency": round(time_efficiency, 4),
"resource_efficiency": round(resource_efficiency, 4),
}
@property
def cumulative_reward(self) -> float:
return sum(bd.total for bd in self.history)
def reset(self):
self.history = []