File size: 17,918 Bytes
77da5ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
import math
import copy
import json
import re
from core.life_state import LifeMetrics
from core.task import Task



def compute_reward(
    state_before: LifeMetrics, 
    state_after: LifeMetrics, 
    resources_used: dict, 
    actions_taken: int,
    metric_changes: dict = None,
    completion: str = None,
    disruption_baseline: int = None,
    action_type: str = ""
) -> tuple[float, dict]:
    """
    Computes the reward for a life step based on changes in LifeMetrics and resource usage.
    
    Args:
        state_before: The state at the start of the step.
        state_after: The state after actions and cascades.
        resources_used: Dict with keys 'time', 'money', 'energy'.
        actions_taken: Integer count of intentional actions performed.
        disruption_baseline: Expected number of metrics affected by an action.
        
    Returns:
        tuple[float, dict]: (final_reward, breakdown_dict)
    """
    before_flat = state_before.flatten()
    after_flat = state_after.flatten()
    
    # 1. OUTCOME SCORE (Weighted average of positive deltas)
    domain_weights = {
        "career": 1/6,
        "finances": 1/6,
        "relationships": 1/6,
        "physical_health": 1/6,
        "mental_wellbeing": 1/6,
        "time": 1/6
    }
    
    # Map sub-metrics to their domains
    submetrics_per_domain = {}
    for k in before_flat.keys():
        domain = k.split('.')[0]
        submetrics_per_domain[domain] = submetrics_per_domain.get(domain, 0) + 1
    
    outcome_score = 0.0
    for k in before_flat.keys():
        domain = k.split('.')[0]
        delta = after_flat[k] - before_flat[k]
        if delta > 0:
            # Each domain is 1/6. Each sub-metric within a domain gets its equal share of that 1/6.
            # Normalize delta by 100 (max possible increase is 100).
            weight = domain_weights[domain] / submetrics_per_domain[domain]
            outcome_score += (delta / 100.0) * weight
            
    # 2. CASCADE CONTAINMENT SCORE
    worsened_count = sum(1 for k in before_flat.keys() if after_flat[k] < before_flat[k])
    total_metrics = len(before_flat)
    cascade_containment_score = 1.0 - (worsened_count / total_metrics)
    
    # 3. RESOURCE EFFICIENCY SCORE
    # Available: time 20, money 500, energy 100
    m_time = resources_used.get('time', 0.0) / 20.0
    m_money = resources_used.get('money', 0.0) / 500.0
    m_energy = resources_used.get('energy', 0.0) / 100.0
    
    # Normalize by total slots (3 resources)
    resource_efficiency_score = 1.0 - ((m_time + m_money + m_energy) / 3.0)
    resource_efficiency_score = max(0.0, min(1.0, resource_efficiency_score))
    
    # 4. RELATIONSHIP PRESERVATION SCORE (Sigmoid applied to average delta)
    rel_keys = [k for k in before_flat.keys() if k.startswith('relationships.')]
    avg_rel_before = sum(before_flat[k] for k in rel_keys) / len(rel_keys)
    avg_rel_after = sum(after_flat[k] for k in rel_keys) / len(rel_keys)
    delta_rel = avg_rel_after - avg_rel_before
    
    # score = 1 / (1 + exp(-delta/10))
    relationship_preservation_score = 1.0 / (1.0 + math.exp(-delta_rel / 10.0))
    
    # FINAL REWARD FORMULA
    base_reward = (
        (0.40 * outcome_score) + 
        (0.25 * cascade_containment_score) + 
        (0.20 * resource_efficiency_score) + 
        (0.15 * relationship_preservation_score)
    )
    
    # PENALTIES
    penalties = 0.0
    fired = []
    
    # -0.50 if ANY metric is below 20 after the step
    if any(v < 20 for v in after_flat.values()):
        penalties -= 0.50
        fired.append("CRITICAL_FLOOR_VIOLATION")
        
    # -0.30 if cascade spread wider than the number of metrics the agent directly changed
    # Scaled baseline from task metadata preferred over hardcoded default
    if disruption_baseline is None:
        disruption_baseline = len(metric_changes) if metric_changes else 2
        
    if worsened_count > disruption_baseline:
        penalties -= 0.30
        fired.append("CASCADE_SPREAD_WIDER")
        
    # -0.40 if actions_taken == 0
    if actions_taken == 0:
        penalties -= 0.40
        fired.append("INACTION_PENALTY")
        
    # -0.15 if relationships domain average dropped more than 20 points
    if delta_rel < -20:
        penalties -= 0.15
        fired.append("RELATIONSHIP_COLLAPSE")

    # [NEW] Plausibility Penalty
    plaus = 0.0
    if metric_changes:
        plaus = reward_plausibility_check(metric_changes, resources_used)
        if plaus < 0:
            penalties += plaus
            fired.append("PLAUSIBILITY_VIOLATION")

    # [NEW] Format Compliance & Reasoning
    comp_reward = 0.0
    reasoning = ""
    if completion:
        comp_reward = reward_format_compliance(completion)
        try:
            # Simple extract reasoning from JSON if possible
            import json
            data = json.loads(completion)
            reasoning = data.get("reasoning", "")
        except:
            pass
    
    # [NEW] Reasoning Alignment (tied to action_type)
    reasoning_score = reward_reasoning_coherence(reasoning, action_type=action_type)
    
    final_reward = max(-1.0, min(1.0, base_reward + penalties))
    
    breakdown = {
        "components": {
            "outcome": outcome_score,
            "containment": cascade_containment_score,
            "efficiency": resource_efficiency_score,
            "preservation": relationship_preservation_score,
            "format_compliance": comp_reward,
            "plausibility": plaus,
            "reasoning_alignment": reasoning_score
        },
        "base_reward": base_reward,
        "penalties_total": penalties,
        "penalties_fired": fired,
        "metrics_worsened": worsened_count,
        "rel_delta": delta_rel
    }
    
    return final_reward, breakdown

