File size: 7,279 Bytes
a52f96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""
Fast unit tests for student agent with progress bars.

Optimized for speed with tqdm progress bars:
- Shows progress during slow operations (model loading, training, evaluation)
- Shared student instance where possible
- Reduced iteration counts for fast tests
- Minimal evaluation sets
"""

import sys
from student_agent import StudentAgent
from mock_task_generator import MockTaskGenerator
import torch

try:
    from tqdm import tqdm
    HAS_TQDM = True
except ImportError:
    HAS_TQDM = False
    print("⚠️  tqdm not installed. Install with: pip install tqdm")
    # Dummy tqdm if not available
    class tqdm:
        def __init__(self, iterable=None, *args, **kwargs):
            self.iterable = iterable
        def __enter__(self):
            return self.iterable
        def __exit__(self, *args):
            pass
        def __iter__(self):
            return iter(self.iterable) if self.iterable else iter([])
        def update(self, n=1):
            pass


def test_student_can_load():
    """Test DistilBERT loads successfully (or graceful fallback)."""
    print("Testing student initialization...", end=" ", flush=True)
    
    # Model loading can be slow - show that we're working
    try:
        student = StudentAgent(device='cpu')
        print("βœ… Student model initialized")
        return student
    except Exception as e:
        print(f"⚠️ Error: {e}")
        raise


def test_student_can_answer():
    """Test student can predict answers."""
    print("Testing answer prediction...", end=" ", flush=True)
    student = StudentAgent(device='cpu')
    generator = MockTaskGenerator()
    
    task = generator.generate_task('history', 'easy')
    answer = student.answer(task)
    
    assert 0 <= answer < 4, f"Answer should be 0-3, got {answer}"
    print("βœ… Student can answer tasks")


def test_student_learns():
    """Test student improves with practice (with progress bar)."""
    print("Testing learning capability...", flush=True)
    student = StudentAgent(device='cpu')
    generator = MockTaskGenerator()
    
    topic = 'science'
    
    # Smaller eval set for speed
    print("   Generating eval set...", end=" ", flush=True)
    eval_tasks = [generator.generate_task(topic, 'easy') for _ in range(5)]
    print("Done")
    
    # Measure initial accuracy
    print("   Evaluating initial accuracy...", end=" ", flush=True)
    initial_acc = student.evaluate(eval_tasks)
    print(f"{initial_acc:.3f}")
    
    # Training with progress bar
    num_iterations = 15
    print(f"   Training on {num_iterations} tasks:")
    
    if HAS_TQDM:
        pbar = tqdm(range(num_iterations), desc="      Progress", leave=False)
        for i in pbar:
            task = generator.generate_task(topic, 'easy')
            student.learn(task)
            pbar.set_postfix({'tasks': i+1})
    else:
        # Fallback: simple progress indicator
        for i in range(num_iterations):
            if (i + 1) % 5 == 0:
                print(f"      {i+1}/{num_iterations}...", end="\r", flush=True)
            task = generator.generate_task(topic, 'easy')
            student.learn(task)
        print(f"      {num_iterations}/{num_iterations}     ")  # Clear line
    
    # Measure final accuracy
    print("   Evaluating final accuracy...", end=" ", flush=True)
    final_acc = student.evaluate(eval_tasks)
    print(f"{final_acc:.3f}")
    
    improvement = final_acc - initial_acc
    print(f"βœ… Learning verified (improvement: {improvement:+.3f})")


def test_student_forgets():
    """Test memory decay works (with progress bar)."""
    print("Testing memory decay...", flush=True)
    student = StudentAgent(device='cpu', retention_constant=20.0)
    generator = MockTaskGenerator()
    
    topic = 'literature'
    
    # Training with progress bar
    num_iterations = 20
    print(f"   Training on {num_iterations} tasks:")
    
    if HAS_TQDM:
        pbar = tqdm(range(num_iterations), desc="      Progress", leave=False)
        for i in pbar:
            task = generator.generate_task(topic, 'easy')
            student.learn(task)
            pbar.set_postfix({'tasks': i+1})
    else:
        for i in range(num_iterations):
            if (i + 1) % 5 == 0:
                print(f"      {i+1}/{num_iterations}...", end="\r", flush=True)
            task = generator.generate_task(topic, 'easy')
            student.learn(task)
        print(f"      {num_iterations}/{num_iterations}     ")
    
    print("   Evaluating before forgetting...", end=" ", flush=True)
    eval_tasks = [generator.generate_task(topic, 'easy') for _ in range(5)]
    acc_before = student.evaluate(eval_tasks)
    print(f"{acc_before:.3f}")
    
    # Time passes
    print("   Simulating time passage (forgetting)...", end=" ", flush=True)
    student.advance_time(50.0)
    print("Done")
    
    print("   Evaluating after forgetting...", end=" ", flush=True)
    acc_after = student.evaluate(eval_tasks)
    print(f"{acc_after:.3f}")
    
    if acc_after < acc_before:
        print(f"βœ… Forgetting verified (drop: {acc_before - acc_after:.3f})")
    else:
        print(f"⚠️ Forgetting minimal (change: {acc_after - acc_before:+.3f})")


def test_student_state():
    """Test state reporting works."""
    print("Testing state reporting...", flush=True)
    student = StudentAgent(device='cpu')
    generator = MockTaskGenerator()
    
    # Training with progress bar
    topics_to_test = ['history', 'science']
    tasks_per_topic = 5
    total_tasks = len(topics_to_test) * tasks_per_topic
    
    print(f"   Training on {total_tasks} tasks:")
    
    for topic in topics_to_test:
        if HAS_TQDM:
            pbar = tqdm(range(tasks_per_topic), desc=f"      {topic}", leave=False)
            for i in pbar:
                task = generator.generate_task(topic, 'easy')
                student.learn(task)
        else:
            for i in range(tasks_per_topic):
                task = generator.generate_task(topic, 'easy')
                student.learn(task)
    
    state = student.get_state()
    
    assert len(state.topic_accuracies) > 0
    assert state.total_timesteps >= 10
    
    print("βœ… State reporting works")


def run_all_tests():
    """Run all tests with progress indicators."""
    print("=" * 60)
    print("RUNNING STUDENT AGENT TESTS")
    print("=" * 60)
    if not HAS_TQDM:
        print("πŸ’‘ Tip: Install tqdm for progress bars: pip install tqdm")
    print()
    
    import time
    start_time = time.time()
    
    try:
        test_student_can_load()
        test_student_can_answer()
        test_student_learns()
        test_student_forgets()
        test_student_state()
        
        elapsed = time.time() - start_time
        print()
        print("=" * 60)
        print(f"πŸŽ‰ All tests passed! (Total time: {elapsed:.2f}s)")
        print("=" * 60)
        return True
    except Exception as e:
        elapsed = time.time() - start_time
        print()
        print("=" * 60)
        print(f"❌ Test failed after {elapsed:.2f}s")
        print(f"Error: {e}")
        print("=" * 60)
        import traceback
        traceback.print_exc()
        return False


if __name__ == "__main__":
    success = run_all_tests()
    sys.exit(0 if success else 1)