| import math |
| import copy |
| import json |
| import re |
| from core.life_state import LifeMetrics |
| from core.task import Task |
|
|
|
|
|
|
| def compute_reward( |
| state_before: LifeMetrics, |
| state_after: LifeMetrics, |
| resources_used: dict, |
| actions_taken: int, |
| metric_changes: dict = None, |
| completion: str = None, |
| disruption_baseline: int = None, |
| action_type: str = "" |
| ) -> tuple[float, dict]: |
| """ |
| Computes the reward for a life step based on changes in LifeMetrics and resource usage. |
| |
| Args: |
| state_before: The state at the start of the step. |
| state_after: The state after actions and cascades. |
| resources_used: Dict with keys 'time', 'money', 'energy'. |
| actions_taken: Integer count of intentional actions performed. |
| disruption_baseline: Expected number of metrics affected by an action. |
| |
| Returns: |
| tuple[float, dict]: (final_reward, breakdown_dict) |
| """ |
| before_flat = state_before.flatten() |
| after_flat = state_after.flatten() |
| |
| |
| domain_weights = { |
| "career": 1/6, |
| "finances": 1/6, |
| "relationships": 1/6, |
| "physical_health": 1/6, |
| "mental_wellbeing": 1/6, |
| "time": 1/6 |
| } |
| |
| |
| submetrics_per_domain = {} |
| for k in before_flat.keys(): |
| domain = k.split('.')[0] |
| submetrics_per_domain[domain] = submetrics_per_domain.get(domain, 0) + 1 |
| |
| outcome_score = 0.0 |
| for k in before_flat.keys(): |
| domain = k.split('.')[0] |
| delta = after_flat[k] - before_flat[k] |
| if delta > 0: |
| |
| |
| weight = domain_weights[domain] / submetrics_per_domain[domain] |
| outcome_score += (delta / 100.0) * weight |
| |
| |
| worsened_count = sum(1 for k in before_flat.keys() if after_flat[k] < before_flat[k]) |
| total_metrics = len(before_flat) |
| cascade_containment_score = 1.0 - (worsened_count / total_metrics) |
| |
| |
| |
| m_time = resources_used.get('time', 0.0) / 20.0 |
| m_money = resources_used.get('money', 0.0) / 500.0 |
| m_energy = resources_used.get('energy', 0.0) / 100.0 |
| |
| |
| resource_efficiency_score = 1.0 - ((m_time + m_money + m_energy) / 3.0) |
| resource_efficiency_score = max(0.0, min(1.0, resource_efficiency_score)) |
| |
| |
| rel_keys = [k for k in before_flat.keys() if k.startswith('relationships.')] |
| avg_rel_before = sum(before_flat[k] for k in rel_keys) / len(rel_keys) |
| avg_rel_after = sum(after_flat[k] for k in rel_keys) / len(rel_keys) |
| delta_rel = avg_rel_after - avg_rel_before |
| |
| |
| relationship_preservation_score = 1.0 / (1.0 + math.exp(-delta_rel / 10.0)) |
| |
| |
| base_reward = ( |
| (0.40 * outcome_score) + |
| (0.25 * cascade_containment_score) + |
| (0.20 * resource_efficiency_score) + |
| (0.15 * relationship_preservation_score) |
| ) |
| |
| |
| penalties = 0.0 |
| fired = [] |
| |
| |
| if any(v < 20 for v in after_flat.values()): |
| penalties -= 0.50 |
| fired.append("CRITICAL_FLOOR_VIOLATION") |
| |
| |
| |
| if disruption_baseline is None: |
| disruption_baseline = len(metric_changes) if metric_changes else 2 |
| |
| if worsened_count > disruption_baseline: |
| penalties -= 0.30 |
| fired.append("CASCADE_SPREAD_WIDER") |
| |
| |
| if actions_taken == 0: |
| penalties -= 0.40 |
| fired.append("INACTION_PENALTY") |
| |
| |
| if delta_rel < -20: |
| penalties -= 0.15 |
| fired.append("RELATIONSHIP_COLLAPSE") |
|
|
| |
| plaus = 0.0 |
| if metric_changes: |
| plaus = reward_plausibility_check(metric_changes, resources_used) |
| if plaus < 0: |
| penalties += plaus |
| fired.append("PLAUSIBILITY_VIOLATION") |
|
|
| |
| comp_reward = 0.0 |
| reasoning = "" |
| if completion: |
| comp_reward = reward_format_compliance(completion) |
| try: |
| |
| import json |
| data = json.loads(completion) |
| reasoning = data.get("reasoning", "") |
| except: |
| pass |
| |
| |
| reasoning_score = reward_reasoning_coherence(reasoning, action_type=action_type) |
| |
| final_reward = max(-1.0, min(1.0, base_reward + penalties)) |
| |
| breakdown = { |
| "components": { |
| "outcome": outcome_score, |
| "containment": cascade_containment_score, |
| "efficiency": resource_efficiency_score, |
| "preservation": relationship_preservation_score, |
| "format_compliance": comp_reward, |
| "plausibility": plaus, |
| "reasoning_alignment": reasoning_score |
| }, |
| "base_reward": base_reward, |
| "penalties_total": penalties, |
| "penalties_fired": fired, |
| "metrics_worsened": worsened_count, |
| "rel_delta": delta_rel |
| } |
| |
| return final_reward, breakdown |
|
|
| def compute_milestone_reward(milestones_achieved: list[str], task: Task) -> float: |
| if not task.milestones: |
| return 0.0 |
| total_possible = sum(m.reward for m in task.milestones) |
| if total_possible == 0: |
| return 0.0 |
| achieved = sum(m.reward for m in task.milestones if m.id in milestones_achieved) |
| return min(1.0, achieved / total_possible) |
|
|
| def compute_task_completion_reward(success_conditions_met: list[bool], task: Task) -> float: |
| |
| |
| if not success_conditions_met: |
| return 0.0 |
| return 1.0 if any(success_conditions_met) else 0.0 |
|
|
| def compute_replan_bonus(exo_events_seen: int, milestones_after_event: int) -> float: |
| |
| if exo_events_seen == 0: |
| return 0.0 |
| return min(1.0, (milestones_after_event / exo_events_seen) * 0.5) |
|
|
| def compute_dead_end_penalty(routes_remaining: int) -> float: |
| return -0.5 if routes_remaining <= 0 else 0.0 |
|
|
| def compute_task_reward( |
| state_before: LifeMetrics, |
| state_after: LifeMetrics, |
| resources_used: dict, |
| actions_taken: int, |
| milestones_achieved: list[str], |
| success_conditions_met: list[bool], |
| exo_events_seen: int, |
| milestones_after_event: int, |
| routes_remaining: int, |
| rollback_used: bool, |
| cascade_collapse: bool, |
| task: Task, |
| reasoning: str = "", |
| completion: str = "", |
| conflict_domain: str = "", |
| step_count: int = 0, |
| max_steps: int = 0, |
| metric_changes: dict = None, |
| cumulative_rel_delta: float = 0.0, |
| action_type: str = "" |
| ) -> tuple[float, dict]: |
| |
| d_baseline = len(task.mutable_world) if task and hasattr(task, 'mutable_world') else None |
| local_reward, local_breakdown = compute_reward(state_before, state_after, resources_used, actions_taken, |
| metric_changes=metric_changes, completion=completion, |
| disruption_baseline=d_baseline, action_type=action_type) |
|
|
| |
| |
| |
| outcome_score_local = local_breakdown["components"].get("outcome", 0.0) |
| milestone_score = compute_milestone_reward(milestones_achieved, task) |
| completion_score = compute_task_completion_reward(success_conditions_met, task) |
| replan_score = compute_replan_bonus(exo_events_seen, milestones_after_event) |
| efficiency_score = local_breakdown["components"].get("efficiency", 0.0) |
| preservation_score = local_breakdown["components"].get("preservation", 0.0) |
| reasoning_score = reward_reasoning_coherence(reasoning, action_type=action_type) |
| |
| |
| timeout_pen = reward_timeout_check(step_count, max_steps, any(success_met for success_met in success_conditions_met) if success_conditions_met else False) |
| dead_end_pen = compute_dead_end_penalty(routes_remaining) |
| |
| |
| |
| base_reward = ( |
| (0.35 * milestone_score) + |
| (0.25 * completion_score) + |
| (0.10 * outcome_score_local) + |
| (0.05 * preservation_score) + |
| (0.10 * replan_score) + |
| (0.10 * efficiency_score) + |
| (0.05 * reasoning_score) |
| ) |
|
|
| |
| penalties = 0.0 |
| fired = [] |
| |
| if timeout_pen < 0: |
| penalties += timeout_pen |
| fired.append("TIMEOUT") |
| |
| if dead_end_pen < 0: |
| penalties += dead_end_pen |
| fired.append("DEAD_END") |
|
|
| if rollback_used: |
| penalties += -0.1 |
| fired.append("ROLLBACK_USED") |
|
|
| if cascade_collapse: |
| penalties += -0.3 |
| fired.append("CASCADE_COLLAPSE") |
|
|
| |
| if actions_taken == 0: |
| penalties += -0.20 |
| fired.append("TASK_INACTION_PENALTY") |
|
|
| |
| if cumulative_rel_delta < -20: |
| penalties += -0.15 |
| fired.append("CUMULATIVE_RELATIONSHIP_EROSION") |
|
|
| final_reward = max(-1.0, min(1.0, base_reward + penalties)) |
|
|
| breakdown = { |
| "components": { |
| "local_metric_delta": outcome_score_local, |
| "milestone": milestone_score, |
| "completion": completion_score, |
| "replan": replan_score, |
| "efficiency": efficiency_score, |
| "reasoning": reasoning_score, |
| "format_compliance": local_breakdown["components"].get("format_compliance", 0.0), |
| "plausibility": local_breakdown["components"].get("plausibility", 0.0), |
| "timeout_penalty": timeout_pen |
| }, |
| "base_reward": base_reward, |
| "penalties_total": penalties, |
| "penalties_fired": fired, |
| "local_breakdown": local_breakdown |
| } |
|
|
| return final_reward, breakdown |
|
|
| def reward_format_compliance(completion: str) -> float: |
| """ |
| Scores the completion based on its format (JSON validity and required fields). |
| |
| Returns: |
| +1.0: Valid JSON with all required fields: |
| action_type, target_domain, metric_changes, resource_cost, reasoning |
| +0.5: Any parseable JSON (including partial/incomplete dicts) |
| -0.5: Invalid JSON / unparseable |
| -1.0: Empty strings or refusal content |
| """ |
| if not completion or len(completion.strip()) < 10: |
| return -1.0 |
| |
| |
| if any(x in completion.lower() for x in ["i cannot", "i'm sorry", "as an ai"]): |
| return -1.0 |
|
|
| |
| json_str = completion.strip() |
| if "```json" in json_str: |
| json_str = json_str.split("```json")[-1].split("```")[0].strip() |
| elif "```" in json_str: |
| json_str = json_str.split("```")[-1].split("```")[0].strip() |
| |
| try: |
| data = json.loads(json_str) |
| required = ["action_type", "target_domain", "metric_changes", "resource_cost", "reasoning"] |
| if isinstance(data, dict) and all(k in data and data.get(k) is not None for k in required): |
| return 1.0 |
| return 0.5 |
| except json.JSONDecodeError: |
| |
| match = re.search(r'\{.*\}', json_str, re.DOTALL) |
| if match: |
| try: |
| data = json.loads(match.group(0)) |
| required = ["action_type", "target_domain", "metric_changes", "resource_cost", "reasoning"] |
| if isinstance(data, dict) and all(k in data and data.get(k) is not None for k in required): |
| return 1.0 |
| return 0.5 |
| except: |
| pass |
| return -0.5 |
|
|
| def reward_plausibility_check(metric_changes: dict, resource_cost: dict) -> float: |
| """ |
| Anti-gaming check. Prevents the model from claiming massive metric changes while spending 0 resources. |
| Resource cost is normalized to comparable units (time/20h, money/$500, energy/100pts). |
| """ |
| total_delta = sum(abs(v) for v in metric_changes.values()) |
|
|
| |
| |
| if not resource_cost or all(v == 0 for v in resource_cost.values()): |
| if total_delta > 3.0: |
| return -0.30 |
| return 0.0 |
|
|
| |
| norm_time = resource_cost.get('time', 0.0) / 20.0 |
| norm_money = resource_cost.get('money', 0.0) / 500.0 |
| norm_energy = resource_cost.get('energy', 0.0) / 100.0 |
| total_cost = norm_time + norm_money + norm_energy |
|
|
| ratio = total_delta / max(0.01, total_cost) |
|
|
| if ratio > 150: |
| return -0.30 |
| if ratio > 80: |
| return -0.10 |
| return 0.0 |
|
|
| def reward_timeout_check(step_count: int, max_steps: int, done: bool) -> float: |
| """ |
| Penalizes episodes that end by reaching the step limit without being resolved. |
| """ |
| if step_count >= max_steps and not done: |
| return -0.20 |
| return 0.0 |
|
|
| def reward_reasoning_coherence(reasoning: str, action_type: str = "") -> float: |
| """ |
| Harden verification of logical consistency. Requires both length and |
| alignment with the chosen action to prevent word-stuffing. |
| """ |
| if not reasoning or len(reasoning.strip()) < 20: |
| return -0.20 |
|
|
| reasoning_lower = reasoning.lower() |
| score = 0.0 |
|
|
| |
| |
| connectors = ["because", "since", "therefore", "due to", "resulting in", "consequently"] |
| if any(c in reasoning_lower for c in connectors): |
| score += 0.05 |
| |
| |
| |
| action_keywords = { |
| "spend": ["cost", "price", "expensive", "money", "budget", "finance"], |
| "rest": ["energy", "sleep", "exhaustion", "recharge", "break"], |
| "communicate": ["talk", "discuss", "speak", "message", "call", "explain"], |
| "delegate": ["hand off", "assign", "help", "junior", "colleague"], |
| "negotiate": ["bargain", "trade", "deal", "terms"], |
| "deprioritize": ["later", "postpone", "unimportant", "drop"], |
| "reschedule": ["reschedule", "delay", "postpone", "move", "time", "calendar", "slot"], |
| "execute": ["route", "plan", "action", "implement", "complete", "resolve", "execute"], |
| } |
| |
| if action_type and action_type in action_keywords: |
| match = any(kw in reasoning_lower for kw in action_keywords[action_type]) |
| if match: |
| score += 0.10 |
| else: |
| score -= 0.20 |
| |
| return max(-0.30, min(0.30, score)) |
|
|
| def main(): |
| |
| print("--- TESTING REWARD SYSTEM ---") |
| |
| |
| state_start = LifeMetrics() |
| state_perfect = copy.deepcopy(state_start) |
| for k in state_perfect.flatten().keys(): |
| domain, sub = k.split('.') |
| current = getattr(getattr(state_perfect, domain), sub) |
| setattr(getattr(state_perfect, domain), sub, current + 10) |
| |
| res_perfect = {"time": 2, "money": 50, "energy": 10} |
| reward_p, break_p = compute_reward(state_start, state_perfect, res_perfect, actions_taken=5) |
| |
| print("\n[SCENARIO 1: PERFECT ACTION]") |
| print(f"Reward: {reward_p:.4f}") |
| print(f"Breakdown: {break_p}") |
|
|
| |
| state_bad = copy.deepcopy(state_start) |
| for k in state_bad.flatten().keys(): |
| if k.startswith('relationships.'): |
| domain, sub = k.split('.') |
| current = getattr(getattr(state_bad, domain), sub) |
| setattr(getattr(state_bad, domain), sub, current - 30) |
| |
| res_bad = {"time": 10, "money": 300, "energy": 80} |
| reward_b, break_b = compute_reward(state_start, state_bad, res_bad, actions_taken=1) |
| |
| print("\n[SCENARIO 2: BAD ACTION (Relationships Tank)]") |
| print(f"Reward: {reward_b:.4f}") |
| print(f"Breakdown: {break_b}") |
|
|
| |
| state_nothing = copy.deepcopy(state_start) |
| res_none = {} |
| reward_n, break_n = compute_reward(state_start, state_nothing, res_none, actions_taken=0) |
| |
| print("\n[SCENARIO 3: INACTION]") |
| print(f"Reward: {reward_n:.4f}") |
| print(f"Breakdown: {break_n}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|