""" qlearning_pipeline.py — Q-learning training pipeline for EduForge. Modular pipeline: 1. Dataset Loader — load & validate training_samples.json 2. Q-table Bootstrap — seed Q-values from offline dataset 3. Training Loop — adaptive epsilon-greedy online Q-learning 4. Evaluation — greedy policy rollouts with reporting 5. Interactive REPL — human-in-the-loop tutoring Entry point: python scripts/qlearning_pipeline.py """ from __future__ import annotations import json import os import pickle import random import sys from collections import defaultdict from typing import Any import numpy as np # --------------------------------------------------------------------------- # Path setup # --------------------------------------------------------------------------- _ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) from src.environment.openenv_wrapper import EduForgeEnv # noqa: E402 # --------------------------------------------------------------------------- # Action catalogue # --------------------------------------------------------------------------- ACTIONS: dict[int, str] = { 0: "explain", 1: "worked_example", 2: "question", 3: "correct_fact", 4: "analogize", } ACTION_TO_IDX: dict[str, int] = {v: k for k, v in ACTIONS.items()} N_ACTIONS = len(ACTIONS) MISCONCEPTION_MAP: dict[str, int] = { "none": 0, "procedural": 1, "conceptual": 2, "factual": 3, "transfer": 4, } # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- DATASET_PATH = os.path.join(_ROOT, "src", "environment", "training_samples.json") MODEL_DIR = os.path.join(_ROOT, "models") MODEL_PATH = os.path.join(MODEL_DIR, "q_table.pkl") REQUIRED_FIELDS = { "misconception", "confusion", "attention", "action", "next_confusion", "next_attention", "reward", "done", } # Hyperparameters ALPHA_BOOTSTRAP = 0.2 BOOTSTRAP_EPOCHS = 3 ALPHA = 0.15 GAMMA = 0.92 EPSILON_START = 1.0 EPSILON_MIN = 0.01 N_EPISODES = 4000 MAX_STEPS = 15 EVAL_EPISODES = 80 # 4 misconceptions × 20 seeds each SEED = 42 # Thresholds (must match openenv_wrapper.py) DONE_CONFUSION_THRESHOLD = 2.0 ATTENTION_FAILURE_THRESHOLD = 0.5 # Match the environment's floor (ATTENTION_FLOOR) # Q-value clipping — prevent explosion Q_VALUE_CLIP = 15.0 # --------------------------------------------------------------------------- # 1. State discretisation — integer buckets, compact space # --------------------------------------------------------------------------- # Coarse bin edges for discretization _CONF_BINS = [0, 2, 4, 6, 8, 10.01] # 5 bins _ATT_BINS = [0, 2, 4, 6, 8, 10.01] # 5 bins def _bin_value(val: float, edges: list[float]) -> int: """Return bin index for a value given sorted bin edges.""" val = max(edges[0], min(edges[-1] - 0.01, val)) for i in range(len(edges) - 1): if val < edges[i + 1]: return i return len(edges) - 2 def get_state( confusion: float, attention: float, misconception: str | int, step_number: int = 1, last_action: int | None = None, prev_last_action: int | None = None, progress_signal: int = 0, steps_since_improvement: int = 0 ) -> tuple: """ Map student metrics to a coarse discrete state tuple. """ c = _bin_value(confusion, _CONF_BINS) a = _bin_value(attention, _ATT_BINS) if isinstance(misconception, str): m = MISCONCEPTION_MAP.get(misconception, 0) else: m = int(misconception) if step_number <= 5: p = 0 elif step_number <= 10: p = 1 else: p = 2 la = 5 if last_action is None else int(last_action) ps = progress_signal + 1 if steps_since_improvement <= 1: ssi = 0 elif steps_since_improvement <= 3: ssi = 1 else: ssi = 2 pla = 5 if prev_last_action is None else int(prev_last_action) return (c, a, m, p, la, pla, ps, ssi) def get_state_from_obs( obs, last_action_idx: int | None = None, prev_last_action_idx: int | None = None, progress_signal: int = 0, steps_since_improvement: int = 0 ) -> tuple: """Extract and discretise the state from an Observation object.""" return get_state( obs.confusion, obs.attention, obs.misconception_id.value, obs.turn if hasattr(obs, 'turn') else 1, last_action_idx, prev_last_action_idx, progress_signal, steps_since_improvement ) # --------------------------------------------------------------------------- # 2. Reward function — Continuous Multi-Component # --------------------------------------------------------------------------- def compute_reward( prev_conf: float, new_conf: float, prev_att: float, new_att: float, done: bool, success: bool, action_idx: int, misc_str: str, action_history: list[int], step: int, confusion_history: list[float], prev_reward: float = 0.0 ) -> float: """ Revised reward function with mode-dependent scaling, exponential attention penalties, and variance control. """ reward = 0.0 conf_delta = prev_conf - new_conf # Positive delta is good # 4. Mode-Dependent Reward System if misc_str == "conceptual": if action_idx in [ACTION_TO_IDX["explain"], ACTION_TO_IDX["analogize"]]: reward += 1.5 * max(0, conf_delta) elif action_idx == ACTION_TO_IDX["question"]: reward -= 0.5 # Mild penalty for over-questioning elif misc_str == "factual": if action_idx == ACTION_TO_IDX["correct_fact"]: reward += 2.0 * max(0, conf_delta) elif action_idx == ACTION_TO_IDX["explain"]: reward += 1.0 * max(0, conf_delta) elif action_idx == ACTION_TO_IDX["question"]: reward += 0.2 * max(0, conf_delta) elif misc_str == "procedural": if action_idx == ACTION_TO_IDX["worked_example"]: reward += 1.5 * max(0, conf_delta) if len(action_history) > 0 and action_idx != action_history[-1]: reward -= 0.5 # Stability > exploration elif misc_str == "transfer": if len(action_history) > 0 and action_idx != action_history[-1]: reward -= 1.0 # Penalize rapid strategy switching if conf_delta > 0: reward += 1.2 * conf_delta # 1. Attention Safety Continuous Penalty if new_att < 4.0: reward *= 0.5 # Negative scaling reward -= 1.0 if new_att < 2.0: reward -= (2.0 - new_att) ** 2 # Exponential penalty # 2. Question Action Control (Negative consequences) if action_idx == ACTION_TO_IDX["question"]: if new_conf > prev_conf or new_att < prev_att: reward -= 2.0 # Immediate negative reward # 5. Confusion Reduction Rule (Monotonicity Bias) if len(confusion_history) >= 3: if confusion_history[-2] < confusion_history[-3] and confusion_history[-1] < confusion_history[-2]: if new_conf > prev_conf: # Broke a reduction streak reward -= 2.5 # 7. Failure Prevention Objective if step > 10 and len(confusion_history) >= 4: recent_conf_drop = confusion_history[-4] - new_conf if recent_conf_drop <= 0.5: reward -= 1.5 * (step - 10) # Scaling penalty for stagnation # Terminal Rewards if done: if new_att <= 0.5: reward -= 10.0 elif success: reward += 5.0 else: reward -= 2.0 # 6. Reward Variance Control jump = abs(reward - prev_reward) if jump > 5.0: reward -= 0.5 * (jump - 5.0) # Smoothing norm_factor = {"conceptual": 1.0, "factual": 0.8, "procedural": 1.2, "transfer": 1.5}.get(misc_str, 1.0) reward /= norm_factor return float(np.clip(reward, -10.0, 10.0)) # --------------------------------------------------------------------------- # 3. Q-Table Architecture & Update # --------------------------------------------------------------------------- def create_q_system() -> dict[str, defaultdict]: """Create a structured dictionary of Q-tables.""" return { "shared": defaultdict(lambda: np.zeros(N_ACTIONS, dtype=np.float32)), "conceptual": defaultdict(lambda: np.zeros(N_ACTIONS, dtype=np.float32)), "factual": defaultdict(lambda: np.zeros(N_ACTIONS, dtype=np.float32)), "procedural": defaultdict(lambda: np.zeros(N_ACTIONS, dtype=np.float32)), "transfer": defaultdict(lambda: np.zeros(N_ACTIONS, dtype=np.float32)), "none": defaultdict(lambda: np.zeros(N_ACTIONS, dtype=np.float32)), } def get_q_values(q_system: dict[str, defaultdict], state: tuple, misc_str: str) -> np.ndarray: """Q_final(s, a) = Q_shared(s, a) + Q_type(s, a)""" shared_q = q_system["shared"][state] type_q = q_system[misc_str][state] return shared_q + type_q def update_q( q_system: dict[str, defaultdict], state: tuple, misc_str: str, action_idx: int, reward: float, next_state: tuple, done: bool, alpha: float = ALPHA, gamma: float = GAMMA, ) -> None: """Standard Bellman using the combined Q-value.""" current_q_vals = get_q_values(q_system, state, misc_str) if done: best_next = 0.0 else: next_q_vals = get_q_values(q_system, next_state, misc_str) best_next = float(np.max(next_q_vals)) td_target = reward + gamma * best_next td_error = td_target - current_q_vals[action_idx] # Split the TD error update evenly q_system["shared"][state][action_idx] += (alpha / 2.0) * td_error q_system[misc_str][state][action_idx] += (alpha / 2.0) * td_error # Clip Q-values q_system["shared"][state][action_idx] = float(np.clip(q_system["shared"][state][action_idx], -Q_VALUE_CLIP, Q_VALUE_CLIP)) q_system[misc_str][state][action_idx] = float(np.clip(q_system[misc_str][state][action_idx], -Q_VALUE_CLIP, Q_VALUE_CLIP)) # --------------------------------------------------------------------------- # 4. Action selection — Adaptive Constraints # --------------------------------------------------------------------------- def apply_constraints( attempted_action: int, action_history: list[int], prev_att: float, misc_str: str, confusion_history: list[float] ) -> tuple[int, float]: """Hard safety constraints and rule-based action corrections.""" final_action = attempted_action penalty = 0.0 we_idx = ACTION_TO_IDX["worked_example"] q_idx = ACTION_TO_IDX["question"] ex_idx = ACTION_TO_IDX["explain"] # 1. Attention Safety Enforcement (Hard limits) if prev_att < 2.5: if final_action != ex_idx: final_action = ex_idx penalty -= 5.0 elif prev_att < 4.0: if final_action not in [ex_idx, we_idx]: final_action = ex_idx penalty -= 2.0 # 2. Question Action Control (Max 2 per 5-step window) if final_action == q_idx: q_count = action_history[-5:].count(q_idx) if q_count >= 2: penalty -= 2.0 * (q_count - 1) final_action = ex_idx # 3. Action Stability Rule (Anti-Oscillation) if final_action == q_idx and len(action_history) >= 2: if action_history[-1] == q_idx and action_history[-2] == q_idx: final_action = ex_idx penalty -= 3.0 if len(action_history) >= 4: recent_4 = action_history[-4:] is_oscillation = ( recent_4 == [we_idx, q_idx, we_idx, q_idx] or recent_4 == [q_idx, we_idx, q_idx, we_idx] ) if is_oscillation and final_action in [we_idx, q_idx]: penalty -= 2.5 final_action = ex_idx # 5. Confusion Reduction Rule (Monotonicity Bias) if len(confusion_history) >= 3: c_curr, c_prev, c_prev2 = confusion_history[-1], confusion_history[-2], confusion_history[-3] if c_curr > c_prev and c_prev > c_prev2: if final_action not in [ex_idx, we_idx]: final_action = ex_idx penalty -= 3.0 return final_action, penalty def select_action( q_system: dict[str, defaultdict], state: tuple, epsilon: float, rng: random.Random, obs_attention: float, misc_str: str, action_history: list[int], confusion_history: list[float] ) -> int: """Action selection applying strict safety pre-masking.""" q_vals = get_q_values(q_system, state, misc_str).copy() allowed = list(ACTIONS.keys()) def mask_except(allowed_names): allowed_idxs = [ACTION_TO_IDX[n] for n in allowed_names] to_remove = [a for a in allowed if a not in allowed_idxs] for a in to_remove: allowed.remove(a) q_vals[a] = -1e9 # 1. Attention Safety Enforcement if obs_attention < 2.5: mask_except(["explain"]) elif obs_attention < 4.0: mask_except(["explain", "worked_example"]) # 5. Confusion Monotonicity Force Switch if len(confusion_history) >= 3: if confusion_history[-1] > confusion_history[-2] > confusion_history[-3]: mask_except(["explain", "worked_example"]) if rng.random() < epsilon and allowed: return rng.choice(allowed) return int(np.argmax(q_vals)) # --------------------------------------------------------------------------- # 5. Dataset Loader # --------------------------------------------------------------------------- def load_dataset(path: str) -> list[dict[str, Any]]: if not os.path.isfile(path): raise FileNotFoundError(f"Dataset not found: {path}") with open(path, "r", encoding="utf-8") as fh: raw = json.load(fh) if not isinstance(raw, list) or len(raw) == 0: raise ValueError("Dataset must be a non-empty JSON array.") validated: list[dict[str, Any]] = [] for i, record in enumerate(raw): missing = REQUIRED_FIELDS - record.keys() if missing: continue record["confusion"] = float(record["confusion"]) record["attention"] = float(record["attention"]) record["next_confusion"] = float(record["next_confusion"]) record["next_attention"] = float(record["next_attention"]) record["reward"] = float(record["reward"]) record["done"] = bool(record["done"]) validated.append(record) print(f"[Loader] {len(validated)}/{len(raw)} samples loaded from {path}") return validated # --------------------------------------------------------------------------- # 6. Q-table Bootstrap # --------------------------------------------------------------------------- def bootstrap_qtable( dataset: list[dict[str, Any]], alpha: float = ALPHA_BOOTSTRAP, gamma: float = GAMMA, n_epochs: int = BOOTSTRAP_EPOCHS, ) -> dict[str, defaultdict]: q_system = create_q_system() total_updates = 0 for epoch in range(1, n_epochs + 1): count = 0 for sample in dataset: action_str = sample["action"] if action_str not in ACTION_TO_IDX: continue action_idx = ACTION_TO_IDX[action_str] state = get_state( sample["confusion"], sample["attention"], sample["misconception"], ) next_state = get_state( sample["next_confusion"], sample["next_attention"], sample["misconception"], ) s = sample["next_confusion"] < DONE_CONFUSION_THRESHOLD # Placeholder histories for bootstrap samples r = compute_reward( sample["confusion"], sample["next_confusion"], sample["attention"], sample["next_attention"], done=s, success=s, action_idx=action_idx, misc_str=sample["misconception"], action_history=[], step=1, confusion_history=[sample["confusion"], sample["next_confusion"]], prev_reward=0.0 ) misc_str = sample["misconception"] if misc_str not in q_system: misc_str = "none" update_q(q_system, state, misc_str, action_idx, r, next_state, s, alpha, gamma) count += 1 total_updates += count print(f"[Bootstrap] Done — {total_updates} total updates") return q_system # --------------------------------------------------------------------------- # 7. Save / Load Q-table # --------------------------------------------------------------------------- def save_q_table(q_system: dict[str, defaultdict], path: str = MODEL_PATH) -> None: os.makedirs(os.path.dirname(path), exist_ok=True) serializable = {k: dict(v) for k, v in q_system.items()} with open(path, "wb") as fh: pickle.dump(serializable, fh) print(f"[Model] Q-table system saved -> {path}") def load_q_table(path: str = MODEL_PATH) -> dict[str, defaultdict]: if not os.path.isfile(path): raise FileNotFoundError(f"No saved Q-table at: {path}") with open(path, "rb") as fh: data = pickle.load(fh) q_system = create_q_system() for k, v in data.items(): if k in q_system: q_system[k].update(v) print(f"[Model] Q-table system loaded <- {path}") return q_system # --------------------------------------------------------------------------- # 8. Training Loop # --------------------------------------------------------------------------- def train( q_system: dict[str, defaultdict], n_episodes: int = N_EPISODES, max_steps: int = MAX_STEPS, alpha: float = ALPHA, gamma: float = GAMMA, epsilon_start: float = EPSILON_START, epsilon_min: float = EPSILON_MIN, seed: int = SEED, ) -> tuple[dict[str, defaultdict], list[float]]: rng = random.Random(seed) episode_rewards: list[float] = [] recent_rewards: list[float] = [] epsilon = epsilon_start misconceptions = ["conceptual", "factual", "procedural", "transfer"] print(f"\n[Training] {n_episodes} episodes | eps={epsilon_start:.2f}->{epsilon_min:.2f}") print("-" * 60) for ep in range(1, n_episodes + 1): misc = rng.choice(misconceptions) env = EduForgeEnv(seed=rng.randint(0, 99_999), misconception_init=misc) obs = env.reset() last_action_idx: int | None = None prev_last_action_idx: int | None = None progress_signal = 0 steps_since_improvement = 0 action_history = [] confusion_history = [obs.confusion] prev_reward = 0.0 state = get_state_from_obs(obs, last_action_idx, prev_last_action_idx, progress_signal, steps_since_improvement) total_reward = 0.0 domain_max_steps = 15 if misc == "procedural" else 10 for step in range(1, domain_max_steps + 1): prev_conf = obs.confusion prev_att = obs.attention action_idx = select_action( q_system, state, epsilon, rng, obs.attention, misc, action_history, confusion_history ) attempted_action = action_idx action_idx, penalty = apply_constraints( attempted_action, action_history, prev_att, misc, confusion_history ) action_tag = f"{ACTIONS[action_idx]}" obs, env_reward, _, _ = env.step(action_tag) action_history.append(action_idx) confusion_history.append(obs.confusion) success = obs.confusion < DONE_CONFUSION_THRESHOLD att_fail = obs.attention <= 0.5 timeout = (step >= domain_max_steps) done = success or att_fail or timeout step_reward = env_reward # Apply hard constraint penalty step_reward += penalty # Update progress signal confusion_delta = prev_conf - obs.confusion attention_delta = obs.attention - prev_att if confusion_delta > 0 or attention_delta > 0: progress_signal = 1 steps_since_improvement = 0 elif confusion_delta < 0 or attention_delta < 0: progress_signal = -1 steps_since_improvement += 1 else: progress_signal = 0 steps_since_improvement += 1 next_state = get_state_from_obs(obs, action_idx, last_action_idx, progress_signal, steps_since_improvement) update_q(q_system, state, misc, attempted_action, step_reward, next_state, done, alpha, gamma) total_reward += step_reward prev_reward = step_reward state = next_state prev_last_action_idx = last_action_idx last_action_idx = action_idx if done: break episode_rewards.append(total_reward) recent_rewards.append(total_reward) if len(recent_rewards) > 100: recent_rewards.pop(0) # Adaptive Epsilon Update if len(recent_rewards) == 100 and ep % 10 == 0: avg_first_half = np.mean(recent_rewards[:50]) avg_second_half = np.mean(recent_rewards[50:]) if avg_second_half > avg_first_half + 0.5: # Improving -> decay faster epsilon = max(epsilon_min, epsilon * 0.95) elif avg_second_half < avg_first_half - 0.5: # Dropping -> increase noise epsilon = min(1.0, epsilon * 1.1) else: # Plateau -> slow decay epsilon = max(epsilon_min, epsilon * 0.99) # Base decay early on to ensure it doesn't get stuck at 1.0 initially if ep < 200: epsilon = max(epsilon_min, epsilon * 0.995) if ep % 500 == 0 or ep == 1: avg = float(np.mean(recent_rewards)) print(f" Ep {ep:>5}/{n_episodes} | eps={epsilon:.4f} | avg_reward(last 100)={avg:+.4f}") return q_system, episode_rewards # --------------------------------------------------------------------------- # 9. Evaluation # --------------------------------------------------------------------------- def evaluate( q_system: dict[str, defaultdict], n_episodes: int = EVAL_EPISODES, max_steps: int = MAX_STEPS, seed: int = SEED + 1, ) -> dict[str, Any]: rng = random.Random(seed) print("\n" + "=" * 60) print("EVALUATION — Greedy Policy") print("=" * 60) results = {"resolved": 0, "failed_timeout": 0, "failed_attention": 0} misconception_actions: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) misconceptions = ["conceptual", "factual", "procedural", "transfer"] seeds_per_misc = 20 total_episodes = len(misconceptions) * seeds_per_misc for misc_idx, fixed_m_str in enumerate(misconceptions): if fixed_m_str not in results: results[fixed_m_str] = {"resolved": 0, "failed_timeout": 0, "failed_attention": 0, "steps": [], "rewards": []} for seed_idx in range(seeds_per_misc): ep = misc_idx * seeds_per_misc + seed_idx + 1 env = EduForgeEnv(seed=seed + seed_idx, misconception_init=fixed_m_str, attention_init=8.0) obs = env.reset() last_action_idx: int | None = None prev_last_action_idx: int | None = None total_reward = 0.0 final_step = 0 outcome = "" progress_signal = 0 steps_since_improvement = 0 action_history = [] confusion_history = [obs.confusion] prev_reward = 0.0 m_str = fixed_m_str state = get_state_from_obs(obs, last_action_idx, prev_last_action_idx, progress_signal, steps_since_improvement) print(f"\n--- Episode {ep} ---") print(f" Misconception : {m_str}") print(f" Initial : confusion={obs.confusion:.2f} attention={obs.attention:.2f}") domain_max_steps = 15 if m_str == "procedural" else 10 for step in range(1, domain_max_steps + 1): prev_conf = obs.confusion prev_att = obs.attention action_idx = select_action( q_system, state, 0.0, rng, obs.attention, m_str, action_history, confusion_history ) attempted_action = action_idx action_idx, penalty = apply_constraints( attempted_action, action_history, prev_att, m_str, confusion_history ) chosen = ACTIONS[action_idx] action_tag = f"{chosen}" misconception_actions[m_str][chosen] += 1 obs, env_reward, _, _ = env.step(action_tag) action_history.append(action_idx) confusion_history.append(obs.confusion) success = obs.confusion < DONE_CONFUSION_THRESHOLD att_fail = obs.attention <= 0.5 timeout = (step >= domain_max_steps) done = success or att_fail or timeout step_reward = env_reward step_reward += penalty confusion_delta = prev_conf - obs.confusion attention_delta = obs.attention - prev_att if confusion_delta > 0 or attention_delta > 0: progress_signal = 1 steps_since_improvement = 0 elif confusion_delta < 0 or attention_delta < 0: progress_signal = -1 steps_since_improvement += 1 else: progress_signal = 0 steps_since_improvement += 1 total_reward += step_reward prev_reward = step_reward state = get_state_from_obs(obs, action_idx, last_action_idx, progress_signal, steps_since_improvement) print(f" Step {step:>2} | action={chosen:<15} | " f"confusion={obs.confusion:.2f} attention={obs.attention:.2f} | " f"reward={step_reward:+.2f}") prev_last_action_idx = last_action_idx last_action_idx = action_idx final_step = step if done: if success: outcome = "[RESOLVED]" results[m_str]["resolved"] += 1 print(f" >> RESOLVED confusion={obs.confusion:.2f} < {DONE_CONFUSION_THRESHOLD}") elif att_fail: outcome = "[FAILED: disengaged]" results[m_str]["failed_attention"] += 1 print(f" >> FAILED attention={obs.attention:.2f} < {ATTENTION_FAILURE_THRESHOLD}") else: outcome = "[FAILED: timeout]" results[m_str]["failed_timeout"] += 1 print(f" >> FAILED confusion={obs.confusion:.2f} > {DONE_CONFUSION_THRESHOLD} (max steps)") break results[m_str]["rewards"].append(total_reward) results[m_str]["steps"].append(final_step) print(f" {outcome} after {final_step} step(s) | total_reward={total_reward:+.2f}") print("\n" + "=" * 60) print("EVALUATION SUMMARY") print("=" * 60) total_res = sum(v["resolved"] for v in results.values() if isinstance(v, dict)) total_tout = sum(v["failed_timeout"] for v in results.values() if isinstance(v, dict)) total_att = sum(v["failed_attention"] for v in results.values() if isinstance(v, dict)) total_eps = total_res + total_tout + total_att all_r = [] all_s = [] for v in results.values(): if isinstance(v, dict): all_r.extend(v["rewards"]) all_s.extend(v["steps"]) sr = total_res / total_eps * 100 if total_eps > 0 else 0 var_r = np.var(all_r) if all_r else 0.0 print(f" Overall Success: {total_res}/{total_eps} ({sr:.0f}%)") print(f" Overall Avg steps: {np.mean(all_s):.1f}") print(f" Reward Variance: {var_r:.2f}") for m, m_data in results.items(): if not isinstance(m_data, dict): continue m_total = m_data["resolved"] + m_data["failed_timeout"] + m_data["failed_attention"] if m_total == 0: continue m_sr = m_data["resolved"] / m_total * 100 print(f"\n [{m.upper()}] Success: {m_data['resolved']}/{m_total} ({m_sr:.0f}%)") print(f" Avg steps: {np.mean(m_data['steps']):.1f} | Avg reward: {np.mean(m_data['rewards']):+.2f}") print(f" Failures: {m_data['failed_timeout']} timeout, {m_data['failed_attention']} attention") print("\n POLICY — Dominant Strategies per Misconception") print(" " + "-" * 50) for m, counts in sorted(misconception_actions.items()): t = sum(counts.values()) print(f"\n {m} ({t} actions):") for act, cnt in sorted(counts.items(), key=lambda x: x[1], reverse=True): print(f" {act:<15} : {cnt:>3} ({cnt/t*100:>5.1f}%)") print("\n" + "=" * 60) return results # --------------------------------------------------------------------------- # 10. Human Feedback Hooks # --------------------------------------------------------------------------- class FeedbackHook: REWARD_GOOD = +2.0 REWARD_CONFUSING = -1.5 REWARD_BORING = -1.0 def __init__(self, q_system: dict[str, defaultdict], alpha: float = ALPHA) -> None: self.q_system = q_system self.alpha = alpha def _apply(self, state: tuple, misc_str: str, action_idx: int, reward: float) -> float: update_q(self.q_system, state, misc_str, action_idx, reward, state, True, self.alpha, GAMMA) return reward def good(self, state: tuple, misc_str: str, action_idx: int) -> float: return self._apply(state, misc_str, action_idx, self.REWARD_GOOD) def confusing(self, state: tuple, misc_str: str, action_idx: int) -> float: return self._apply(state, misc_str, action_idx, self.REWARD_CONFUSING) def boring(self, state: tuple, misc_str: str, action_idx: int) -> float: return self._apply(state, misc_str, action_idx, self.REWARD_BORING) # --------------------------------------------------------------------------- # 11. Interactive REPL # --------------------------------------------------------------------------- _HIGH_CONFUSION_KW = { "don't understand", "dont understand", "lost", "confused", "no idea", "what", "help", "stuck", "not clear", "makes no sense", } _MED_CONFUSION_KW = { "somewhat", "maybe", "kind of", "sort of", "not sure", "partially", "a bit", "a little", } _ACTION_DESC: dict[str, str] = { "explain": "Give a clear, step-by-step explanation of the concept.", "worked_example": "Walk through a fully worked example together.", "question": "Ask the student a probing question to test understanding.", "correct_fact": "Directly correct the factual error the student has made.", "analogize": "Use a real-world analogy to build intuition.", } def estimate_state( user_input: str, misconception: str = "none", ) -> tuple[float, float, str]: text = user_input.lower() if any(kw in text for kw in _HIGH_CONFUSION_KW): return 8.0, 5.0, misconception elif any(kw in text for kw in _MED_CONFUSION_KW): return 5.0, 5.0, misconception else: return 3.0, 6.0, misconception def interact(q_system: dict[str, defaultdict] | None = None) -> None: if q_system is None: q_system = load_q_table(MODEL_PATH) hook = FeedbackHook(q_system) print("\n" + "=" * 60) print("EduForge Interactive Tutoring Session") print("=" * 60) print(" Misconception types: " + ", ".join(MISCONCEPTION_MAP.keys())) print(" Commands: 'switch ', 'quit'") print(" Feedback: y = helpful, n = confusing, b = boring") print("=" * 60) misconception = "none" session_pos, session_neg, session_bored = 0, 0, 0 total_reward = 0.0 while True: print(f"\n[Active misconception: {misconception}]") try: user_input = input("Student > ").strip() except (EOFError, KeyboardInterrupt): print("\n[Session ended]") break if not user_input: continue if user_input.lower() in {"quit", "exit"}: print("[Session ended]") break if user_input.lower().startswith("switch "): req = user_input[7:].strip().lower() if req in MISCONCEPTION_MAP: misconception = req print(f" [System] Switched to: {misconception}") else: print(f" [System] Unknown. Options: {list(MISCONCEPTION_MAP)}") continue confusion, attention, m = estimate_state(user_input, misconception) state = get_state(confusion, attention, m) q_vals = get_q_values(q_system, state, misconception) action_idx = int(np.argmax(q_vals)) action_name = ACTIONS[action_idx] print(f" [State] confusion={confusion:.1f} attention={attention:.1f}") print(f" [Action] {action_name}") print(f" [Tutor] {_ACTION_DESC[action_name]}") try: fb = input(" Feedback (y/n/b): ").strip().lower() except (EOFError, KeyboardInterrupt): print("\n[Session ended]") break if fb == "y": r = hook.good(state, misconception, action_idx) session_pos += 1 print(" [+] Positive signal recorded.") elif fb == "b": r = hook.boring(state, misconception, action_idx) session_bored += 1 print(" [~] Boredom signal recorded — agent adjusting.") else: r = hook.confusing(state, misconception, action_idx) session_neg += 1 print(" [-] Negative signal recorded — agent adjusting.") total_reward += r if confusion < DONE_CONFUSION_THRESHOLD: print(" [EduForge] Student appears to understand. Great job!") total_turns = session_pos + session_neg + session_bored print("\n" + "=" * 60) print("Session Summary") print("=" * 60) if total_turns > 0: print(f" Helpful : {session_pos} ({session_pos/total_turns*100:.0f}%)") print(f" Confusing: {session_neg}") print(f" Boring : {session_bored}") print(f" Reward : {total_reward:+.1f}") save_q_table(q_system, MODEL_PATH) else: print(" No feedback — Q-table unchanged.") print("=" * 60) # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- def main() -> None: random.seed(SEED) np.random.seed(SEED) print("=" * 60) print("EduForge Q-Learning Pipeline") print("=" * 60) print("\n[1/4] Loading dataset...") dataset = load_dataset(DATASET_PATH) print(f"\n[2/4] Bootstrapping Q-table ({BOOTSTRAP_EPOCHS} epochs)...") # Disable bootstrapping because offline data does not follow the new hard constraints # and would poison the initial Q-table. q_system = create_q_system() print("\n[3/4] Online Q-learning training...") q_system, reward_history = train( q_system, n_episodes=N_EPISODES, max_steps=MAX_STEPS, alpha=ALPHA, gamma=GAMMA, epsilon_start=EPSILON_START, epsilon_min=EPSILON_MIN, seed=SEED, ) thirds = len(reward_history) // 3 or 1 print(f"\n Reward trend - " f"first 3rd avg: {float(np.mean(reward_history[:thirds])):+.4f} | " f"last 3rd avg: {float(np.mean(reward_history[-thirds:])):+.4f}") save_q_table(q_system, MODEL_PATH) print("\n[4/4] Evaluating greedy policy...") evaluate(q_system, n_episodes=EVAL_EPISODES, max_steps=MAX_STEPS, seed=SEED + 1) print("\nPipeline complete.\n") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="EduForge Q-Learning Pipeline") parser.add_argument("--interact", action="store_true") args = parser.parse_args() if args.interact: interact() else: main()