Spaces:
Sleeping
Sleeping
| """ | |
| qlearning_pipeline.py — Q-learning training pipeline for EduForge. | |
| Modular pipeline: | |
| 1. Dataset Loader — load & validate training_samples.json | |
| 2. Q-table Bootstrap — seed Q-values from offline dataset | |
| 3. Training Loop — adaptive epsilon-greedy online Q-learning | |
| 4. Evaluation — greedy policy rollouts with reporting | |
| 5. Interactive REPL — human-in-the-loop tutoring | |
| Entry point: python scripts/qlearning_pipeline.py | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import pickle | |
| import random | |
| import sys | |
| from collections import defaultdict | |
| from typing import Any | |
| import numpy as np | |
| # --------------------------------------------------------------------------- | |
| # Path setup | |
| # --------------------------------------------------------------------------- | |
| _ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| if _ROOT not in sys.path: | |
| sys.path.insert(0, _ROOT) | |
| from src.environment.openenv_wrapper import EduForgeEnv # noqa: E402 | |
| # --------------------------------------------------------------------------- | |
| # Action catalogue | |
| # --------------------------------------------------------------------------- | |
| ACTIONS: dict[int, str] = { | |
| 0: "explain", | |
| 1: "worked_example", | |
| 2: "question", | |
| 3: "correct_fact", | |
| 4: "analogize", | |
| } | |
| ACTION_TO_IDX: dict[str, int] = {v: k for k, v in ACTIONS.items()} | |
| N_ACTIONS = len(ACTIONS) | |
| MISCONCEPTION_MAP: dict[str, int] = { | |
| "none": 0, "procedural": 1, "conceptual": 2, "factual": 3, "transfer": 4, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| DATASET_PATH = os.path.join(_ROOT, "src", "environment", "training_samples.json") | |
| MODEL_DIR = os.path.join(_ROOT, "models") | |
| MODEL_PATH = os.path.join(MODEL_DIR, "q_table.pkl") | |
| REQUIRED_FIELDS = { | |
| "misconception", "confusion", "attention", | |
| "action", "next_confusion", "next_attention", "reward", "done", | |
| } | |
| # Hyperparameters | |
| ALPHA_BOOTSTRAP = 0.2 | |
| BOOTSTRAP_EPOCHS = 3 | |
| ALPHA = 0.15 | |
| GAMMA = 0.92 | |
| EPSILON_START = 1.0 | |
| EPSILON_MIN = 0.01 | |
| N_EPISODES = 4000 | |
| MAX_STEPS = 15 | |
| EVAL_EPISODES = 80 # 4 misconceptions × 20 seeds each | |
| SEED = 42 | |
| # Thresholds (must match openenv_wrapper.py) | |
| DONE_CONFUSION_THRESHOLD = 2.0 | |
| ATTENTION_FAILURE_THRESHOLD = 0.5 # Match the environment's floor (ATTENTION_FLOOR) | |
| # Q-value clipping — prevent explosion | |
| Q_VALUE_CLIP = 15.0 | |
| # --------------------------------------------------------------------------- | |
| # 1. State discretisation — integer buckets, compact space | |
| # --------------------------------------------------------------------------- | |
| # Coarse bin edges for discretization | |
| _CONF_BINS = [0, 2, 4, 6, 8, 10.01] # 5 bins | |
| _ATT_BINS = [0, 2, 4, 6, 8, 10.01] # 5 bins | |
| def _bin_value(val: float, edges: list[float]) -> int: | |
| """Return bin index for a value given sorted bin edges.""" | |
| val = max(edges[0], min(edges[-1] - 0.01, val)) | |
| for i in range(len(edges) - 1): | |
| if val < edges[i + 1]: | |
| return i | |
| return len(edges) - 2 | |
| def get_state( | |
| confusion: float, | |
| attention: float, | |
| misconception: str | int, | |
| step_number: int = 1, | |
| last_action: int | None = None, | |
| prev_last_action: int | None = None, | |
| progress_signal: int = 0, | |
| steps_since_improvement: int = 0 | |
| ) -> tuple: | |
| """ | |
| Map student metrics to a coarse discrete state tuple. | |
| """ | |
| c = _bin_value(confusion, _CONF_BINS) | |
| a = _bin_value(attention, _ATT_BINS) | |
| if isinstance(misconception, str): | |
| m = MISCONCEPTION_MAP.get(misconception, 0) | |
| else: | |
| m = int(misconception) | |
| if step_number <= 5: | |
| p = 0 | |
| elif step_number <= 10: | |
| p = 1 | |
| else: | |
| p = 2 | |
| la = 5 if last_action is None else int(last_action) | |
| ps = progress_signal + 1 | |
| if steps_since_improvement <= 1: | |
| ssi = 0 | |
| elif steps_since_improvement <= 3: | |
| ssi = 1 | |
| else: | |
| ssi = 2 | |
| pla = 5 if prev_last_action is None else int(prev_last_action) | |
| return (c, a, m, p, la, pla, ps, ssi) | |
| def get_state_from_obs( | |
| obs, | |
| last_action_idx: int | None = None, | |
| prev_last_action_idx: int | None = None, | |
| progress_signal: int = 0, | |
| steps_since_improvement: int = 0 | |
| ) -> tuple: | |
| """Extract and discretise the state from an Observation object.""" | |
| return get_state( | |
| obs.confusion, | |
| obs.attention, | |
| obs.misconception_id.value, | |
| obs.turn if hasattr(obs, 'turn') else 1, | |
| last_action_idx, | |
| prev_last_action_idx, | |
| progress_signal, | |
| steps_since_improvement | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # 2. Reward function — Continuous Multi-Component | |
| # --------------------------------------------------------------------------- | |
| def compute_reward( | |
| prev_conf: float, | |
| new_conf: float, | |
| prev_att: float, | |
| new_att: float, | |
| done: bool, | |
| success: bool, | |
| action_idx: int, | |
| misc_str: str, | |
| action_history: list[int], | |
| step: int, | |
| confusion_history: list[float], | |
| prev_reward: float = 0.0 | |
| ) -> float: | |
| """ | |
| Revised reward function with mode-dependent scaling, exponential attention penalties, | |
| and variance control. | |
| """ | |
| reward = 0.0 | |
| conf_delta = prev_conf - new_conf # Positive delta is good | |
| # 4. Mode-Dependent Reward System | |
| if misc_str == "conceptual": | |
| if action_idx in [ACTION_TO_IDX["explain"], ACTION_TO_IDX["analogize"]]: | |
| reward += 1.5 * max(0, conf_delta) | |
| elif action_idx == ACTION_TO_IDX["question"]: | |
| reward -= 0.5 # Mild penalty for over-questioning | |
| elif misc_str == "factual": | |
| if action_idx == ACTION_TO_IDX["correct_fact"]: | |
| reward += 2.0 * max(0, conf_delta) | |
| elif action_idx == ACTION_TO_IDX["explain"]: | |
| reward += 1.0 * max(0, conf_delta) | |
| elif action_idx == ACTION_TO_IDX["question"]: | |
| reward += 0.2 * max(0, conf_delta) | |
| elif misc_str == "procedural": | |
| if action_idx == ACTION_TO_IDX["worked_example"]: | |
| reward += 1.5 * max(0, conf_delta) | |
| if len(action_history) > 0 and action_idx != action_history[-1]: | |
| reward -= 0.5 # Stability > exploration | |
| elif misc_str == "transfer": | |
| if len(action_history) > 0 and action_idx != action_history[-1]: | |
| reward -= 1.0 # Penalize rapid strategy switching | |
| if conf_delta > 0: | |
| reward += 1.2 * conf_delta | |
| # 1. Attention Safety Continuous Penalty | |
| if new_att < 4.0: | |
| reward *= 0.5 # Negative scaling | |
| reward -= 1.0 | |
| if new_att < 2.0: | |
| reward -= (2.0 - new_att) ** 2 # Exponential penalty | |
| # 2. Question Action Control (Negative consequences) | |
| if action_idx == ACTION_TO_IDX["question"]: | |
| if new_conf > prev_conf or new_att < prev_att: | |
| reward -= 2.0 # Immediate negative reward | |
| # 5. Confusion Reduction Rule (Monotonicity Bias) | |
| if len(confusion_history) >= 3: | |
| if confusion_history[-2] < confusion_history[-3] and confusion_history[-1] < confusion_history[-2]: | |
| if new_conf > prev_conf: # Broke a reduction streak | |
| reward -= 2.5 | |
| # 7. Failure Prevention Objective | |
| if step > 10 and len(confusion_history) >= 4: | |
| recent_conf_drop = confusion_history[-4] - new_conf | |
| if recent_conf_drop <= 0.5: | |
| reward -= 1.5 * (step - 10) # Scaling penalty for stagnation | |
| # Terminal Rewards | |
| if done: | |
| if new_att <= 0.5: | |
| reward -= 10.0 | |
| elif success: | |
| reward += 5.0 | |
| else: | |
| reward -= 2.0 | |
| # 6. Reward Variance Control | |
| jump = abs(reward - prev_reward) | |
| if jump > 5.0: | |
| reward -= 0.5 * (jump - 5.0) # Smoothing | |
| norm_factor = {"conceptual": 1.0, "factual": 0.8, "procedural": 1.2, "transfer": 1.5}.get(misc_str, 1.0) | |
| reward /= norm_factor | |
| return float(np.clip(reward, -10.0, 10.0)) | |
| # --------------------------------------------------------------------------- | |
| # 3. Q-Table Architecture & Update | |
| # --------------------------------------------------------------------------- | |
| def create_q_system() -> dict[str, defaultdict]: | |
| """Create a structured dictionary of Q-tables.""" | |
| return { | |
| "shared": defaultdict(lambda: np.zeros(N_ACTIONS, dtype=np.float32)), | |
| "conceptual": defaultdict(lambda: np.zeros(N_ACTIONS, dtype=np.float32)), | |
| "factual": defaultdict(lambda: np.zeros(N_ACTIONS, dtype=np.float32)), | |
| "procedural": defaultdict(lambda: np.zeros(N_ACTIONS, dtype=np.float32)), | |
| "transfer": defaultdict(lambda: np.zeros(N_ACTIONS, dtype=np.float32)), | |
| "none": defaultdict(lambda: np.zeros(N_ACTIONS, dtype=np.float32)), | |
| } | |
| def get_q_values(q_system: dict[str, defaultdict], state: tuple, misc_str: str) -> np.ndarray: | |
| """Q_final(s, a) = Q_shared(s, a) + Q_type(s, a)""" | |
| shared_q = q_system["shared"][state] | |
| type_q = q_system[misc_str][state] | |
| return shared_q + type_q | |
| def update_q( | |
| q_system: dict[str, defaultdict], | |
| state: tuple, | |
| misc_str: str, | |
| action_idx: int, | |
| reward: float, | |
| next_state: tuple, | |
| done: bool, | |
| alpha: float = ALPHA, | |
| gamma: float = GAMMA, | |
| ) -> None: | |
| """Standard Bellman using the combined Q-value.""" | |
| current_q_vals = get_q_values(q_system, state, misc_str) | |
| if done: | |
| best_next = 0.0 | |
| else: | |
| next_q_vals = get_q_values(q_system, next_state, misc_str) | |
| best_next = float(np.max(next_q_vals)) | |
| td_target = reward + gamma * best_next | |
| td_error = td_target - current_q_vals[action_idx] | |
| # Split the TD error update evenly | |
| q_system["shared"][state][action_idx] += (alpha / 2.0) * td_error | |
| q_system[misc_str][state][action_idx] += (alpha / 2.0) * td_error | |
| # Clip Q-values | |
| q_system["shared"][state][action_idx] = float(np.clip(q_system["shared"][state][action_idx], -Q_VALUE_CLIP, Q_VALUE_CLIP)) | |
| q_system[misc_str][state][action_idx] = float(np.clip(q_system[misc_str][state][action_idx], -Q_VALUE_CLIP, Q_VALUE_CLIP)) | |
| # --------------------------------------------------------------------------- | |
| # 4. Action selection — Adaptive Constraints | |
| # --------------------------------------------------------------------------- | |
| def apply_constraints( | |
| attempted_action: int, | |
| action_history: list[int], | |
| prev_att: float, | |
| misc_str: str, | |
| confusion_history: list[float] | |
| ) -> tuple[int, float]: | |
| """Hard safety constraints and rule-based action corrections.""" | |
| final_action = attempted_action | |
| penalty = 0.0 | |
| we_idx = ACTION_TO_IDX["worked_example"] | |
| q_idx = ACTION_TO_IDX["question"] | |
| ex_idx = ACTION_TO_IDX["explain"] | |
| # 1. Attention Safety Enforcement (Hard limits) | |
| if prev_att < 2.5: | |
| if final_action != ex_idx: | |
| final_action = ex_idx | |
| penalty -= 5.0 | |
| elif prev_att < 4.0: | |
| if final_action not in [ex_idx, we_idx]: | |
| final_action = ex_idx | |
| penalty -= 2.0 | |
| # 2. Question Action Control (Max 2 per 5-step window) | |
| if final_action == q_idx: | |
| q_count = action_history[-5:].count(q_idx) | |
| if q_count >= 2: | |
| penalty -= 2.0 * (q_count - 1) | |
| final_action = ex_idx | |
| # 3. Action Stability Rule (Anti-Oscillation) | |
| if final_action == q_idx and len(action_history) >= 2: | |
| if action_history[-1] == q_idx and action_history[-2] == q_idx: | |
| final_action = ex_idx | |
| penalty -= 3.0 | |
| if len(action_history) >= 4: | |
| recent_4 = action_history[-4:] | |
| is_oscillation = ( | |
| recent_4 == [we_idx, q_idx, we_idx, q_idx] or | |
| recent_4 == [q_idx, we_idx, q_idx, we_idx] | |
| ) | |
| if is_oscillation and final_action in [we_idx, q_idx]: | |
| penalty -= 2.5 | |
| final_action = ex_idx | |
| # 5. Confusion Reduction Rule (Monotonicity Bias) | |
| if len(confusion_history) >= 3: | |
| c_curr, c_prev, c_prev2 = confusion_history[-1], confusion_history[-2], confusion_history[-3] | |
| if c_curr > c_prev and c_prev > c_prev2: | |
| if final_action not in [ex_idx, we_idx]: | |
| final_action = ex_idx | |
| penalty -= 3.0 | |
| return final_action, penalty | |
| def select_action( | |
| q_system: dict[str, defaultdict], | |
| state: tuple, | |
| epsilon: float, | |
| rng: random.Random, | |
| obs_attention: float, | |
| misc_str: str, | |
| action_history: list[int], | |
| confusion_history: list[float] | |
| ) -> int: | |
| """Action selection applying strict safety pre-masking.""" | |
| q_vals = get_q_values(q_system, state, misc_str).copy() | |
| allowed = list(ACTIONS.keys()) | |
| def mask_except(allowed_names): | |
| allowed_idxs = [ACTION_TO_IDX[n] for n in allowed_names] | |
| to_remove = [a for a in allowed if a not in allowed_idxs] | |
| for a in to_remove: | |
| allowed.remove(a) | |
| q_vals[a] = -1e9 | |
| # 1. Attention Safety Enforcement | |
| if obs_attention < 2.5: | |
| mask_except(["explain"]) | |
| elif obs_attention < 4.0: | |
| mask_except(["explain", "worked_example"]) | |
| # 5. Confusion Monotonicity Force Switch | |
| if len(confusion_history) >= 3: | |
| if confusion_history[-1] > confusion_history[-2] > confusion_history[-3]: | |
| mask_except(["explain", "worked_example"]) | |
| if rng.random() < epsilon and allowed: | |
| return rng.choice(allowed) | |
| return int(np.argmax(q_vals)) | |
| # --------------------------------------------------------------------------- | |
| # 5. Dataset Loader | |
| # --------------------------------------------------------------------------- | |
| def load_dataset(path: str) -> list[dict[str, Any]]: | |
| if not os.path.isfile(path): | |
| raise FileNotFoundError(f"Dataset not found: {path}") | |
| with open(path, "r", encoding="utf-8") as fh: | |
| raw = json.load(fh) | |
| if not isinstance(raw, list) or len(raw) == 0: | |
| raise ValueError("Dataset must be a non-empty JSON array.") | |
| validated: list[dict[str, Any]] = [] | |
| for i, record in enumerate(raw): | |
| missing = REQUIRED_FIELDS - record.keys() | |
| if missing: | |
| continue | |
| record["confusion"] = float(record["confusion"]) | |
| record["attention"] = float(record["attention"]) | |
| record["next_confusion"] = float(record["next_confusion"]) | |
| record["next_attention"] = float(record["next_attention"]) | |
| record["reward"] = float(record["reward"]) | |
| record["done"] = bool(record["done"]) | |
| validated.append(record) | |
| print(f"[Loader] {len(validated)}/{len(raw)} samples loaded from {path}") | |
| return validated | |
| # --------------------------------------------------------------------------- | |
| # 6. Q-table Bootstrap | |
| # --------------------------------------------------------------------------- | |
| def bootstrap_qtable( | |
| dataset: list[dict[str, Any]], | |
| alpha: float = ALPHA_BOOTSTRAP, | |
| gamma: float = GAMMA, | |
| n_epochs: int = BOOTSTRAP_EPOCHS, | |
| ) -> dict[str, defaultdict]: | |
| q_system = create_q_system() | |
| total_updates = 0 | |
| for epoch in range(1, n_epochs + 1): | |
| count = 0 | |
| for sample in dataset: | |
| action_str = sample["action"] | |
| if action_str not in ACTION_TO_IDX: | |
| continue | |
| action_idx = ACTION_TO_IDX[action_str] | |
| state = get_state( | |
| sample["confusion"], sample["attention"], sample["misconception"], | |
| ) | |
| next_state = get_state( | |
| sample["next_confusion"], sample["next_attention"], sample["misconception"], | |
| ) | |
| s = sample["next_confusion"] < DONE_CONFUSION_THRESHOLD | |
| # Placeholder histories for bootstrap samples | |
| r = compute_reward( | |
| sample["confusion"], sample["next_confusion"], | |
| sample["attention"], sample["next_attention"], | |
| done=s, success=s, | |
| action_idx=action_idx, | |
| misc_str=sample["misconception"], | |
| action_history=[], | |
| step=1, | |
| confusion_history=[sample["confusion"], sample["next_confusion"]], | |
| prev_reward=0.0 | |
| ) | |
| misc_str = sample["misconception"] | |
| if misc_str not in q_system: | |
| misc_str = "none" | |
| update_q(q_system, state, misc_str, action_idx, r, next_state, s, alpha, gamma) | |
| count += 1 | |
| total_updates += count | |
| print(f"[Bootstrap] Done — {total_updates} total updates") | |
| return q_system | |
| # --------------------------------------------------------------------------- | |
| # 7. Save / Load Q-table | |
| # --------------------------------------------------------------------------- | |
| def save_q_table(q_system: dict[str, defaultdict], path: str = MODEL_PATH) -> None: | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| serializable = {k: dict(v) for k, v in q_system.items()} | |
| with open(path, "wb") as fh: | |
| pickle.dump(serializable, fh) | |
| print(f"[Model] Q-table system saved -> {path}") | |
| def load_q_table(path: str = MODEL_PATH) -> dict[str, defaultdict]: | |
| if not os.path.isfile(path): | |
| raise FileNotFoundError(f"No saved Q-table at: {path}") | |
| with open(path, "rb") as fh: | |
| data = pickle.load(fh) | |
| q_system = create_q_system() | |
| for k, v in data.items(): | |
| if k in q_system: | |
| q_system[k].update(v) | |
| print(f"[Model] Q-table system loaded <- {path}") | |
| return q_system | |
| # --------------------------------------------------------------------------- | |
| # 8. Training Loop | |
| # --------------------------------------------------------------------------- | |
| def train( | |
| q_system: dict[str, defaultdict], | |
| n_episodes: int = N_EPISODES, | |
| max_steps: int = MAX_STEPS, | |
| alpha: float = ALPHA, | |
| gamma: float = GAMMA, | |
| epsilon_start: float = EPSILON_START, | |
| epsilon_min: float = EPSILON_MIN, | |
| seed: int = SEED, | |
| ) -> tuple[dict[str, defaultdict], list[float]]: | |
| rng = random.Random(seed) | |
| episode_rewards: list[float] = [] | |
| recent_rewards: list[float] = [] | |
| epsilon = epsilon_start | |
| misconceptions = ["conceptual", "factual", "procedural", "transfer"] | |
| print(f"\n[Training] {n_episodes} episodes | eps={epsilon_start:.2f}->{epsilon_min:.2f}") | |
| print("-" * 60) | |
| for ep in range(1, n_episodes + 1): | |
| misc = rng.choice(misconceptions) | |
| env = EduForgeEnv(seed=rng.randint(0, 99_999), misconception_init=misc) | |
| obs = env.reset() | |
| last_action_idx: int | None = None | |
| prev_last_action_idx: int | None = None | |
| progress_signal = 0 | |
| steps_since_improvement = 0 | |
| action_history = [] | |
| confusion_history = [obs.confusion] | |
| prev_reward = 0.0 | |
| state = get_state_from_obs(obs, last_action_idx, prev_last_action_idx, progress_signal, steps_since_improvement) | |
| total_reward = 0.0 | |
| domain_max_steps = 15 if misc == "procedural" else 10 | |
| for step in range(1, domain_max_steps + 1): | |
| prev_conf = obs.confusion | |
| prev_att = obs.attention | |
| action_idx = select_action( | |
| q_system, state, epsilon, rng, | |
| obs.attention, misc, action_history, confusion_history | |
| ) | |
| attempted_action = action_idx | |
| action_idx, penalty = apply_constraints( | |
| attempted_action, action_history, prev_att, misc, confusion_history | |
| ) | |
| action_tag = f"<STRATEGY>{ACTIONS[action_idx]}</STRATEGY>" | |
| obs, env_reward, _, _ = env.step(action_tag) | |
| action_history.append(action_idx) | |
| confusion_history.append(obs.confusion) | |
| success = obs.confusion < DONE_CONFUSION_THRESHOLD | |
| att_fail = obs.attention <= 0.5 | |
| timeout = (step >= domain_max_steps) | |
| done = success or att_fail or timeout | |
| step_reward = env_reward | |
| # Apply hard constraint penalty | |
| step_reward += penalty | |
| # Update progress signal | |
| confusion_delta = prev_conf - obs.confusion | |
| attention_delta = obs.attention - prev_att | |
| if confusion_delta > 0 or attention_delta > 0: | |
| progress_signal = 1 | |
| steps_since_improvement = 0 | |
| elif confusion_delta < 0 or attention_delta < 0: | |
| progress_signal = -1 | |
| steps_since_improvement += 1 | |
| else: | |
| progress_signal = 0 | |
| steps_since_improvement += 1 | |
| next_state = get_state_from_obs(obs, action_idx, last_action_idx, progress_signal, steps_since_improvement) | |
| update_q(q_system, state, misc, attempted_action, step_reward, next_state, done, alpha, gamma) | |
| total_reward += step_reward | |
| prev_reward = step_reward | |
| state = next_state | |
| prev_last_action_idx = last_action_idx | |
| last_action_idx = action_idx | |
| if done: | |
| break | |
| episode_rewards.append(total_reward) | |
| recent_rewards.append(total_reward) | |
| if len(recent_rewards) > 100: | |
| recent_rewards.pop(0) | |
| # Adaptive Epsilon Update | |
| if len(recent_rewards) == 100 and ep % 10 == 0: | |
| avg_first_half = np.mean(recent_rewards[:50]) | |
| avg_second_half = np.mean(recent_rewards[50:]) | |
| if avg_second_half > avg_first_half + 0.5: | |
| # Improving -> decay faster | |
| epsilon = max(epsilon_min, epsilon * 0.95) | |
| elif avg_second_half < avg_first_half - 0.5: | |
| # Dropping -> increase noise | |
| epsilon = min(1.0, epsilon * 1.1) | |
| else: | |
| # Plateau -> slow decay | |
| epsilon = max(epsilon_min, epsilon * 0.99) | |
| # Base decay early on to ensure it doesn't get stuck at 1.0 initially | |
| if ep < 200: | |
| epsilon = max(epsilon_min, epsilon * 0.995) | |
| if ep % 500 == 0 or ep == 1: | |
| avg = float(np.mean(recent_rewards)) | |
| print(f" Ep {ep:>5}/{n_episodes} | eps={epsilon:.4f} | avg_reward(last 100)={avg:+.4f}") | |
| return q_system, episode_rewards | |
| # --------------------------------------------------------------------------- | |
| # 9. Evaluation | |
| # --------------------------------------------------------------------------- | |
| def evaluate( | |
| q_system: dict[str, defaultdict], | |
| n_episodes: int = EVAL_EPISODES, | |
| max_steps: int = MAX_STEPS, | |
| seed: int = SEED + 1, | |
| ) -> dict[str, Any]: | |
| rng = random.Random(seed) | |
| print("\n" + "=" * 60) | |
| print("EVALUATION — Greedy Policy") | |
| print("=" * 60) | |
| results = {"resolved": 0, "failed_timeout": 0, "failed_attention": 0} | |
| misconception_actions: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) | |
| misconceptions = ["conceptual", "factual", "procedural", "transfer"] | |
| seeds_per_misc = 20 | |
| total_episodes = len(misconceptions) * seeds_per_misc | |
| for misc_idx, fixed_m_str in enumerate(misconceptions): | |
| if fixed_m_str not in results: | |
| results[fixed_m_str] = {"resolved": 0, "failed_timeout": 0, "failed_attention": 0, "steps": [], "rewards": []} | |
| for seed_idx in range(seeds_per_misc): | |
| ep = misc_idx * seeds_per_misc + seed_idx + 1 | |
| env = EduForgeEnv(seed=seed + seed_idx, misconception_init=fixed_m_str, attention_init=8.0) | |
| obs = env.reset() | |
| last_action_idx: int | None = None | |
| prev_last_action_idx: int | None = None | |
| total_reward = 0.0 | |
| final_step = 0 | |
| outcome = "" | |
| progress_signal = 0 | |
| steps_since_improvement = 0 | |
| action_history = [] | |
| confusion_history = [obs.confusion] | |
| prev_reward = 0.0 | |
| m_str = fixed_m_str | |
| state = get_state_from_obs(obs, last_action_idx, prev_last_action_idx, progress_signal, steps_since_improvement) | |
| print(f"\n--- Episode {ep} ---") | |
| print(f" Misconception : {m_str}") | |
| print(f" Initial : confusion={obs.confusion:.2f} attention={obs.attention:.2f}") | |
| domain_max_steps = 15 if m_str == "procedural" else 10 | |
| for step in range(1, domain_max_steps + 1): | |
| prev_conf = obs.confusion | |
| prev_att = obs.attention | |
| action_idx = select_action( | |
| q_system, state, 0.0, rng, | |
| obs.attention, m_str, action_history, confusion_history | |
| ) | |
| attempted_action = action_idx | |
| action_idx, penalty = apply_constraints( | |
| attempted_action, action_history, prev_att, m_str, confusion_history | |
| ) | |
| chosen = ACTIONS[action_idx] | |
| action_tag = f"<STRATEGY>{chosen}</STRATEGY>" | |
| misconception_actions[m_str][chosen] += 1 | |
| obs, env_reward, _, _ = env.step(action_tag) | |
| action_history.append(action_idx) | |
| confusion_history.append(obs.confusion) | |
| success = obs.confusion < DONE_CONFUSION_THRESHOLD | |
| att_fail = obs.attention <= 0.5 | |
| timeout = (step >= domain_max_steps) | |
| done = success or att_fail or timeout | |
| step_reward = env_reward | |
| step_reward += penalty | |
| confusion_delta = prev_conf - obs.confusion | |
| attention_delta = obs.attention - prev_att | |
| if confusion_delta > 0 or attention_delta > 0: | |
| progress_signal = 1 | |
| steps_since_improvement = 0 | |
| elif confusion_delta < 0 or attention_delta < 0: | |
| progress_signal = -1 | |
| steps_since_improvement += 1 | |
| else: | |
| progress_signal = 0 | |
| steps_since_improvement += 1 | |
| total_reward += step_reward | |
| prev_reward = step_reward | |
| state = get_state_from_obs(obs, action_idx, last_action_idx, progress_signal, steps_since_improvement) | |
| print(f" Step {step:>2} | action={chosen:<15} | " | |
| f"confusion={obs.confusion:.2f} attention={obs.attention:.2f} | " | |
| f"reward={step_reward:+.2f}") | |
| prev_last_action_idx = last_action_idx | |
| last_action_idx = action_idx | |
| final_step = step | |
| if done: | |
| if success: | |
| outcome = "[RESOLVED]" | |
| results[m_str]["resolved"] += 1 | |
| print(f" >> RESOLVED confusion={obs.confusion:.2f} < {DONE_CONFUSION_THRESHOLD}") | |
| elif att_fail: | |
| outcome = "[FAILED: disengaged]" | |
| results[m_str]["failed_attention"] += 1 | |
| print(f" >> FAILED attention={obs.attention:.2f} < {ATTENTION_FAILURE_THRESHOLD}") | |
| else: | |
| outcome = "[FAILED: timeout]" | |
| results[m_str]["failed_timeout"] += 1 | |
| print(f" >> FAILED confusion={obs.confusion:.2f} > {DONE_CONFUSION_THRESHOLD} (max steps)") | |
| break | |
| results[m_str]["rewards"].append(total_reward) | |
| results[m_str]["steps"].append(final_step) | |
| print(f" {outcome} after {final_step} step(s) | total_reward={total_reward:+.2f}") | |
| print("\n" + "=" * 60) | |
| print("EVALUATION SUMMARY") | |
| print("=" * 60) | |
| total_res = sum(v["resolved"] for v in results.values() if isinstance(v, dict)) | |
| total_tout = sum(v["failed_timeout"] for v in results.values() if isinstance(v, dict)) | |
| total_att = sum(v["failed_attention"] for v in results.values() if isinstance(v, dict)) | |
| total_eps = total_res + total_tout + total_att | |
| all_r = [] | |
| all_s = [] | |
| for v in results.values(): | |
| if isinstance(v, dict): | |
| all_r.extend(v["rewards"]) | |
| all_s.extend(v["steps"]) | |
| sr = total_res / total_eps * 100 if total_eps > 0 else 0 | |
| var_r = np.var(all_r) if all_r else 0.0 | |
| print(f" Overall Success: {total_res}/{total_eps} ({sr:.0f}%)") | |
| print(f" Overall Avg steps: {np.mean(all_s):.1f}") | |
| print(f" Reward Variance: {var_r:.2f}") | |
| for m, m_data in results.items(): | |
| if not isinstance(m_data, dict): | |
| continue | |
| m_total = m_data["resolved"] + m_data["failed_timeout"] + m_data["failed_attention"] | |
| if m_total == 0: | |
| continue | |
| m_sr = m_data["resolved"] / m_total * 100 | |
| print(f"\n [{m.upper()}] Success: {m_data['resolved']}/{m_total} ({m_sr:.0f}%)") | |
| print(f" Avg steps: {np.mean(m_data['steps']):.1f} | Avg reward: {np.mean(m_data['rewards']):+.2f}") | |
| print(f" Failures: {m_data['failed_timeout']} timeout, {m_data['failed_attention']} attention") | |
| print("\n POLICY — Dominant Strategies per Misconception") | |
| print(" " + "-" * 50) | |
| for m, counts in sorted(misconception_actions.items()): | |
| t = sum(counts.values()) | |
| print(f"\n {m} ({t} actions):") | |
| for act, cnt in sorted(counts.items(), key=lambda x: x[1], reverse=True): | |
| print(f" {act:<15} : {cnt:>3} ({cnt/t*100:>5.1f}%)") | |
| print("\n" + "=" * 60) | |
| return results | |
| # --------------------------------------------------------------------------- | |
| # 10. Human Feedback Hooks | |
| # --------------------------------------------------------------------------- | |
| class FeedbackHook: | |
| REWARD_GOOD = +2.0 | |
| REWARD_CONFUSING = -1.5 | |
| REWARD_BORING = -1.0 | |
| def __init__(self, q_system: dict[str, defaultdict], alpha: float = ALPHA) -> None: | |
| self.q_system = q_system | |
| self.alpha = alpha | |
| def _apply(self, state: tuple, misc_str: str, action_idx: int, reward: float) -> float: | |
| update_q(self.q_system, state, misc_str, action_idx, reward, state, True, self.alpha, GAMMA) | |
| return reward | |
| def good(self, state: tuple, misc_str: str, action_idx: int) -> float: | |
| return self._apply(state, misc_str, action_idx, self.REWARD_GOOD) | |
| def confusing(self, state: tuple, misc_str: str, action_idx: int) -> float: | |
| return self._apply(state, misc_str, action_idx, self.REWARD_CONFUSING) | |
| def boring(self, state: tuple, misc_str: str, action_idx: int) -> float: | |
| return self._apply(state, misc_str, action_idx, self.REWARD_BORING) | |
| # --------------------------------------------------------------------------- | |
| # 11. Interactive REPL | |
| # --------------------------------------------------------------------------- | |
| _HIGH_CONFUSION_KW = { | |
| "don't understand", "dont understand", "lost", "confused", | |
| "no idea", "what", "help", "stuck", "not clear", "makes no sense", | |
| } | |
| _MED_CONFUSION_KW = { | |
| "somewhat", "maybe", "kind of", "sort of", "not sure", | |
| "partially", "a bit", "a little", | |
| } | |
| _ACTION_DESC: dict[str, str] = { | |
| "explain": "Give a clear, step-by-step explanation of the concept.", | |
| "worked_example": "Walk through a fully worked example together.", | |
| "question": "Ask the student a probing question to test understanding.", | |
| "correct_fact": "Directly correct the factual error the student has made.", | |
| "analogize": "Use a real-world analogy to build intuition.", | |
| } | |
| def estimate_state( | |
| user_input: str, misconception: str = "none", | |
| ) -> tuple[float, float, str]: | |
| text = user_input.lower() | |
| if any(kw in text for kw in _HIGH_CONFUSION_KW): | |
| return 8.0, 5.0, misconception | |
| elif any(kw in text for kw in _MED_CONFUSION_KW): | |
| return 5.0, 5.0, misconception | |
| else: | |
| return 3.0, 6.0, misconception | |
| def interact(q_system: dict[str, defaultdict] | None = None) -> None: | |
| if q_system is None: | |
| q_system = load_q_table(MODEL_PATH) | |
| hook = FeedbackHook(q_system) | |
| print("\n" + "=" * 60) | |
| print("EduForge Interactive Tutoring Session") | |
| print("=" * 60) | |
| print(" Misconception types: " + ", ".join(MISCONCEPTION_MAP.keys())) | |
| print(" Commands: 'switch <type>', 'quit'") | |
| print(" Feedback: y = helpful, n = confusing, b = boring") | |
| print("=" * 60) | |
| misconception = "none" | |
| session_pos, session_neg, session_bored = 0, 0, 0 | |
| total_reward = 0.0 | |
| while True: | |
| print(f"\n[Active misconception: {misconception}]") | |
| try: | |
| user_input = input("Student > ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print("\n[Session ended]") | |
| break | |
| if not user_input: | |
| continue | |
| if user_input.lower() in {"quit", "exit"}: | |
| print("[Session ended]") | |
| break | |
| if user_input.lower().startswith("switch "): | |
| req = user_input[7:].strip().lower() | |
| if req in MISCONCEPTION_MAP: | |
| misconception = req | |
| print(f" [System] Switched to: {misconception}") | |
| else: | |
| print(f" [System] Unknown. Options: {list(MISCONCEPTION_MAP)}") | |
| continue | |
| confusion, attention, m = estimate_state(user_input, misconception) | |
| state = get_state(confusion, attention, m) | |
| q_vals = get_q_values(q_system, state, misconception) | |
| action_idx = int(np.argmax(q_vals)) | |
| action_name = ACTIONS[action_idx] | |
| print(f" [State] confusion={confusion:.1f} attention={attention:.1f}") | |
| print(f" [Action] {action_name}") | |
| print(f" [Tutor] {_ACTION_DESC[action_name]}") | |
| try: | |
| fb = input(" Feedback (y/n/b): ").strip().lower() | |
| except (EOFError, KeyboardInterrupt): | |
| print("\n[Session ended]") | |
| break | |
| if fb == "y": | |
| r = hook.good(state, misconception, action_idx) | |
| session_pos += 1 | |
| print(" [+] Positive signal recorded.") | |
| elif fb == "b": | |
| r = hook.boring(state, misconception, action_idx) | |
| session_bored += 1 | |
| print(" [~] Boredom signal recorded — agent adjusting.") | |
| else: | |
| r = hook.confusing(state, misconception, action_idx) | |
| session_neg += 1 | |
| print(" [-] Negative signal recorded — agent adjusting.") | |
| total_reward += r | |
| if confusion < DONE_CONFUSION_THRESHOLD: | |
| print(" [EduForge] Student appears to understand. Great job!") | |
| total_turns = session_pos + session_neg + session_bored | |
| print("\n" + "=" * 60) | |
| print("Session Summary") | |
| print("=" * 60) | |
| if total_turns > 0: | |
| print(f" Helpful : {session_pos} ({session_pos/total_turns*100:.0f}%)") | |
| print(f" Confusing: {session_neg}") | |
| print(f" Boring : {session_bored}") | |
| print(f" Reward : {total_reward:+.1f}") | |
| save_q_table(q_system, MODEL_PATH) | |
| else: | |
| print(" No feedback — Q-table unchanged.") | |
| print("=" * 60) | |
| # --------------------------------------------------------------------------- | |
| # Entry point | |
| # --------------------------------------------------------------------------- | |
| def main() -> None: | |
| random.seed(SEED) | |
| np.random.seed(SEED) | |
| print("=" * 60) | |
| print("EduForge Q-Learning Pipeline") | |
| print("=" * 60) | |
| print("\n[1/4] Loading dataset...") | |
| dataset = load_dataset(DATASET_PATH) | |
| print(f"\n[2/4] Bootstrapping Q-table ({BOOTSTRAP_EPOCHS} epochs)...") | |
| # Disable bootstrapping because offline data does not follow the new hard constraints | |
| # and would poison the initial Q-table. | |
| q_system = create_q_system() | |
| print("\n[3/4] Online Q-learning training...") | |
| q_system, reward_history = train( | |
| q_system, | |
| n_episodes=N_EPISODES, max_steps=MAX_STEPS, | |
| alpha=ALPHA, gamma=GAMMA, | |
| epsilon_start=EPSILON_START, epsilon_min=EPSILON_MIN, | |
| seed=SEED, | |
| ) | |
| thirds = len(reward_history) // 3 or 1 | |
| print(f"\n Reward trend - " | |
| f"first 3rd avg: {float(np.mean(reward_history[:thirds])):+.4f} | " | |
| f"last 3rd avg: {float(np.mean(reward_history[-thirds:])):+.4f}") | |
| save_q_table(q_system, MODEL_PATH) | |
| print("\n[4/4] Evaluating greedy policy...") | |
| evaluate(q_system, n_episodes=EVAL_EPISODES, max_steps=MAX_STEPS, seed=SEED + 1) | |
| print("\nPipeline complete.\n") | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="EduForge Q-Learning Pipeline") | |
| parser.add_argument("--interact", action="store_true") | |
| args = parser.parse_args() | |
| if args.interact: | |
| interact() | |
| else: | |
| main() | |