Spaces:
Sleeping
Sleeping
| """Reward computation engine with component breakdown.""" | |
| from dataclasses import dataclass, field | |
| from typing import Any | |
| from app.config import Settings, get_settings | |
| from app.core.action import Action, ActionType | |
| from app.core.observation import Observation | |
| class RewardBreakdown: | |
| """Detailed breakdown of reward components.""" | |
| # Core components | |
| accuracy: float = 0.0 | |
| efficiency: float = 0.0 | |
| cost: float = 0.0 | |
| completeness: float = 0.0 | |
| # Bonus/penalty components | |
| progress_bonus: float = 0.0 | |
| error_penalty: float = 0.0 | |
| time_penalty: float = 0.0 | |
| redundancy_penalty: float = 0.0 | |
| exploration_bonus: float = 0.0 | |
| verification_bonus: float = 0.0 | |
| # Metadata | |
| total: float = 0.0 | |
| components: dict[str, float] = field(default_factory=dict) | |
| def compute_total(self, weights: dict[str, float]) -> float: | |
| """Compute total reward with weights.""" | |
| self.total = ( | |
| self.accuracy * weights.get("accuracy", 0.4) | |
| + self.efficiency * weights.get("efficiency", 0.2) | |
| + self.cost * weights.get("cost", 0.2) | |
| + self.completeness * weights.get("completeness", 0.2) | |
| + self.progress_bonus | |
| + self.exploration_bonus | |
| + self.verification_bonus | |
| - self.error_penalty | |
| - self.time_penalty | |
| - self.redundancy_penalty | |
| ) | |
| self.components = { | |
| "accuracy": self.accuracy, | |
| "efficiency": self.efficiency, | |
| "cost": self.cost, | |
| "completeness": self.completeness, | |
| "progress_bonus": self.progress_bonus, | |
| "error_penalty": self.error_penalty, | |
| "time_penalty": self.time_penalty, | |
| "redundancy_penalty": self.redundancy_penalty, | |
| "exploration_bonus": self.exploration_bonus, | |
| "verification_bonus": self.verification_bonus, | |
| } | |
| return self.total | |
| def to_dict(self) -> dict[str, float]: | |
| """Convert to dictionary.""" | |
| return { | |
| "total": self.total, | |
| **self.components, | |
| } | |
| class RewardEngine: | |
| """ | |
| Computes rewards for actions in the web scraping environment. | |
| Reward components: | |
| - Accuracy: How correct extracted data is | |
| - Efficiency: Steps taken vs optimal | |
| - Cost: API/compute costs | |
| - Completeness: Progress towards task completion | |
| Plus bonuses/penalties for: | |
| - Progress: Making progress towards goal | |
| - Errors: Failed actions or invalid extractions | |
| - Time: Taking too long | |
| - Redundancy: Repeating unsuccessful actions | |
| - Exploration: Discovering new information | |
| - Verification: Validating extracted data | |
| """ | |
| def __init__(self, settings: Settings | None = None) -> None: | |
| """Initialize the reward engine.""" | |
| self.settings = settings or get_settings() | |
| self.weights = { | |
| "accuracy": self.settings.reward_accuracy_weight, | |
| "efficiency": self.settings.reward_efficiency_weight, | |
| "cost": self.settings.reward_cost_weight, | |
| "completeness": self.settings.reward_completeness_weight, | |
| } | |
| # Tracking for penalties | |
| self._action_history: list[Action] = [] | |
| self._extraction_attempts: dict[str, int] = {} | |
| self._url_visits: dict[str, int] = {} | |
| def reset(self) -> None: | |
| """Reset tracking state for a new episode.""" | |
| self._action_history.clear() | |
| self._extraction_attempts.clear() | |
| self._url_visits.clear() | |
| def compute_reward( | |
| self, | |
| action: Action, | |
| prev_observation: Observation, | |
| new_observation: Observation, | |
| ground_truth: dict[str, Any] | None = None, | |
| max_steps: int = 50, | |
| ) -> tuple[float, RewardBreakdown]: | |
| """ | |
| Compute reward for an action. | |
| Args: | |
| action: The action that was taken. | |
| prev_observation: Observation before the action. | |
| new_observation: Observation after the action. | |
| ground_truth: Optional ground truth data for accuracy calculation. | |
| max_steps: Maximum steps allowed in episode. | |
| Returns: | |
| Tuple of (total_reward, breakdown). | |
| """ | |
| breakdown = RewardBreakdown() | |
| # Track action | |
| self._action_history.append(action) | |
| # Compute accuracy component | |
| breakdown.accuracy = self._compute_accuracy( | |
| action, new_observation, ground_truth | |
| ) | |
| # Compute efficiency component | |
| breakdown.efficiency = self._compute_efficiency( | |
| new_observation.step_number, max_steps | |
| ) | |
| # Compute cost component | |
| breakdown.cost = self._compute_cost_reward(new_observation) | |
| # Compute completeness component | |
| breakdown.completeness = self._compute_completeness( | |
| prev_observation, new_observation | |
| ) | |
| # Compute bonuses | |
| breakdown.progress_bonus = self._compute_progress_bonus( | |
| prev_observation, new_observation | |
| ) | |
| breakdown.exploration_bonus = self._compute_exploration_bonus( | |
| action, new_observation | |
| ) | |
| breakdown.verification_bonus = self._compute_verification_bonus( | |
| action, new_observation | |
| ) | |
| # Compute penalties | |
| breakdown.error_penalty = self._compute_error_penalty(new_observation) | |
| breakdown.time_penalty = self._compute_time_penalty(new_observation, max_steps) | |
| breakdown.redundancy_penalty = self._compute_redundancy_penalty(action) | |
| # Compute total | |
| total = breakdown.compute_total(self.weights) | |
| return total, breakdown | |
| def _compute_accuracy( | |
| self, | |
| action: Action, | |
| observation: Observation, | |
| ground_truth: dict[str, Any] | None, | |
| ) -> float: | |
| """Compute accuracy reward component.""" | |
| if ground_truth is None: | |
| # Without ground truth, use confidence scores | |
| if observation.extracted_so_far: | |
| avg_confidence = sum( | |
| f.confidence for f in observation.extracted_so_far | |
| ) / len(observation.extracted_so_far) | |
| return avg_confidence | |
| return 0.5 # Neutral | |
| # With ground truth, compute actual accuracy | |
| extracted = observation.get_extraction_dict() | |
| if not extracted: | |
| return 0.0 | |
| correct = 0 | |
| total = 0 | |
| for field_name, expected_value in ground_truth.items(): | |
| if field_name in extracted: | |
| total += 1 | |
| actual_value = extracted[field_name] | |
| if self._values_match(actual_value, expected_value): | |
| correct += 1 | |
| if total == 0: | |
| return 0.0 | |
| return correct / total | |
| def _values_match(self, actual: Any, expected: Any) -> bool: | |
| """Check if extracted value matches expected value.""" | |
| if actual == expected: | |
| return True | |
| # Fuzzy matching for strings | |
| if isinstance(actual, str) and isinstance(expected, str): | |
| actual_clean = actual.strip().lower() | |
| expected_clean = expected.strip().lower() | |
| if actual_clean == expected_clean: | |
| return True | |
| # Partial match | |
| if expected_clean in actual_clean or actual_clean in expected_clean: | |
| return True | |
| # Numeric comparison with tolerance | |
| if isinstance(actual, (int, float)) and isinstance(expected, (int, float)): | |
| tolerance = abs(expected) * 0.01 if expected != 0 else 0.01 | |
| return abs(actual - expected) <= tolerance | |
| return False | |
| def _compute_efficiency(self, current_step: int, max_steps: int) -> float: | |
| """Compute efficiency based on steps taken.""" | |
| # Higher reward for completing tasks in fewer steps | |
| remaining_ratio = (max_steps - current_step) / max_steps | |
| return max(0.0, remaining_ratio) | |
| def _compute_cost_reward(self, observation: Observation) -> float: | |
| """Compute reward based on cost efficiency.""" | |
| # Penalize high token usage and API calls | |
| max_expected_tokens = 10000 | |
| max_expected_calls = 50 | |
| token_efficiency = 1.0 - min( | |
| observation.tokens_used / max_expected_tokens, 1.0 | |
| ) | |
| call_efficiency = 1.0 - min( | |
| observation.api_calls_made / max_expected_calls, 1.0 | |
| ) | |
| return (token_efficiency + call_efficiency) / 2 | |
| def _compute_completeness( | |
| self, | |
| prev_observation: Observation, | |
| new_observation: Observation, | |
| ) -> float: | |
| """Compute completeness based on extraction progress.""" | |
| return new_observation.extraction_progress | |
| def _compute_progress_bonus( | |
| self, | |
| prev_observation: Observation, | |
| new_observation: Observation, | |
| ) -> float: | |
| """Bonus for making progress.""" | |
| progress_delta = ( | |
| new_observation.extraction_progress - prev_observation.extraction_progress | |
| ) | |
| # Bonus for new extractions | |
| new_extractions = len(new_observation.extracted_so_far) - len( | |
| prev_observation.extracted_so_far | |
| ) | |
| bonus = 0.0 | |
| if progress_delta > 0: | |
| bonus += progress_delta * 0.5 | |
| if new_extractions > 0: | |
| bonus += new_extractions * 0.1 | |
| return bonus | |
| def _compute_exploration_bonus( | |
| self, | |
| action: Action, | |
| observation: Observation, | |
| ) -> float: | |
| """Bonus for exploring new pages.""" | |
| bonus = 0.0 | |
| if action.action_type == ActionType.NAVIGATE: | |
| url = action.get_param("url", "") | |
| if url and url not in self._url_visits: | |
| bonus += 0.05 | |
| self._url_visits[url] = self._url_visits.get(url, 0) + 1 | |
| return bonus | |
| def _compute_verification_bonus( | |
| self, | |
| action: Action, | |
| observation: Observation, | |
| ) -> float: | |
| """Bonus for verification actions.""" | |
| if action.action_type in [ActionType.VERIFY_FACT, ActionType.VERIFY_FIELD]: | |
| return 0.05 | |
| return 0.0 | |
| def _compute_error_penalty(self, observation: Observation) -> float: | |
| """Penalty for errors.""" | |
| if observation.last_action_error: | |
| base_penalty = 0.1 | |
| consecutive_penalty = observation.consecutive_errors * 0.05 | |
| return base_penalty + consecutive_penalty | |
| return 0.0 | |
| def _compute_time_penalty( | |
| self, | |
| observation: Observation, | |
| max_steps: int, | |
| ) -> float: | |
| """Penalty for taking too long.""" | |
| step_ratio = observation.step_number / max_steps | |
| if step_ratio > 0.8: | |
| return (step_ratio - 0.8) * 0.5 | |
| return 0.0 | |
| def _compute_redundancy_penalty(self, action: Action) -> float: | |
| """Penalty for redundant actions.""" | |
| if len(self._action_history) < 2: | |
| return 0.0 | |
| # Check for repeated extract attempts on same field | |
| if action.action_type == ActionType.EXTRACT_FIELD: | |
| field = action.get_param("field_name", "") | |
| attempts = self._extraction_attempts.get(field, 0) | |
| self._extraction_attempts[field] = attempts + 1 | |
| if attempts > 0: | |
| return min(attempts * 0.05, 0.2) | |
| # Check for repeated navigation to same URL | |
| if action.action_type == ActionType.NAVIGATE: | |
| url = action.get_param("url", "") | |
| visits = self._url_visits.get(url, 0) | |
| if visits > 1: | |
| return min((visits - 1) * 0.03, 0.15) | |
| return 0.0 | |
| def compute_terminal_reward( | |
| self, | |
| observation: Observation, | |
| success: bool, | |
| ground_truth: dict[str, Any] | None = None, | |
| ) -> tuple[float, RewardBreakdown]: | |
| """ | |
| Compute final reward at episode termination. | |
| Args: | |
| observation: Final observation. | |
| success: Whether the task was completed successfully. | |
| ground_truth: Optional ground truth for accuracy. | |
| Returns: | |
| Tuple of (total_reward, breakdown). | |
| """ | |
| breakdown = RewardBreakdown() | |
| if success: | |
| # Big bonus for successful completion | |
| breakdown.completeness = 1.0 | |
| breakdown.progress_bonus = 0.5 | |
| # Compute final accuracy | |
| if ground_truth: | |
| extracted = observation.get_extraction_dict() | |
| correct = sum( | |
| 1 for k, v in ground_truth.items() | |
| if k in extracted and self._values_match(extracted[k], v) | |
| ) | |
| total = len(ground_truth) | |
| breakdown.accuracy = correct / total if total > 0 else 1.0 | |
| else: | |
| breakdown.accuracy = observation.extraction_progress | |
| # Efficiency bonus for fast completion | |
| breakdown.efficiency = 1.0 - ( | |
| observation.step_number / self.settings.max_steps_per_episode | |
| ) | |
| else: | |
| # Partial credit for progress made | |
| breakdown.completeness = observation.extraction_progress * 0.5 | |
| breakdown.error_penalty = 0.3 | |
| total = breakdown.compute_total(self.weights) | |
| return total, breakdown | |