Spaces:
Paused
Paused
File size: 7,279 Bytes
a52f96d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
"""
Fast unit tests for student agent with progress bars.
Optimized for speed with tqdm progress bars:
- Shows progress during slow operations (model loading, training, evaluation)
- Shared student instance where possible
- Reduced iteration counts for fast tests
- Minimal evaluation sets
"""
import sys
from student_agent import StudentAgent
from mock_task_generator import MockTaskGenerator
import torch
try:
from tqdm import tqdm
HAS_TQDM = True
except ImportError:
HAS_TQDM = False
print("β οΈ tqdm not installed. Install with: pip install tqdm")
# Dummy tqdm if not available
class tqdm:
def __init__(self, iterable=None, *args, **kwargs):
self.iterable = iterable
def __enter__(self):
return self.iterable
def __exit__(self, *args):
pass
def __iter__(self):
return iter(self.iterable) if self.iterable else iter([])
def update(self, n=1):
pass
def test_student_can_load():
"""Test DistilBERT loads successfully (or graceful fallback)."""
print("Testing student initialization...", end=" ", flush=True)
# Model loading can be slow - show that we're working
try:
student = StudentAgent(device='cpu')
print("β
Student model initialized")
return student
except Exception as e:
print(f"β οΈ Error: {e}")
raise
def test_student_can_answer():
"""Test student can predict answers."""
print("Testing answer prediction...", end=" ", flush=True)
student = StudentAgent(device='cpu')
generator = MockTaskGenerator()
task = generator.generate_task('history', 'easy')
answer = student.answer(task)
assert 0 <= answer < 4, f"Answer should be 0-3, got {answer}"
print("β
Student can answer tasks")
def test_student_learns():
"""Test student improves with practice (with progress bar)."""
print("Testing learning capability...", flush=True)
student = StudentAgent(device='cpu')
generator = MockTaskGenerator()
topic = 'science'
# Smaller eval set for speed
print(" Generating eval set...", end=" ", flush=True)
eval_tasks = [generator.generate_task(topic, 'easy') for _ in range(5)]
print("Done")
# Measure initial accuracy
print(" Evaluating initial accuracy...", end=" ", flush=True)
initial_acc = student.evaluate(eval_tasks)
print(f"{initial_acc:.3f}")
# Training with progress bar
num_iterations = 15
print(f" Training on {num_iterations} tasks:")
if HAS_TQDM:
pbar = tqdm(range(num_iterations), desc=" Progress", leave=False)
for i in pbar:
task = generator.generate_task(topic, 'easy')
student.learn(task)
pbar.set_postfix({'tasks': i+1})
else:
# Fallback: simple progress indicator
for i in range(num_iterations):
if (i + 1) % 5 == 0:
print(f" {i+1}/{num_iterations}...", end="\r", flush=True)
task = generator.generate_task(topic, 'easy')
student.learn(task)
print(f" {num_iterations}/{num_iterations} ") # Clear line
# Measure final accuracy
print(" Evaluating final accuracy...", end=" ", flush=True)
final_acc = student.evaluate(eval_tasks)
print(f"{final_acc:.3f}")
improvement = final_acc - initial_acc
print(f"β
Learning verified (improvement: {improvement:+.3f})")
def test_student_forgets():
"""Test memory decay works (with progress bar)."""
print("Testing memory decay...", flush=True)
student = StudentAgent(device='cpu', retention_constant=20.0)
generator = MockTaskGenerator()
topic = 'literature'
# Training with progress bar
num_iterations = 20
print(f" Training on {num_iterations} tasks:")
if HAS_TQDM:
pbar = tqdm(range(num_iterations), desc=" Progress", leave=False)
for i in pbar:
task = generator.generate_task(topic, 'easy')
student.learn(task)
pbar.set_postfix({'tasks': i+1})
else:
for i in range(num_iterations):
if (i + 1) % 5 == 0:
print(f" {i+1}/{num_iterations}...", end="\r", flush=True)
task = generator.generate_task(topic, 'easy')
student.learn(task)
print(f" {num_iterations}/{num_iterations} ")
print(" Evaluating before forgetting...", end=" ", flush=True)
eval_tasks = [generator.generate_task(topic, 'easy') for _ in range(5)]
acc_before = student.evaluate(eval_tasks)
print(f"{acc_before:.3f}")
# Time passes
print(" Simulating time passage (forgetting)...", end=" ", flush=True)
student.advance_time(50.0)
print("Done")
print(" Evaluating after forgetting...", end=" ", flush=True)
acc_after = student.evaluate(eval_tasks)
print(f"{acc_after:.3f}")
if acc_after < acc_before:
print(f"β
Forgetting verified (drop: {acc_before - acc_after:.3f})")
else:
print(f"β οΈ Forgetting minimal (change: {acc_after - acc_before:+.3f})")
def test_student_state():
"""Test state reporting works."""
print("Testing state reporting...", flush=True)
student = StudentAgent(device='cpu')
generator = MockTaskGenerator()
# Training with progress bar
topics_to_test = ['history', 'science']
tasks_per_topic = 5
total_tasks = len(topics_to_test) * tasks_per_topic
print(f" Training on {total_tasks} tasks:")
for topic in topics_to_test:
if HAS_TQDM:
pbar = tqdm(range(tasks_per_topic), desc=f" {topic}", leave=False)
for i in pbar:
task = generator.generate_task(topic, 'easy')
student.learn(task)
else:
for i in range(tasks_per_topic):
task = generator.generate_task(topic, 'easy')
student.learn(task)
state = student.get_state()
assert len(state.topic_accuracies) > 0
assert state.total_timesteps >= 10
print("β
State reporting works")
def run_all_tests():
"""Run all tests with progress indicators."""
print("=" * 60)
print("RUNNING STUDENT AGENT TESTS")
print("=" * 60)
if not HAS_TQDM:
print("π‘ Tip: Install tqdm for progress bars: pip install tqdm")
print()
import time
start_time = time.time()
try:
test_student_can_load()
test_student_can_answer()
test_student_learns()
test_student_forgets()
test_student_state()
elapsed = time.time() - start_time
print()
print("=" * 60)
print(f"π All tests passed! (Total time: {elapsed:.2f}s)")
print("=" * 60)
return True
except Exception as e:
elapsed = time.time() - start_time
print()
print("=" * 60)
print(f"β Test failed after {elapsed:.2f}s")
print(f"Error: {e}")
print("=" * 60)
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)
|