financial_audit_env / tests /test_environment.py
balloonmann's picture
fix: increase score epsilon to 0.01 for validator safety
b4d5e6a
"""Tests for environment — lifecycle, all tasks, investigation mode, adaptive difficulty."""
import pytest
from financial_audit_env.server.environment import FinancialAuditEnvironment
from financial_audit_env.models import AuditAction, Finding
class TestResetStep:
"""Tests for basic reset/step/state lifecycle."""
def test_reset_returns_observation(self, env):
obs = env.reset(task_id="expense_audit", seed=42)
assert obs.task_id == "expense_audit"
assert obs.done is False
assert obs.step_number == 0
assert len(obs.documents) > 0
def test_step_without_reset_raises(self, env):
with pytest.raises(RuntimeError, match="not initialized"):
env.step(AuditAction(findings=[], submit_final=True))
def test_step_with_findings(self, env):
obs = env.reset(task_id="expense_audit", seed=42)
action = AuditAction(
findings=[
Finding(
document_id="EXP-010",
error_type="over_limit",
description="Test",
)
],
submit_final=False,
)
obs = env.step(action)
assert obs.done is False
assert obs.step_number == 1
assert len(obs.findings_so_far) == 1
def test_submit_final_ends_episode(self, env):
obs = env.reset(task_id="expense_audit", seed=42)
action = AuditAction(
findings=[
Finding(
document_id="EXP-010",
error_type="over_limit",
description="Test",
)
],
submit_final=True,
)
obs = env.step(action)
assert obs.done is True
def test_max_steps_ends_episode(self, env):
obs = env.reset(task_id="expense_audit", seed=42)
for i in range(obs.max_steps):
action = AuditAction(findings=[], submit_final=False)
obs = env.step(action)
assert obs.done is True
def test_state_tracks_progress(self, env):
env.reset(task_id="expense_audit", seed=42)
state = env.state
assert state.task_id == "expense_audit"
assert state.step_count == 0
assert state.total_errors > 0
class TestAllTasks:
"""Test that all 4 tasks run without errors."""
@pytest.mark.parametrize("task_id", [
"expense_audit", "invoice_match", "gst_reconciliation", "fraud_detection",
])
def test_task_reset_and_step(self, env, task_id):
obs = env.reset(task_id=task_id, seed=42)
assert obs.task_id == task_id
assert not obs.done
assert len(obs.documents) > 0
# Submit empty findings to end episode
action = AuditAction(findings=[], submit_final=True)
obs = env.step(action)
assert obs.done is True
@pytest.mark.parametrize("task_id", [
"expense_audit", "invoice_match", "gst_reconciliation", "fraud_detection",
])
def test_grader_after_completion(self, env, task_id):
env.reset(task_id=task_id, seed=42)
env.step(AuditAction(findings=[], submit_final=True))
result = env.last_grader_result
assert result is not None
assert "score" in result
assert 0 < result["score"] < 1
def test_invalid_task_raises(self, env):
with pytest.raises(ValueError, match="Unknown task_id"):
env.reset(task_id="nonexistent_task", seed=42)
def test_grader_result_has_enhanced_fields(self, env):
from financial_audit_env.server.data_generator import generate_expense_data
_, gt = generate_expense_data(42)
env.reset(task_id="expense_audit", seed=42)
# Submit perfect findings
action = AuditAction(
findings=[
Finding(document_id=g["document_id"], error_type=g["error_type"], description="Test")
for g in gt
],
submit_final=True,
)
env.step(action)
result = env.last_grader_result
assert "weighted_score" in result
assert "partial_credit_score" in result
assert "confusion_matrix" in result
assert "risk_score" in result
class TestInvestigationMode:
"""Tests for investigation mode."""
def test_investigation_mode_no_docs_on_reset(self, env):
obs = env.reset(task_id="expense_audit", seed=42, investigation_mode=True)
assert obs.investigation_mode is True
assert len(obs.documents) == 0 # No docs until investigated
assert len(obs.available_categories) > 0
assert obs.data_summary is not None
def test_standard_mode_has_docs_on_reset(self, env):
obs = env.reset(task_id="expense_audit", seed=42, investigation_mode=False)
assert obs.investigation_mode is False
assert len(obs.documents) > 0
def test_investigation_mode_submit_still_works(self, env):
env.reset(task_id="expense_audit", seed=42, investigation_mode=True)
action = AuditAction(
findings=[
Finding(document_id="EXP-010", error_type="over_limit", description="Test")
],
submit_final=True,
)
obs = env.step(action)
assert obs.done is True
class TestAdaptiveDifficulty:
"""Tests for adaptive difficulty tracking."""
def test_no_history_returns_default(self, env):
result = env.get_adaptive_difficulty()
assert result["suggestion"] == "default"
def test_high_scores_increase_difficulty(self, env):
# Simulate high scores
env._score_history = [0.9, 0.85, 0.88]
result = env.get_adaptive_difficulty()
assert result["suggestion"] == "increase_difficulty"
assert result["noise_level"] >= 0.7
def test_low_scores_decrease_difficulty(self, env):
env._score_history = [0.1, 0.2, 0.15]
result = env.get_adaptive_difficulty()
assert result["suggestion"] == "decrease_difficulty"
class TestEdgeCases:
"""Edge case tests."""
def test_empty_findings_submission(self, env):
env.reset(task_id="expense_audit", seed=42)
action = AuditAction(findings=[], submit_final=True)
obs = env.step(action)
assert obs.done is True
assert 0 < env.last_grader_result["score"] < 1
def test_invalid_error_type_filtered(self, env):
env.reset(task_id="expense_audit", seed=42)
action = AuditAction(
findings=[
Finding(
document_id="EXP-010",
error_type="nonexistent_type",
description="Test",
)
],
submit_final=False,
)
obs = env.step(action)
# Invalid error type should be filtered out
assert "Invalid error_type" in obs.feedback
def test_many_findings_accepted(self, env):
env.reset(task_id="expense_audit", seed=42)
findings = [
Finding(
document_id=f"DOC-{i}",
error_type="over_limit",
description=f"Test {i}",
)
for i in range(50)
]
action = AuditAction(findings=findings, submit_final=True)
obs = env.step(action)
assert obs.done is True
def test_reset_clears_previous_state(self, env):
env.reset(task_id="expense_audit", seed=42)
env.step(AuditAction(
findings=[Finding(document_id="EXP-010", error_type="over_limit", description="T")],
submit_final=True,
))
# Reset and run a new episode
obs = env.reset(task_id="invoice_match", seed=42)
assert obs.task_id == "invoice_match"
assert len(obs.findings_so_far) == 0
assert env.state.step_count == 0