VLAarchtests / tests /test_teacher_audit.py
lsnu's picture
2026-03-25 runpod handoff update
e7d8e79 verified
import numpy as np
from sim_reveal.procedural_envs import available_proxy_names, make_proxy_env
def _mean_utility(proxy_name: str, baseline_name: str, seeds: range) -> float:
utilities = []
for seed in seeds:
env = make_proxy_env(proxy_name=proxy_name, resolution=32, seed=seed, rollout_horizon=4)
_, _ = env.reset(seed=seed)
chunk = env.baseline_action_chunk(baseline_name, chunk_horizon=4)
outcome = env.evaluate_action_chunk(chunk, rollout_horizon=4)
utility = float(outcome["retrieval_success"]) - float(outcome["final_disturbance_cost"]) - float(outcome["reocclusion_rate"])
utilities.append(utility)
return float(np.mean(utilities))
def test_teacher_audit():
seeds = range(5)
for proxy_name in available_proxy_names():
teacher_utility = _mean_utility(proxy_name, "teacher", seeds)
assert teacher_utility >= _mean_utility(proxy_name, "random", seeds)
assert teacher_utility >= _mean_utility(proxy_name, "retrieve_only", seeds)
assert teacher_utility >= _mean_utility(proxy_name, "reveal_only", seeds)
assert teacher_utility >= _mean_utility(proxy_name, "no_hold", seeds)