def compute_milestone_reward(milestones_achieved: list[str], task: Task) -> float:
    if not task.milestones:
        return 0.0
    total_possible = sum(m.reward for m in task.milestones)
    if total_possible == 0:
        return 0.0
    achieved = sum(m.reward for m in task.milestones if m.id in milestones_achieved)
    return min(1.0, achieved / total_possible)

def compute_task_completion_reward(success_conditions_met: list[bool], task: Task) -> float:
    # A task is completed if any of its target success conditions are satisfied.
    # This handles tasks with multiple alternative goal-states (e.g. choice of routes).
    if not success_conditions_met:
        return 0.0
    return 1.0 if any(success_conditions_met) else 0.0

def compute_replan_bonus(exo_events_seen: int, milestones_after_event: int) -> float:
    # Scale bonus based on ability to bounce back after exogenous events
    if exo_events_seen == 0:
        return 0.0
    return min(1.0, (milestones_after_event / exo_events_seen) * 0.5)

def compute_dead_end_penalty(routes_remaining: int) -> float:
    return -0.5 if routes_remaining <= 0 else 0.0

def compute_task_reward(
    state_before: LifeMetrics,
    state_after: LifeMetrics,
    resources_used: dict,
    actions_taken: int,
    milestones_achieved: list[str],
    success_conditions_met: list[bool],
    exo_events_seen: int,
    milestones_after_event: int,
    routes_remaining: int,
    rollback_used: bool,
    cascade_collapse: bool,
    task: Task,
    reasoning: str = "",
    completion: str = "",
    conflict_domain: str = "",
    step_count: int = 0,
    max_steps: int = 0,
    metric_changes: dict = None,
    cumulative_rel_delta: float = 0.0,
    action_type: str = ""
) -> tuple[float, dict]:
    # 1. Base local components (with scaled disruption baseline from task metadata)
    d_baseline = len(task.mutable_world) if task and hasattr(task, 'mutable_world') else None
    local_reward, local_breakdown = compute_reward(state_before, state_after, resources_used, actions_taken,
                                                   metric_changes=metric_changes, completion=completion,
                                                   disruption_baseline=d_baseline, action_type=action_type)

    # 2. Orchestrator components
    # Use only the raw outcome component from local_breakdown to avoid double-counting 
    # efficiency, containment, or preservation which are added separately below.
    outcome_score_local = local_breakdown["components"].get("outcome", 0.0)
    milestone_score = compute_milestone_reward(milestones_achieved, task)
    completion_score = compute_task_completion_reward(success_conditions_met, task)
    replan_score = compute_replan_bonus(exo_events_seen, milestones_after_event)
    efficiency_score = local_breakdown["components"].get("efficiency", 0.0)
    preservation_score = local_breakdown["components"].get("preservation", 0.0)
    reasoning_score = reward_reasoning_coherence(reasoning, action_type=action_type)
    
    # Check for specific failure cases
    timeout_pen = reward_timeout_check(step_count, max_steps, any(success_met for success_met in success_conditions_met) if success_conditions_met else False)
    dead_end_pen = compute_dead_end_penalty(routes_remaining)
    
    # 3. Final weighting (all components are now unique/non-overlapping)
    # Weights: Milestone 35%, Completion 25%, Outcome 10%, Preservation 5%, Replan 10%, Efficiency 10%, Reasoning 5%
    base_reward = (
        (0.35 * milestone_score) + 
        (0.25 * completion_score) + 
        (0.10 * outcome_score_local) + 
        (0.05 * preservation_score) +
        (0.10 * replan_score) + 
        (0.10 * efficiency_score) + 
        (0.05 * reasoning_score)
    )

    # 4. Penalties
    penalties = 0.0
    fired = []
    
    if timeout_pen < 0:
        penalties += timeout_pen
        fired.append("TIMEOUT")
        
    if dead_end_pen < 0:
        penalties += dead_end_pen
        fired.append("DEAD_END")

    if rollback_used:
        penalties += -0.1
        fired.append("ROLLBACK_USED")

    if cascade_collapse:
        penalties += -0.3
        fired.append("CASCADE_COLLAPSE")

    # Direct inaction penalty — not diluted by the 0.05 local weight
    if actions_taken == 0:
        penalties += -0.20
        fired.append("TASK_INACTION_PENALTY")

    # Cumulative relationship erosion across the episode
    if cumulative_rel_delta < -20:
        penalties += -0.15
        fired.append("CUMULATIVE_RELATIONSHIP_EROSION")

    final_reward = max(-1.0, min(1.0, base_reward + penalties))

    breakdown = {
        "components": {
            "local_metric_delta": outcome_score_local,
            "milestone": milestone_score,
            "completion": completion_score,
            "replan": replan_score,
            "efficiency": efficiency_score,
            "reasoning": reasoning_score,
            "format_compliance": local_breakdown["components"].get("format_compliance", 0.0),
            "plausibility": local_breakdown["components"].get("plausibility", 0.0),
            "timeout_penalty": timeout_pen
        },
        "base_reward": base_reward,
        "penalties_total": penalties,
        "penalties_fired": fired,
        "local_breakdown": local_breakdown
    }

    return final_reward, breakdown

