File size: 5,615 Bytes
6f44ddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import sys
import re

def patch_file(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read()

    # 1. State Enrichment (get_state_from_obs)
    content = re.sub(
        r'def get_state_from_obs\(obs, last_action_idx, progress_signal, steps_since_improvement\):',
        r'def get_state_from_obs(obs, last_action_idx, prev_last_action_idx, progress_signal, steps_since_improvement):',
        content
    )
    content = re.sub(
        r'return \(c_bin, a_bin, m, progress_signal, last_act, ssi\)',
        r'prev_last_act = prev_last_action_idx if prev_last_action_idx is not None else -1\n    return (c_bin, a_bin, m, progress_signal, last_act, prev_last_act, ssi)',
        content
    )
    
    # 2. Reward Shaping (compute_reward)
    content = re.sub(
        r'def compute_reward\([\s\S]*?done: bool,[\s\S]*?success: bool,[\s\S]*?\) -> float:',
        r'def compute_reward(\n    prev_conf: float,\n    new_conf: float,\n    prev_att: float,\n    new_att: float,\n    done: bool,\n    success: bool,\n    step_number: int,\n    action_idx: int,\n    last_action_idx: int | None,\n    prev_last_action_idx: int | None,\n    steps_since_improvement: int,\n    prev_prev_conf: float = -1.0\n) -> float:',
        content
    )
    
    # Inside compute_reward, replace the old confusion logic
    reward_logic = """    if success:
        return 5.0

    if done and not success:
        return -5.0

    import numpy as np
    # Base: Directional Clamping
    delta = prev_conf - new_conf
    reward = float(np.clip(np.sign(delta) * 1.5, -1.5, 1.5))
    
    # Smoothness & Reversal
    if prev_prev_conf != -1.0:
        if new_conf < prev_conf < prev_prev_conf:
            reward += 1.0  # Smoothness bonus
        elif prev_conf < prev_prev_conf and new_conf > prev_conf:
            reward -= 3.0  # Reversal penalty
            
    # Early Attention Floor
    if new_att < 4.0:
        reward -= 1.5
"""
    content = re.sub(
        r'    if success:[\s\S]*?reward = float\(\(prev_conf - new_conf\) \* 1\.5\)',
        reward_logic,
        content
    )
    
    # 3. Sequence Masking (select_action)
    mask_logic = """    # Action Masking based on Sequence
    for a in range(N_ACTIONS):
        if len(action_history) >= 3 and action_history[-3:] == [a, a, a]:
            mask[a] = -np.inf
            
    if len(action_history) >= 2:
        we_idx = ACTION_TO_IDX["worked_example"]
        if action_history[-2:] == [we_idx, we_idx]:
            mask[we_idx] = -np.inf"""
            
    content = content.replace('    # Safe fallback', mask_logic + '\n\n    # Safe fallback')
    
    # 4. Decoupled Epsilon (train_phase)
    # Replaces the piecewise decay loop in train_phase
    epsilon_decay_logic = """        # Decoupled Domain Epsilon
        decay_rate = 0.999 if misconception in ["procedural", "factual"] else 0.995
        if ep < 500:
            epsilon = 1.0
        else:
            epsilon = max(epsilon_min, epsilon * decay_rate)"""
            
    content = re.sub(
        r'        # Piecewise decay[\s\S]*?epsilon = max\(epsilon_min, epsilon \* 0\.999\)',
        epsilon_decay_logic,
        content
    )
    
    # 5. Domain Confidence Merge (train)
    merge_logic = """    for s in all_states:
        m_val = s[2] # misconception_int
        m_str = ""
        for name, val in MISCONCEPTION_MAP.items():
            if val == m_val:
                m_str = name
                break
                
        vals = []
        weights = []
        for i, misc in enumerate(misconceptions):
            if s in q_tables[i]:
                vals.append(q_tables[i][s])
                weights.append(0.9 if misc == m_str else 0.1)
                
        if vals:
            weights = np.array(weights)
            weights = weights / np.sum(weights)
            merged_q[s] = np.average(vals, axis=0, weights=weights)"""
            
    content = re.sub(
        r'    for s in all_states:[\s\S]*?merged_q\[s\] = 0\.7 \* np\.max\(vals, axis=0\) \+ 0\.3 \* np\.mean\(vals, axis=0\)',
        merge_logic,
        content
    )
    
    # 6. Fix references to get_state_from_obs and compute_reward
    content = re.sub(
        r'state = get_state_from_obs\(obs, last_action_idx, progress_signal, steps_since_improvement\)',
        r'state = get_state_from_obs(obs, last_action_idx, prev_last_action_idx, progress_signal, steps_since_improvement)',
        content
    )
    
    content = re.sub(
        r'next_state = get_state_from_obs\(obs, action_idx, progress_signal, steps_since_improvement\)',
        r'next_state = get_state_from_obs(obs, action_idx, last_action_idx, progress_signal, steps_since_improvement)',
        content
    )
    
    # For compute_reward, we need to pass prev_prev_conf. We must track it in the main loops.
    # In train_phase:
    content = re.sub(
        r'            prev_conf = obs.confusion\n            prev_att  = obs.attention',
        r'            prev_prev_conf = prev_conf if step > 1 else -1.0\n            prev_conf = obs.confusion\n            prev_att  = obs.attention',
        content
    )
    
    content = re.sub(
        r'                steps_since_improvement=steps_since_improvement,\n            \)',
        r'                steps_since_improvement=steps_since_improvement,\n                prev_prev_conf=prev_prev_conf\n            )',
        content
    )

    with open(filepath, 'w', encoding='utf-8') as f:
        f.write(content)
        
if __name__ == "__main__":
    patch_file("scripts/qlearning_pipeline.py")
    print("Patch applied.")