EduForge-Tutor / patch_script.py
hari15prasad
Initial clean deployment to Hugging Face
6f44ddb
import sys
from collections import defaultdict
def patch_file(filepath):
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
# 1. Add attention_history to train_phase
old_train_init = """ action_history = []
consecutive_conf_increases = 0
consecutive_att_drops = 0
locked_phase_timer = 0
last_reward = 0.0"""
new_train_init = """ action_history = []
attention_history = []
consecutive_conf_increases = 0
consecutive_att_drops = 0
locked_phase_timer = 0
last_reward = 0.0"""
content = content.replace(old_train_init, new_train_init)
# 2. Add attention_history update in train_phase loop
old_train_loop = """ # Update tracking variables
action_history.append(action_idx)
if len(action_history) > 3:
action_history.pop(0)
if obs.confusion > prev_conf:
consecutive_conf_increases += 1
else:
consecutive_conf_increases = 0
if obs.attention < prev_att:
consecutive_att_drops += 1
else:
consecutive_att_drops = 0"""
new_train_loop = """ # Update tracking variables
action_history.append(action_idx)
if len(action_history) > 4:
action_history.pop(0)
attention_history.append(obs.attention)
if len(attention_history) > 4:
attention_history.pop(0)
if obs.confusion >= prev_conf:
consecutive_conf_increases += 1
else:
consecutive_conf_increases = 0
if obs.attention < prev_att:
consecutive_att_drops += 1
else:
consecutive_att_drops = 0"""
content = content.replace(old_train_loop, new_train_loop)
# 3. Add attention_history to evaluate
old_eval_init = """ action_history = []
consecutive_conf_increases = 0
consecutive_att_drops = 0
locked_phase_timer = 0
last_reward = 0.0"""
new_eval_init = """ action_history = []
attention_history = []
consecutive_conf_increases = 0
consecutive_att_drops = 0
locked_phase_timer = 0
last_reward = 0.0"""
content = content.replace(old_eval_init, new_eval_init)
# 4. Add attention_history update in evaluate loop
old_eval_loop = """ # Update tracking variables
action_history.append(action_idx)
if len(action_history) > 3:
action_history.pop(0)
if obs.confusion > prev_conf:
consecutive_conf_increases += 1
else:
consecutive_conf_increases = 0
if obs.attention < prev_att:
consecutive_att_drops += 1
else:
consecutive_att_drops = 0"""
new_eval_loop = """ # Update tracking variables
action_history.append(action_idx)
if len(action_history) > 4:
action_history.pop(0)
attention_history.append(obs.attention)
if len(attention_history) > 4:
attention_history.pop(0)
if obs.confusion >= prev_conf:
consecutive_conf_increases += 1
else:
consecutive_conf_increases = 0
if obs.attention < prev_att:
consecutive_att_drops += 1
else:
consecutive_att_drops = 0"""
content = content.replace(old_eval_loop, new_eval_loop)
# 5. Replace select_action function
# Because of how complex it is, we will locate its start and end
start_str = "def select_action("
end_str = "# 5. Dataset Loader"
start_idx = content.find(start_str)
end_idx = content.find(end_str)
new_select_action = '''def select_action(
q_table: defaultdict,
state: tuple,
epsilon: float,
rng: random.Random,
obs_attention: float,
obs_confusion: float,
consecutive_conf_increases: int,
consecutive_att_drops: int,
locked_phase_timer: int,
action_history: list[int],
attention_history: list[float],
last_reward: float
) -> int:
"""Strict hierarchical action selection for Guided RL following policy constraints."""
c, a, m, p, la, ps, ssi = state
misc_str = "none"
for k, v in MISCONCEPTION_MAP.items():
if v == m:
misc_str = k
break
allowed = list(ACTIONS.keys())
# Helper function to mask out actions
def mask_action(action_name: str):
idx = ACTION_TO_IDX.get(action_name, -1)
if idx in allowed:
allowed.remove(idx)
def mask_except(action_names: list[str]):
idxs = [ACTION_TO_IDX[name] for name in action_names]
allowed[:] = [act for act in allowed if act in idxs]
def force_action(action_name: str):
idx = ACTION_TO_IDX[action_name]
allowed[:] = [idx]
# --- Step 1: Calculate Action Diversity & History ---
action_counts = defaultdict(int)
for act in action_history:
action_counts[act] += 1
worked_example_idx = ACTION_TO_IDX["worked_example"]
we_used_consecutively = len(action_history) >= 2 and action_history[-1] == worked_example_idx and action_history[-2] == worked_example_idx
# --- Step 2: "Engagement Floor Rule" (from feedback) ---
att_below_3_count = sum(1 for att in attention_history if att < 3.0)
engagement_floor_broken = (consecutive_att_drops >= 3) or (att_below_3_count >= 2)
if engagement_floor_broken:
# Switch away from dominant action
if action_history:
dominant_action = max(action_counts, key=action_counts.get)
if dominant_action in allowed:
allowed.remove(dominant_action)
# To ensure switch, we force exploration if it was worked_example
if dominant_action == worked_example_idx:
mask_except(["explain", "question", "analogize"])
# --- Step 3: Hard Constraint 1 - Attention Safety ---
# If attention < 2.5: NO worked_example, prefer explain/question
if obs_attention < 2.5:
mask_action("worked_example")
if not engagement_floor_broken:
mask_except(["explain", "question"])
# If attention < 3.0: immediate strategy shift required
if obs_attention < 3.0:
if action_history and action_history[-1] in allowed:
allowed.remove(action_history[-1])
# If attention is falling for 2 consecutive steps -> switch strategy class entirely
if consecutive_att_drops >= 2:
if action_history and action_history[-1] in allowed:
allowed.remove(action_history[-1])
# --- Step 4: Hard Constraint 2 - Action Diversity ---
# No action type may exceed 60% frequency (in window of 4, max 2 occurrences)
for act, count in action_counts.items():
if count >= 2:
if act in allowed:
allowed.remove(act)
# No repeated worked_example more than 2 consecutive times.
if we_used_consecutively:
mask_action("worked_example")
# If allowed is empty because of diversity, force exploration action
if not allowed:
allowed = list(ACTIONS.keys())
mask_except(["question", "analogize"])
if not allowed: # If those were masked, fallback
allowed = [ACTION_TO_IDX["question"], ACTION_TO_IDX["analogize"]]
# --- Step 5: Hard Constraint 3 - Confusion Reduction Rule ---
# If confusion is NOT decreasing for 2 consecutive steps
if consecutive_conf_increases >= 2:
if misc_str == "conceptual":
mask_except(["analogize", "question"])
elif misc_str == "factual":
force_action("correct_fact")
elif misc_str == "procedural":
if not we_used_consecutively and ACTION_TO_IDX["worked_example"] in allowed:
force_action("worked_example")
else:
mask_except(["explain", "question"])
elif misc_str == "transfer":
mask_except(["explain", "analogize"])
# --- Step 6: Hard Constraint 4 - Factual Misconception Rule ---
if misc_str == "factual":
# Primary actions: correct_fact (must lead), explain (support)
# worked_example only if confusion < 6
if obs_confusion >= 6.0:
mask_action("worked_example")
if len(action_history) == 0:
force_action("correct_fact") # Must lead
# --- Step 7: Domain Selection Policies (Priors if not overridden) ---
if len(allowed) > 1 and not (consecutive_conf_increases >= 2 or engagement_floor_broken or obs_attention < 3.0):
if misc_str == "conceptual":
mask_except(["explain", "analogize", "question"])
elif misc_str == "procedural":
mask_except(["worked_example", "explain", "question"])
elif misc_str == "factual":
mask_except(["correct_fact", "explain", "worked_example"])
elif misc_str == "transfer":
mask_except(["explain", "analogize", "worked_example", "question"])
if not allowed:
allowed = list(ACTIONS.keys()) # Safe fallback
if rng.random() < epsilon:
return rng.choice(allowed)
# Argmax over ALLOWED actions only
q_vals = [q_table[state][act] if act in allowed else -1e9 for act in range(N_ACTIONS)]
return int(np.argmax(q_vals))
# ---------------------------------------------------------------------------
'''
content = content[:start_idx] + new_select_action + content[end_idx:]
# 6. We must also fix the function call signatures for select_action
old_call_train = """ action_idx = select_action(
q_table, state, current_epsilon, rng,
obs.attention, obs.confusion,
consecutive_conf_increases, consecutive_att_drops,
locked_phase_timer, action_history, last_reward
)"""
new_call_train = """ action_idx = select_action(
q_table, state, current_epsilon, rng,
obs.attention, obs.confusion,
consecutive_conf_increases, consecutive_att_drops,
locked_phase_timer, action_history, attention_history, last_reward
)"""
content = content.replace(old_call_train, new_call_train)
old_call_eval = """ action_idx = select_action(
q_table, state, 0.0, rng, # epsilon=0 for evaluation
obs.attention, obs.confusion,
consecutive_conf_increases, consecutive_att_drops,
locked_phase_timer, action_history, last_reward
)"""
new_call_eval = """ action_idx = select_action(
q_table, state, 0.0, rng, # epsilon=0 for evaluation
obs.attention, obs.confusion,
consecutive_conf_increases, consecutive_att_drops,
locked_phase_timer, action_history, attention_history, last_reward
)"""
content = content.replace(old_call_eval, new_call_eval)
with open(filepath, 'w', encoding='utf-8') as f:
f.write(content)
patch_file("scripts/qlearning_pipeline.py")