feature-flag-cleanup / tests /test_env.py
Falgunisharma's picture
Improve environment depth: investigate action, cascading deps, rich observations, harder hard task
5f6895d
"""Tests for the Feature Flag Cleanup environment."""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env.environment import FeatureFlagCleanupEnv
def test_reset():
env = FeatureFlagCleanupEnv()
obs = env.reset("easy")
assert "flag_name" in obs
assert "code_snippet" in obs
assert "previously_removed" in obs
assert obs["done"] is False
state = env.state()
assert state["step_count"] == 0
assert state["task_id"] == "easy"
print("PASS: test_reset")
def test_step():
env = FeatureFlagCleanupEnv()
env.reset("easy")
obs, reward, done, info = env.step({"action": "remove"})
assert isinstance(reward, float)
assert isinstance(done, bool)
assert "correct_action" in info
state = env.state()
assert state["step_count"] == 1
print("PASS: test_step")
def test_investigate():
env = FeatureFlagCleanupEnv()
obs = env.reset("easy")
assert obs["investigation_notes"] == "" # Hidden before investigate
# Investigate
obs, reward, done, info = env.step({"action": "investigate"})
assert reward == -0.05
assert done is False
assert obs["investigation_notes"] != "" # Revealed after investigate
assert info["agent_action"] == "investigate"
# Now take real action on same flag
obs, reward, done, info = env.step({"action": "remove"})
assert info["correct_action"] == "remove"
assert info["investigated"] is True
print("PASS: test_investigate")
def test_investigate_twice():
env = FeatureFlagCleanupEnv()
env.reset("easy")
env.step({"action": "investigate"})
obs, reward, done, info = env.step({"action": "investigate"})
assert reward == -0.05 # Still penalized
assert "Already investigated" in info.get("message", "")
print("PASS: test_investigate_twice")
def test_cascade_tracking():
env = FeatureFlagCleanupEnv()
env.reset("easy")
# Remove first flag
env.step({"action": "remove"})
state = env.state()
assert "enable_new_checkout_flow" in state["flags_removed"]
print("PASS: test_cascade_tracking")
def test_full_episode():
env = FeatureFlagCleanupEnv()
obs = env.reset("easy")
done = False
steps = 0
while not done:
obs, reward, done, info = env.step({"action": "remove"})
steps += 1
assert done is True
assert steps > 0
score = env.grade()
assert 0.0 <= score <= 1.0
print(f"PASS: test_full_episode (steps={steps}, score={score:.4f})")
def test_all_tasks():
env = FeatureFlagCleanupEnv()
for task_id in ["easy", "medium", "hard"]:
obs = env.reset(task_id)
assert obs["task_id"] == task_id
done = False
while not done:
obs, reward, done, info = env.step({"action": "keep"})
score = env.grade()
assert 0.0 <= score <= 1.0
print(f"PASS: test_{task_id} (score={score:.4f})")
def test_reward_range():
env = FeatureFlagCleanupEnv()
for action in ["remove", "keep", "deprecate", "escalate"]:
env.reset("easy")
obs, reward, done, info = env.step({"action": action})
assert -1.0 <= reward <= 1.0, f"Reward {reward} out of range for action {action}"
print("PASS: test_reward_range")
def test_invalid_action():
env = FeatureFlagCleanupEnv()
env.reset("easy")
try:
env.step({"action": "invalid_action"})
assert False, "Should have raised an error"
except Exception:
pass
print("PASS: test_invalid_action")
def test_rich_observations():
env = FeatureFlagCleanupEnv()
obs = env.reset("easy")
assert obs["code_snippet"] != ""
assert obs["last_commit_message"] != ""
assert obs["pr_context"] != ""
assert isinstance(obs["related_incidents"], list)
assert isinstance(obs["previously_removed"], list)
print("PASS: test_rich_observations")
def test_hard_task_has_more_flags():
env = FeatureFlagCleanupEnv()
env.reset("hard")
state = env.state()
assert state["total_flags"] >= 10 # Hard should have more flags
print(f"PASS: test_hard_task_has_more_flags (total={state['total_flags']})")
if __name__ == "__main__":
test_reset()
test_step()
test_investigate()
test_investigate_twice()
test_cascade_tracking()
test_full_episode()
test_all_tasks()
test_reward_range()
test_invalid_action()
test_rich_observations()
test_hard_task_has_more_flags()
print("\nAll tests passed!")