import sys from collections import defaultdict def patch_file(filepath): with open(filepath, 'r', encoding='utf-8') as f: content = f.read() # 1. Add attention_history to train_phase old_train_init = """ action_history = [] consecutive_conf_increases = 0 consecutive_att_drops = 0 locked_phase_timer = 0 last_reward = 0.0""" new_train_init = """ action_history = [] attention_history = [] consecutive_conf_increases = 0 consecutive_att_drops = 0 locked_phase_timer = 0 last_reward = 0.0""" content = content.replace(old_train_init, new_train_init) # 2. Add attention_history update in train_phase loop old_train_loop = """ # Update tracking variables action_history.append(action_idx) if len(action_history) > 3: action_history.pop(0) if obs.confusion > prev_conf: consecutive_conf_increases += 1 else: consecutive_conf_increases = 0 if obs.attention < prev_att: consecutive_att_drops += 1 else: consecutive_att_drops = 0""" new_train_loop = """ # Update tracking variables action_history.append(action_idx) if len(action_history) > 4: action_history.pop(0) attention_history.append(obs.attention) if len(attention_history) > 4: attention_history.pop(0) if obs.confusion >= prev_conf: consecutive_conf_increases += 1 else: consecutive_conf_increases = 0 if obs.attention < prev_att: consecutive_att_drops += 1 else: consecutive_att_drops = 0""" content = content.replace(old_train_loop, new_train_loop) # 3. Add attention_history to evaluate old_eval_init = """ action_history = [] consecutive_conf_increases = 0 consecutive_att_drops = 0 locked_phase_timer = 0 last_reward = 0.0""" new_eval_init = """ action_history = [] attention_history = [] consecutive_conf_increases = 0 consecutive_att_drops = 0 locked_phase_timer = 0 last_reward = 0.0""" content = content.replace(old_eval_init, new_eval_init) # 4. Add attention_history update in evaluate loop old_eval_loop = """ # Update tracking variables action_history.append(action_idx) if len(action_history) > 3: action_history.pop(0) if obs.confusion > prev_conf: consecutive_conf_increases += 1 else: consecutive_conf_increases = 0 if obs.attention < prev_att: consecutive_att_drops += 1 else: consecutive_att_drops = 0""" new_eval_loop = """ # Update tracking variables action_history.append(action_idx) if len(action_history) > 4: action_history.pop(0) attention_history.append(obs.attention) if len(attention_history) > 4: attention_history.pop(0) if obs.confusion >= prev_conf: consecutive_conf_increases += 1 else: consecutive_conf_increases = 0 if obs.attention < prev_att: consecutive_att_drops += 1 else: consecutive_att_drops = 0""" content = content.replace(old_eval_loop, new_eval_loop) # 5. Replace select_action function # Because of how complex it is, we will locate its start and end start_str = "def select_action(" end_str = "# 5. Dataset Loader" start_idx = content.find(start_str) end_idx = content.find(end_str) new_select_action = '''def select_action( q_table: defaultdict, state: tuple, epsilon: float, rng: random.Random, obs_attention: float, obs_confusion: float, consecutive_conf_increases: int, consecutive_att_drops: int, locked_phase_timer: int, action_history: list[int], attention_history: list[float], last_reward: float ) -> int: """Strict hierarchical action selection for Guided RL following policy constraints.""" c, a, m, p, la, ps, ssi = state misc_str = "none" for k, v in MISCONCEPTION_MAP.items(): if v == m: misc_str = k break allowed = list(ACTIONS.keys()) # Helper function to mask out actions def mask_action(action_name: str): idx = ACTION_TO_IDX.get(action_name, -1) if idx in allowed: allowed.remove(idx) def mask_except(action_names: list[str]): idxs = [ACTION_TO_IDX[name] for name in action_names] allowed[:] = [act for act in allowed if act in idxs] def force_action(action_name: str): idx = ACTION_TO_IDX[action_name] allowed[:] = [idx] # --- Step 1: Calculate Action Diversity & History --- action_counts = defaultdict(int) for act in action_history: action_counts[act] += 1 worked_example_idx = ACTION_TO_IDX["worked_example"] we_used_consecutively = len(action_history) >= 2 and action_history[-1] == worked_example_idx and action_history[-2] == worked_example_idx # --- Step 2: "Engagement Floor Rule" (from feedback) --- att_below_3_count = sum(1 for att in attention_history if att < 3.0) engagement_floor_broken = (consecutive_att_drops >= 3) or (att_below_3_count >= 2) if engagement_floor_broken: # Switch away from dominant action if action_history: dominant_action = max(action_counts, key=action_counts.get) if dominant_action in allowed: allowed.remove(dominant_action) # To ensure switch, we force exploration if it was worked_example if dominant_action == worked_example_idx: mask_except(["explain", "question", "analogize"]) # --- Step 3: Hard Constraint 1 - Attention Safety --- # If attention < 2.5: NO worked_example, prefer explain/question if obs_attention < 2.5: mask_action("worked_example") if not engagement_floor_broken: mask_except(["explain", "question"]) # If attention < 3.0: immediate strategy shift required if obs_attention < 3.0: if action_history and action_history[-1] in allowed: allowed.remove(action_history[-1]) # If attention is falling for 2 consecutive steps -> switch strategy class entirely if consecutive_att_drops >= 2: if action_history and action_history[-1] in allowed: allowed.remove(action_history[-1]) # --- Step 4: Hard Constraint 2 - Action Diversity --- # No action type may exceed 60% frequency (in window of 4, max 2 occurrences) for act, count in action_counts.items(): if count >= 2: if act in allowed: allowed.remove(act) # No repeated worked_example more than 2 consecutive times. if we_used_consecutively: mask_action("worked_example") # If allowed is empty because of diversity, force exploration action if not allowed: allowed = list(ACTIONS.keys()) mask_except(["question", "analogize"]) if not allowed: # If those were masked, fallback allowed = [ACTION_TO_IDX["question"], ACTION_TO_IDX["analogize"]] # --- Step 5: Hard Constraint 3 - Confusion Reduction Rule --- # If confusion is NOT decreasing for 2 consecutive steps if consecutive_conf_increases >= 2: if misc_str == "conceptual": mask_except(["analogize", "question"]) elif misc_str == "factual": force_action("correct_fact") elif misc_str == "procedural": if not we_used_consecutively and ACTION_TO_IDX["worked_example"] in allowed: force_action("worked_example") else: mask_except(["explain", "question"]) elif misc_str == "transfer": mask_except(["explain", "analogize"]) # --- Step 6: Hard Constraint 4 - Factual Misconception Rule --- if misc_str == "factual": # Primary actions: correct_fact (must lead), explain (support) # worked_example only if confusion < 6 if obs_confusion >= 6.0: mask_action("worked_example") if len(action_history) == 0: force_action("correct_fact") # Must lead # --- Step 7: Domain Selection Policies (Priors if not overridden) --- if len(allowed) > 1 and not (consecutive_conf_increases >= 2 or engagement_floor_broken or obs_attention < 3.0): if misc_str == "conceptual": mask_except(["explain", "analogize", "question"]) elif misc_str == "procedural": mask_except(["worked_example", "explain", "question"]) elif misc_str == "factual": mask_except(["correct_fact", "explain", "worked_example"]) elif misc_str == "transfer": mask_except(["explain", "analogize", "worked_example", "question"]) if not allowed: allowed = list(ACTIONS.keys()) # Safe fallback if rng.random() < epsilon: return rng.choice(allowed) # Argmax over ALLOWED actions only q_vals = [q_table[state][act] if act in allowed else -1e9 for act in range(N_ACTIONS)] return int(np.argmax(q_vals)) # --------------------------------------------------------------------------- ''' content = content[:start_idx] + new_select_action + content[end_idx:] # 6. We must also fix the function call signatures for select_action old_call_train = """ action_idx = select_action( q_table, state, current_epsilon, rng, obs.attention, obs.confusion, consecutive_conf_increases, consecutive_att_drops, locked_phase_timer, action_history, last_reward )""" new_call_train = """ action_idx = select_action( q_table, state, current_epsilon, rng, obs.attention, obs.confusion, consecutive_conf_increases, consecutive_att_drops, locked_phase_timer, action_history, attention_history, last_reward )""" content = content.replace(old_call_train, new_call_train) old_call_eval = """ action_idx = select_action( q_table, state, 0.0, rng, # epsilon=0 for evaluation obs.attention, obs.confusion, consecutive_conf_increases, consecutive_att_drops, locked_phase_timer, action_history, last_reward )""" new_call_eval = """ action_idx = select_action( q_table, state, 0.0, rng, # epsilon=0 for evaluation obs.attention, obs.confusion, consecutive_conf_increases, consecutive_att_drops, locked_phase_timer, action_history, attention_history, last_reward )""" content = content.replace(old_call_eval, new_call_eval) with open(filepath, 'w', encoding='utf-8') as f: f.write(content) patch_file("scripts/qlearning_pipeline.py")