Spaces:
Running
Running
| """Tests for environment — lifecycle, all tasks, investigation mode, adaptive difficulty.""" | |
| import pytest | |
| from financial_audit_env.server.environment import FinancialAuditEnvironment | |
| from financial_audit_env.models import AuditAction, Finding | |
| class TestResetStep: | |
| """Tests for basic reset/step/state lifecycle.""" | |
| def test_reset_returns_observation(self, env): | |
| obs = env.reset(task_id="expense_audit", seed=42) | |
| assert obs.task_id == "expense_audit" | |
| assert obs.done is False | |
| assert obs.step_number == 0 | |
| assert len(obs.documents) > 0 | |
| def test_step_without_reset_raises(self, env): | |
| with pytest.raises(RuntimeError, match="not initialized"): | |
| env.step(AuditAction(findings=[], submit_final=True)) | |
| def test_step_with_findings(self, env): | |
| obs = env.reset(task_id="expense_audit", seed=42) | |
| action = AuditAction( | |
| findings=[ | |
| Finding( | |
| document_id="EXP-010", | |
| error_type="over_limit", | |
| description="Test", | |
| ) | |
| ], | |
| submit_final=False, | |
| ) | |
| obs = env.step(action) | |
| assert obs.done is False | |
| assert obs.step_number == 1 | |
| assert len(obs.findings_so_far) == 1 | |
| def test_submit_final_ends_episode(self, env): | |
| obs = env.reset(task_id="expense_audit", seed=42) | |
| action = AuditAction( | |
| findings=[ | |
| Finding( | |
| document_id="EXP-010", | |
| error_type="over_limit", | |
| description="Test", | |
| ) | |
| ], | |
| submit_final=True, | |
| ) | |
| obs = env.step(action) | |
| assert obs.done is True | |
| def test_max_steps_ends_episode(self, env): | |
| obs = env.reset(task_id="expense_audit", seed=42) | |
| for i in range(obs.max_steps): | |
| action = AuditAction(findings=[], submit_final=False) | |
| obs = env.step(action) | |
| assert obs.done is True | |
| def test_state_tracks_progress(self, env): | |
| env.reset(task_id="expense_audit", seed=42) | |
| state = env.state | |
| assert state.task_id == "expense_audit" | |
| assert state.step_count == 0 | |
| assert state.total_errors > 0 | |
| class TestAllTasks: | |
| """Test that all 4 tasks run without errors.""" | |
| def test_task_reset_and_step(self, env, task_id): | |
| obs = env.reset(task_id=task_id, seed=42) | |
| assert obs.task_id == task_id | |
| assert not obs.done | |
| assert len(obs.documents) > 0 | |
| # Submit empty findings to end episode | |
| action = AuditAction(findings=[], submit_final=True) | |
| obs = env.step(action) | |
| assert obs.done is True | |
| def test_grader_after_completion(self, env, task_id): | |
| env.reset(task_id=task_id, seed=42) | |
| env.step(AuditAction(findings=[], submit_final=True)) | |
| result = env.last_grader_result | |
| assert result is not None | |
| assert "score" in result | |
| assert 0 < result["score"] < 1 | |
| def test_invalid_task_raises(self, env): | |
| with pytest.raises(ValueError, match="Unknown task_id"): | |
| env.reset(task_id="nonexistent_task", seed=42) | |
| def test_grader_result_has_enhanced_fields(self, env): | |
| from financial_audit_env.server.data_generator import generate_expense_data | |
| _, gt = generate_expense_data(42) | |
| env.reset(task_id="expense_audit", seed=42) | |
| # Submit perfect findings | |
| action = AuditAction( | |
| findings=[ | |
| Finding(document_id=g["document_id"], error_type=g["error_type"], description="Test") | |
| for g in gt | |
| ], | |
| submit_final=True, | |
| ) | |
| env.step(action) | |
| result = env.last_grader_result | |
| assert "weighted_score" in result | |
| assert "partial_credit_score" in result | |
| assert "confusion_matrix" in result | |
| assert "risk_score" in result | |
| class TestInvestigationMode: | |
| """Tests for investigation mode.""" | |
| def test_investigation_mode_no_docs_on_reset(self, env): | |
| obs = env.reset(task_id="expense_audit", seed=42, investigation_mode=True) | |
| assert obs.investigation_mode is True | |
| assert len(obs.documents) == 0 # No docs until investigated | |
| assert len(obs.available_categories) > 0 | |
| assert obs.data_summary is not None | |
| def test_standard_mode_has_docs_on_reset(self, env): | |
| obs = env.reset(task_id="expense_audit", seed=42, investigation_mode=False) | |
| assert obs.investigation_mode is False | |
| assert len(obs.documents) > 0 | |
| def test_investigation_mode_submit_still_works(self, env): | |
| env.reset(task_id="expense_audit", seed=42, investigation_mode=True) | |
| action = AuditAction( | |
| findings=[ | |
| Finding(document_id="EXP-010", error_type="over_limit", description="Test") | |
| ], | |
| submit_final=True, | |
| ) | |
| obs = env.step(action) | |
| assert obs.done is True | |
| class TestAdaptiveDifficulty: | |
| """Tests for adaptive difficulty tracking.""" | |
| def test_no_history_returns_default(self, env): | |
| result = env.get_adaptive_difficulty() | |
| assert result["suggestion"] == "default" | |
| def test_high_scores_increase_difficulty(self, env): | |
| # Simulate high scores | |
| env._score_history = [0.9, 0.85, 0.88] | |
| result = env.get_adaptive_difficulty() | |
| assert result["suggestion"] == "increase_difficulty" | |
| assert result["noise_level"] >= 0.7 | |
| def test_low_scores_decrease_difficulty(self, env): | |
| env._score_history = [0.1, 0.2, 0.15] | |
| result = env.get_adaptive_difficulty() | |
| assert result["suggestion"] == "decrease_difficulty" | |
| class TestEdgeCases: | |
| """Edge case tests.""" | |
| def test_empty_findings_submission(self, env): | |
| env.reset(task_id="expense_audit", seed=42) | |
| action = AuditAction(findings=[], submit_final=True) | |
| obs = env.step(action) | |
| assert obs.done is True | |
| assert 0 < env.last_grader_result["score"] < 1 | |
| def test_invalid_error_type_filtered(self, env): | |
| env.reset(task_id="expense_audit", seed=42) | |
| action = AuditAction( | |
| findings=[ | |
| Finding( | |
| document_id="EXP-010", | |
| error_type="nonexistent_type", | |
| description="Test", | |
| ) | |
| ], | |
| submit_final=False, | |
| ) | |
| obs = env.step(action) | |
| # Invalid error type should be filtered out | |
| assert "Invalid error_type" in obs.feedback | |
| def test_many_findings_accepted(self, env): | |
| env.reset(task_id="expense_audit", seed=42) | |
| findings = [ | |
| Finding( | |
| document_id=f"DOC-{i}", | |
| error_type="over_limit", | |
| description=f"Test {i}", | |
| ) | |
| for i in range(50) | |
| ] | |
| action = AuditAction(findings=findings, submit_final=True) | |
| obs = env.step(action) | |
| assert obs.done is True | |
| def test_reset_clears_previous_state(self, env): | |
| env.reset(task_id="expense_audit", seed=42) | |
| env.step(AuditAction( | |
| findings=[Finding(document_id="EXP-010", error_type="over_limit", description="T")], | |
| submit_final=True, | |
| )) | |
| # Reset and run a new episode | |
| obs = env.reset(task_id="invoice_match", seed=42) | |
| assert obs.task_id == "invoice_match" | |
| assert len(obs.findings_so_far) == 0 | |
| assert env.state.step_count == 0 | |