scrapeRL / backend /app /core /reward.py
NeerajCodz's picture
feat: add core RL environment models (observation, action, reward, env)
ab65628
"""Reward computation engine with component breakdown."""
from dataclasses import dataclass, field
from typing import Any
from app.config import Settings, get_settings
from app.core.action import Action, ActionType
from app.core.observation import Observation
@dataclass
class RewardBreakdown:
"""Detailed breakdown of reward components."""
# Core components
accuracy: float = 0.0
efficiency: float = 0.0
cost: float = 0.0
completeness: float = 0.0
# Bonus/penalty components
progress_bonus: float = 0.0
error_penalty: float = 0.0
time_penalty: float = 0.0
redundancy_penalty: float = 0.0
exploration_bonus: float = 0.0
verification_bonus: float = 0.0
# Metadata
total: float = 0.0
components: dict[str, float] = field(default_factory=dict)
def compute_total(self, weights: dict[str, float]) -> float:
"""Compute total reward with weights."""
self.total = (
self.accuracy * weights.get("accuracy", 0.4)
+ self.efficiency * weights.get("efficiency", 0.2)
+ self.cost * weights.get("cost", 0.2)
+ self.completeness * weights.get("completeness", 0.2)
+ self.progress_bonus
+ self.exploration_bonus
+ self.verification_bonus
- self.error_penalty
- self.time_penalty
- self.redundancy_penalty
)
self.components = {
"accuracy": self.accuracy,
"efficiency": self.efficiency,
"cost": self.cost,
"completeness": self.completeness,
"progress_bonus": self.progress_bonus,
"error_penalty": self.error_penalty,
"time_penalty": self.time_penalty,
"redundancy_penalty": self.redundancy_penalty,
"exploration_bonus": self.exploration_bonus,
"verification_bonus": self.verification_bonus,
}
return self.total
def to_dict(self) -> dict[str, float]:
"""Convert to dictionary."""
return {
"total": self.total,
**self.components,
}
class RewardEngine:
"""
Computes rewards for actions in the web scraping environment.
Reward components:
- Accuracy: How correct extracted data is
- Efficiency: Steps taken vs optimal
- Cost: API/compute costs
- Completeness: Progress towards task completion
Plus bonuses/penalties for:
- Progress: Making progress towards goal
- Errors: Failed actions or invalid extractions
- Time: Taking too long
- Redundancy: Repeating unsuccessful actions
- Exploration: Discovering new information
- Verification: Validating extracted data
"""
def __init__(self, settings: Settings | None = None) -> None:
"""Initialize the reward engine."""
self.settings = settings or get_settings()
self.weights = {
"accuracy": self.settings.reward_accuracy_weight,
"efficiency": self.settings.reward_efficiency_weight,
"cost": self.settings.reward_cost_weight,
"completeness": self.settings.reward_completeness_weight,
}
# Tracking for penalties
self._action_history: list[Action] = []
self._extraction_attempts: dict[str, int] = {}
self._url_visits: dict[str, int] = {}
def reset(self) -> None:
"""Reset tracking state for a new episode."""
self._action_history.clear()
self._extraction_attempts.clear()
self._url_visits.clear()
def compute_reward(
self,
action: Action,
prev_observation: Observation,
new_observation: Observation,
ground_truth: dict[str, Any] | None = None,
max_steps: int = 50,
) -> tuple[float, RewardBreakdown]:
"""
Compute reward for an action.
Args:
action: The action that was taken.
prev_observation: Observation before the action.
new_observation: Observation after the action.
ground_truth: Optional ground truth data for accuracy calculation.
max_steps: Maximum steps allowed in episode.
Returns:
Tuple of (total_reward, breakdown).
"""
breakdown = RewardBreakdown()
# Track action
self._action_history.append(action)
# Compute accuracy component
breakdown.accuracy = self._compute_accuracy(
action, new_observation, ground_truth
)
# Compute efficiency component
breakdown.efficiency = self._compute_efficiency(
new_observation.step_number, max_steps
)
# Compute cost component
breakdown.cost = self._compute_cost_reward(new_observation)
# Compute completeness component
breakdown.completeness = self._compute_completeness(
prev_observation, new_observation
)
# Compute bonuses
breakdown.progress_bonus = self._compute_progress_bonus(
prev_observation, new_observation
)
breakdown.exploration_bonus = self._compute_exploration_bonus(
action, new_observation
)
breakdown.verification_bonus = self._compute_verification_bonus(
action, new_observation
)
# Compute penalties
breakdown.error_penalty = self._compute_error_penalty(new_observation)
breakdown.time_penalty = self._compute_time_penalty(new_observation, max_steps)
breakdown.redundancy_penalty = self._compute_redundancy_penalty(action)
# Compute total
total = breakdown.compute_total(self.weights)
return total, breakdown
def _compute_accuracy(
self,
action: Action,
observation: Observation,
ground_truth: dict[str, Any] | None,
) -> float:
"""Compute accuracy reward component."""
if ground_truth is None:
# Without ground truth, use confidence scores
if observation.extracted_so_far:
avg_confidence = sum(
f.confidence for f in observation.extracted_so_far
) / len(observation.extracted_so_far)
return avg_confidence
return 0.5 # Neutral
# With ground truth, compute actual accuracy
extracted = observation.get_extraction_dict()
if not extracted:
return 0.0
correct = 0
total = 0
for field_name, expected_value in ground_truth.items():
if field_name in extracted:
total += 1
actual_value = extracted[field_name]
if self._values_match(actual_value, expected_value):
correct += 1
if total == 0:
return 0.0
return correct / total
def _values_match(self, actual: Any, expected: Any) -> bool:
"""Check if extracted value matches expected value."""
if actual == expected:
return True
# Fuzzy matching for strings
if isinstance(actual, str) and isinstance(expected, str):
actual_clean = actual.strip().lower()
expected_clean = expected.strip().lower()
if actual_clean == expected_clean:
return True
# Partial match
if expected_clean in actual_clean or actual_clean in expected_clean:
return True
# Numeric comparison with tolerance
if isinstance(actual, (int, float)) and isinstance(expected, (int, float)):
tolerance = abs(expected) * 0.01 if expected != 0 else 0.01
return abs(actual - expected) <= tolerance
return False
def _compute_efficiency(self, current_step: int, max_steps: int) -> float:
"""Compute efficiency based on steps taken."""
# Higher reward for completing tasks in fewer steps
remaining_ratio = (max_steps - current_step) / max_steps
return max(0.0, remaining_ratio)
def _compute_cost_reward(self, observation: Observation) -> float:
"""Compute reward based on cost efficiency."""
# Penalize high token usage and API calls
max_expected_tokens = 10000
max_expected_calls = 50
token_efficiency = 1.0 - min(
observation.tokens_used / max_expected_tokens, 1.0
)
call_efficiency = 1.0 - min(
observation.api_calls_made / max_expected_calls, 1.0
)
return (token_efficiency + call_efficiency) / 2
def _compute_completeness(
self,
prev_observation: Observation,
new_observation: Observation,
) -> float:
"""Compute completeness based on extraction progress."""
return new_observation.extraction_progress
def _compute_progress_bonus(
self,
prev_observation: Observation,
new_observation: Observation,
) -> float:
"""Bonus for making progress."""
progress_delta = (
new_observation.extraction_progress - prev_observation.extraction_progress
)
# Bonus for new extractions
new_extractions = len(new_observation.extracted_so_far) - len(
prev_observation.extracted_so_far
)
bonus = 0.0
if progress_delta > 0:
bonus += progress_delta * 0.5
if new_extractions > 0:
bonus += new_extractions * 0.1
return bonus
def _compute_exploration_bonus(
self,
action: Action,
observation: Observation,
) -> float:
"""Bonus for exploring new pages."""
bonus = 0.0
if action.action_type == ActionType.NAVIGATE:
url = action.get_param("url", "")
if url and url not in self._url_visits:
bonus += 0.05
self._url_visits[url] = self._url_visits.get(url, 0) + 1
return bonus
def _compute_verification_bonus(
self,
action: Action,
observation: Observation,
) -> float:
"""Bonus for verification actions."""
if action.action_type in [ActionType.VERIFY_FACT, ActionType.VERIFY_FIELD]:
return 0.05
return 0.0
def _compute_error_penalty(self, observation: Observation) -> float:
"""Penalty for errors."""
if observation.last_action_error:
base_penalty = 0.1
consecutive_penalty = observation.consecutive_errors * 0.05
return base_penalty + consecutive_penalty
return 0.0
def _compute_time_penalty(
self,
observation: Observation,
max_steps: int,
) -> float:
"""Penalty for taking too long."""
step_ratio = observation.step_number / max_steps
if step_ratio > 0.8:
return (step_ratio - 0.8) * 0.5
return 0.0
def _compute_redundancy_penalty(self, action: Action) -> float:
"""Penalty for redundant actions."""
if len(self._action_history) < 2:
return 0.0
# Check for repeated extract attempts on same field
if action.action_type == ActionType.EXTRACT_FIELD:
field = action.get_param("field_name", "")
attempts = self._extraction_attempts.get(field, 0)
self._extraction_attempts[field] = attempts + 1
if attempts > 0:
return min(attempts * 0.05, 0.2)
# Check for repeated navigation to same URL
if action.action_type == ActionType.NAVIGATE:
url = action.get_param("url", "")
visits = self._url_visits.get(url, 0)
if visits > 1:
return min((visits - 1) * 0.03, 0.15)
return 0.0
def compute_terminal_reward(
self,
observation: Observation,
success: bool,
ground_truth: dict[str, Any] | None = None,
) -> tuple[float, RewardBreakdown]:
"""
Compute final reward at episode termination.
Args:
observation: Final observation.
success: Whether the task was completed successfully.
ground_truth: Optional ground truth for accuracy.
Returns:
Tuple of (total_reward, breakdown).
"""
breakdown = RewardBreakdown()
if success:
# Big bonus for successful completion
breakdown.completeness = 1.0
breakdown.progress_bonus = 0.5
# Compute final accuracy
if ground_truth:
extracted = observation.get_extraction_dict()
correct = sum(
1 for k, v in ground_truth.items()
if k in extracted and self._values_match(extracted[k], v)
)
total = len(ground_truth)
breakdown.accuracy = correct / total if total > 0 else 1.0
else:
breakdown.accuracy = observation.extraction_progress
# Efficiency bonus for fast completion
breakdown.efficiency = 1.0 - (
observation.step_number / self.settings.max_steps_per_episode
)
else:
# Partial credit for progress made
breakdown.completeness = observation.extraction_progress * 0.5
breakdown.error_penalty = 0.3
total = breakdown.compute_total(self.weights)
return total, breakdown