Spaces:
Sleeping
Sleeping
File size: 6,692 Bytes
f6db629 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | """
engine.py β Multi-component reward engine for EduForge.
REVISED: Aggressive penalty for information dumping to favor Query over Example.
"""
from __future__ import annotations
import os
import sys
# Ensure the root directory is in the path so AutoTrain can find your files
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '.')))
from dataclasses import dataclass, field
from typing import Optional, List
import numpy as np
from src.environment.student_fsm import MisconceptionType, TutorAction
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
R_OUTCOME_RESOLUTION = 10.0
R_OUTCOME_FAILURE = -3.0
CONFUSION_RESOLUTION = 2.0
SCAFFOLD_ACTIONS = {"question", "hint", "analogize"}
DIRECT_TELL_ACTIONS = {"worked_example", "correct_fact"}
# STRENGTHENED: Minimum steps before any resolution bonus is granted.
MIN_DIAGNOSTIC_DEPTH = 5
# STRENGTHENED: How long the tutor is FORBIDDEN from dumping information.
EARLY_SESSION_THRESHOLD = 5
# ββ REBALANCED DOMAIN PRIORITY ββββββββββββββββββββββββββββββββββββββββββββββ
DOMAIN_PRIORITY = {
MisconceptionType.PROCEDURAL: ["hint", "question", "worked_example"],
MisconceptionType.FACTUAL: ["question", "correct_fact", "explain"],
MisconceptionType.TRANSFER: ["analogize", "question", "worked_example"],
MisconceptionType.CONCEPTUAL: ["question", "analogize", "hint"],
}
# STRENGTHENED: Significant gap between first-priority (Query) and others.
ALIGNMENT_REWARDS = [4.0, 1.5, 0.5]
@dataclass
class RewardComponents:
r_outcome: float = 0.0
r_process: float = 0.0
r_alignment: float = 0.0
r_scaffolding: float = 0.0
r_recovery: float = 0.0
p_penalty: float = 0.0
total: float = 0.0
breakdown: dict = field(default_factory=dict)
class RewardEngine:
def __init__(self) -> None:
self.reset()
def reset(self) -> None:
self._action_streak = {}
self._prev_confusion = None
self._prev_attention = None
self._prev_action = None
self._scaffold_streak = 0
def compute(
self,
*,
confusion_before: float,
confusion_after: float,
attention_after: float,
action_text: str,
format_valid: bool,
done: bool,
done_reason: Optional[str],
attention_before: float = 5.0,
action: Optional[TutorAction] = None,
action_history: Optional[List[str]] = None,
misconception: Optional[MisconceptionType] = None,
episode_length: int = 1,
) -> tuple[float, RewardComponents]:
breakdown: dict[str, float] = {}
action_val = action.value if action else None
confusion_delta = confusion_before - confusion_after
# ------------------------------------------------------------------
# 1. CONFUSION PROGRESS
# ------------------------------------------------------------------
if action_val in SCAFFOLD_ACTIONS:
w_method = 1.5 # Increased from 1.4
elif action_val in DIRECT_TELL_ACTIONS:
w_method = 0.4 # Decreased from 0.6
else:
w_method = 1.0
r_process = 2.0 * w_method * confusion_delta
if action_val in SCAFFOLD_ACTIONS:
self._scaffold_streak += 1
elif action_val in DIRECT_TELL_ACTIONS:
self._scaffold_streak = 0
# ------------------------------------------------------------------
# 2. AGGRESSIVE INFORMATION DUMPING PENALTY
# ------------------------------------------------------------------
p_info_dump = 0.0
if action_val in DIRECT_TELL_ACTIONS:
early_steps_remaining = EARLY_SESSION_THRESHOLD - episode_length + 1
if early_steps_remaining > 0:
# Increased multiplier to -10.0 for step 1
p_info_dump = 2.5 * early_steps_remaining
# ------------------------------------------------------------------
# 3. PEDAGOGICAL ALIGNMENT (High weight on Query/Question)
# ------------------------------------------------------------------
bonus_alignment = 0.0
if action and misconception:
priority_list = DOMAIN_PRIORITY.get(misconception, [])
if action_val in priority_list:
rank = priority_list.index(action_val)
bonus_alignment = ALIGNMENT_REWARDS[min(rank, len(ALIGNMENT_REWARDS) - 1)]
# ------------------------------------------------------------------
# 4. SCAFFOLDING SEQUENCE & STATE IMPROVEMENT
# ------------------------------------------------------------------
bonus_scaffolding = 0.0
r_state_improvement = 0.0
if action_val in SCAFFOLD_ACTIONS and confusion_delta > 0:
r_state_improvement = min(2.0 * confusion_delta, 4.0)
bonus_scaffolding += r_state_improvement
# ------------------------------------------------------------------
# 5. OUTCOME: MIN_DIAGNOSTIC_DEPTH GATE
# ------------------------------------------------------------------
r_outcome = 0.0
r_scaffolded_resolution = 0.0
if done and done_reason == "success":
# If the session ends TOO FAST, penalize even if "successful"
if episode_length < MIN_DIAGNOSTIC_DEPTH:
r_outcome = -5.0
else:
r_outcome = R_OUTCOME_RESOLUTION
if self._scaffold_streak >= 3:
r_scaffolded_resolution = 10.0 # High reward for deep scaffolding
# ------------------------------------------------------------------
# 6. FINAL AGGREGATION
# ------------------------------------------------------------------
total = float(np.clip(
r_process + bonus_alignment + bonus_scaffolding + r_outcome
- p_info_dump, -15.0, 15.0
))
components = RewardComponents(
r_outcome = r_outcome + r_scaffolded_resolution,
r_alignment = bonus_alignment,
r_scaffolding = bonus_scaffolding,
p_penalty = p_info_dump,
total = total,
)
self._prev_confusion = confusion_before
self._prev_action = action
return total, components |