def reward_format_compliance(completion: str) -> float:
    """
    Scores the completion based on its format (JSON validity and required fields).
    
    Returns:
        +1.0: Valid JSON with all required fields:
              action_type, target_domain, metric_changes, resource_cost, reasoning
        +0.5: Any parseable JSON (including partial/incomplete dicts)
        -0.5: Invalid JSON / unparseable
        -1.0: Empty strings or refusal content
    """
    if not completion or len(completion.strip()) < 10:
        return -1.0
        
    # Potential refusal indicators
    if any(x in completion.lower() for x in ["i cannot", "i'm sorry", "as an ai"]):
        return -1.0

    # Extract JSON content from markdown code blocks if present
    json_str = completion.strip()
    if "```json" in json_str:
        json_str = json_str.split("```json")[-1].split("```")[0].strip()
    elif "```" in json_str:
        json_str = json_str.split("```")[-1].split("```")[0].strip()
        
    try:
        data = json.loads(json_str)
        required = ["action_type", "target_domain", "metric_changes", "resource_cost", "reasoning"]
        if isinstance(data, dict) and all(k in data and data.get(k) is not None for k in required):
            return 1.0
        return 0.5
    except json.JSONDecodeError:
        # Final attempt: try to find anything between { and }
        match = re.search(r'\{.*\}', json_str, re.DOTALL)
        if match:
            try:
                data = json.loads(match.group(0))
                required = ["action_type", "target_domain", "metric_changes", "resource_cost", "reasoning"]
                if isinstance(data, dict) and all(k in data and data.get(k) is not None for k in required):
                    return 1.0
                return 0.5
            except:
                pass
        return -0.5

def reward_plausibility_check(metric_changes: dict, resource_cost: dict) -> float:
    """
    Anti-gaming check. Prevents the model from claiming massive metric changes while spending 0 resources.
    Resource cost is normalized to comparable units (time/20h, money/$500, energy/100pts).
    """
    total_delta = sum(abs(v) for v in metric_changes.values())

    # Zero-cost shortcut: any non-trivial claim with no cost at all is implausible
    # Also handles empty resource_cost.
    if not resource_cost or all(v == 0 for v in resource_cost.values()):
        if total_delta > 3.0:
            return -0.30
        return 0.0

    # Normalize each resource dimension to [0,1] before summing
    norm_time   = resource_cost.get('time', 0.0)   / 20.0
    norm_money  = resource_cost.get('money', 0.0)  / 500.0
    norm_energy = resource_cost.get('energy', 0.0) / 100.0
    total_cost  = norm_time + norm_money + norm_energy

    ratio = total_delta / max(0.01, total_cost)

    if ratio > 150:
        return -0.30   # Claiming massive change for virtually free
    if ratio > 80:
        return -0.10   # Highly suspicious efficiency
    return 0.0         # Plausible ratio

