Spaces:
Paused
Paused
File size: 2,511 Bytes
eb62efb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | """Test grader property and grade_episode changes."""
import sys
sys.path.insert(0, '.')
from payops_env.tasks import TASKS
from payops_env.grader import grade_episode
# Test 1: task.grader property
print('=== Test 1: task.grader property ===')
missing = [t.task_id for t in TASKS if not hasattr(t, 'grader')]
if missing:
print(f'FAIL: tasks missing grader property: {missing}')
sys.exit(1)
else:
print(f'PASS: all {len(TASKS)} tasks have grader property')
# Spot-check grader content
t0 = TASKS[0]
g = t0.grader
assert 'type' in g and g['type'] == 'action_match', f'grader bad: {g}'
assert 'correct_action' in g
assert 'partial_credit' in g
assert 'requires_investigation' in g
assert 'regulatory_action' in g
assert 'key_flags' in g
print(f'PASS: grader property has required keys: {sorted(g.keys())}')
# Test 2: grade_episode per_task_rewards have grader key
print()
print('=== Test 2: grade_episode per_task_rewards have grader key ===')
sample_tasks = list(TASKS[:5])
sample_actions = [t.correct_action for t in sample_tasks]
result = grade_episode(sample_actions, sample_tasks)
missing_gr = [pt['task_id'] for pt in result.per_task_rewards if 'grader' not in pt]
if missing_gr:
print(f'FAIL: per_task_rewards entries missing grader key: {missing_gr}')
sys.exit(1)
else:
print(f'PASS: all {len(result.per_task_rewards)} per_task_rewards entries have grader key')
print(f'PASS: score={result.normalised_score}')
# Test all 30 tasks
result_all = grade_episode([t.correct_action for t in TASKS], list(TASKS))
missing_all = [pt['task_id'] for pt in result_all.per_task_rewards if 'grader' not in pt]
assert not missing_all, f'FAIL: {missing_all}'
print(f'PASS: all 30 tasks graded with grader key in per_task_rewards')
print(f'PASS: score={result_all.normalised_score}')
# Test 3: openenv.yaml has task definitions with grader
print()
print('=== Test 3: openenv.yaml task definitions ===')
import yaml
with open('openenv.yaml') as f:
d = yaml.safe_load(f)
# tasks: is now a flat list (changed from dict+definitions to list for platform compat)
tasks_section = d.get('tasks', [])
if isinstance(tasks_section, list):
defs = tasks_section
else:
defs = tasks_section.get('definitions', [])
tasks_with_grader = [t for t in defs if 'grader' in t]
print(f'PASS: openenv.yaml has {len(defs)} task definitions, {len(tasks_with_grader)} with grader')
assert len(tasks_with_grader) >= 3, f'FAIL: only {len(tasks_with_grader)} tasks with grader'
print()
print('ALL TESTS PASSED')
|