def reward_timeout_check(step_count: int, max_steps: int, done: bool) -> float:
    """
    Penalizes episodes that end by reaching the step limit without being resolved.
    """
    if step_count >= max_steps and not done:
        return -0.20
    return 0.0

def reward_reasoning_coherence(reasoning: str, action_type: str = "") -> float:
    """
    Harden verification of logical consistency. Requires both length and 
    alignment with the chosen action to prevent word-stuffing.
    """
    if not reasoning or len(reasoning.strip()) < 20:
        return -0.20 # Severe penalty for lack of effort

    reasoning_lower = reasoning.lower()
    score = 0.0

    # 1. Structural Logic Check
    # Reward use of logical connectors rather than just list of facts
    connectors = ["because", "since", "therefore", "due to", "resulting in", "consequently"]
    if any(c in reasoning_lower for c in connectors):
        score += 0.05
    
    # 2. Action Alignment (Non-Gammable Anti-Hacking)
    # The reasoning MUST logically justify the chosen category.
    action_keywords = {
        "spend": ["cost", "price", "expensive", "money", "budget", "finance"],
        "rest": ["energy", "sleep", "exhaustion", "recharge", "break"],
        "communicate": ["talk", "discuss", "speak", "message", "call", "explain"],
        "delegate": ["hand off", "assign", "help", "junior", "colleague"],
        "negotiate": ["bargain", "trade", "deal", "terms"],
        "deprioritize": ["later", "postpone", "unimportant", "drop"],
        "reschedule": ["reschedule", "delay", "postpone", "move", "time", "calendar", "slot"],
        "execute": ["route", "plan", "action", "implement", "complete", "resolve", "execute"],
    }
    
    if action_type and action_type in action_keywords:
        match = any(kw in reasoning_lower for kw in action_keywords[action_type])
        if match:
            score += 0.10
        else:
            score -= 0.20
            
    return max(-0.30, min(0.30, score))

def main():
    # Scenario setup
    print("--- TESTING REWARD SYSTEM ---")
    
    # 1. PERFECT ACTION: All metrics improve by 10 points
    state_start = LifeMetrics() # Defaults at 70
    state_perfect = copy.deepcopy(state_start)
    for k in state_perfect.flatten().keys():
        domain, sub = k.split('.')
        current = getattr(getattr(state_perfect, domain), sub)
        setattr(getattr(state_perfect, domain), sub, current + 10)
    
    res_perfect = {"time": 2, "money": 50, "energy": 10}
    reward_p, break_p = compute_reward(state_start, state_perfect, res_perfect, actions_taken=5)
    
    print("\n[SCENARIO 1: PERFECT ACTION]")
    print(f"Reward: {reward_p:.4f}")
    print(f"Breakdown: {break_p}")

    # 2. BAD ACTION: Relationships tank by 30 points, everything else stays same
    state_bad = copy.deepcopy(state_start)
    for k in state_bad.flatten().keys():
        if k.startswith('relationships.'):
            domain, sub = k.split('.')
            current = getattr(getattr(state_bad, domain), sub)
            setattr(getattr(state_bad, domain), sub, current - 30)
            
    res_bad = {"time": 10, "money": 300, "energy": 80}
    reward_b, break_b = compute_reward(state_start, state_bad, res_bad, actions_taken=1)
    
    print("\n[SCENARIO 2: BAD ACTION (Relationships Tank)]")
    print(f"Reward: {reward_b:.4f}")
    print(f"Breakdown: {break_b}")

    # 3. INACTION: Nothing changes
    state_nothing = copy.deepcopy(state_start)
    res_none = {}
    reward_n, break_n = compute_reward(state_start, state_nothing, res_none, actions_taken=0)
    
    print("\n[SCENARIO 3: INACTION]")
    print(f"Reward: {reward_n:.4f}")
    print(f"Breakdown: {break_n}")

if __name__ == "__main__":
    main()