|
|
| import pytest |
| import numpy as np |
| from unittest.mock import AsyncMock, MagicMock |
| from uuid import uuid4 |
| from datetime import datetime, timedelta |
|
|
| from app.services.learning_agent import RewardComputer, LearningAgent, FairnessBandit |
| from cron.daily_learning import DailyLearningPipeline |
| from app.models import LearningEpisode, AllocationRun |
|
|
| class TestRewardComputer: |
| """Tests for RewardComputer class.""" |
| |
| @pytest.fixture |
| def mock_db(self): |
| mock = AsyncMock() |
| mock.execute = AsyncMock() |
| return mock |
| |
| @pytest.fixture |
| def reward_computer(self, mock_db): |
| return RewardComputer(mock_db) |
| |
| def test_weight_constants(self, reward_computer): |
| total = ( |
| reward_computer.FAIRNESS_WEIGHT + |
| reward_computer.STRESS_WEIGHT + |
| reward_computer.COMPLETION_WEIGHT + |
| reward_computer.RETENTION_WEIGHT |
| ) |
| assert abs(total - 1.0) < 0.01 |
|
|
| class TestBanditConvergence: |
| """Test Thompson Sampling convergence logic.""" |
| |
| def test_bandit_prefers_high_reward(self): |
| """Simulate 20 updates -> bandit prefers high-reward config.""" |
| mock_db = MagicMock() |
| bandit = FairnessBandit(mock_db) |
| |
| |
| arm0_hash = list(bandit.arm_hashes.keys())[0] |
| arm1_hash = list(bandit.arm_hashes.keys())[1] |
| |
| |
| bandit.alpha = np.ones(bandit.n_arms) |
| bandit.beta = np.ones(bandit.n_arms) |
| |
| |
| for _ in range(15): |
| bandit.update(arm0_hash, 0.9) |
| |
| |
| for _ in range(15): |
| bandit.update(arm1_hash, 0.2) |
| |
| |
| idx0 = bandit.arm_indices[arm0_hash] |
| idx1 = bandit.arm_indices[arm1_hash] |
| |
| |
| |
| |
| assert bandit.alpha[idx0] > bandit.alpha[idx1] |
| assert bandit.beta[idx1] > bandit.beta[idx0] |
| |
| |
| selections = [] |
| for _ in range(100): |
| res = bandit.select_arm(experimental=False) |
| selections.append(res["arm_idx"]) |
| |
| count0 = selections.count(idx0) |
| count1 = selections.count(idx1) |
| |
| assert count0 > count1, f"Should prefer arm0 (got {count0} vs {count1})" |
|
|
| @pytest.mark.asyncio |
| async def test_learning_integration(db_session): |
| """Integration test for Learning Agent interacting with DB.""" |
| agent = LearningAgent(db_session) |
| |
| |
| status = await agent.get_learning_status() |
| assert "bandit_statistics" in status |
| assert len(status["bandit_statistics"]) > 0 |
|
|
| @pytest.mark.asyncio |
| async def test_daily_learning_cron_pipeline(db_session, sample_drivers): |
| """Test the full daily learning pipeline execution.""" |
| pipeline = DailyLearningPipeline(db_session) |
| |
| |
| alloc_run = AllocationRun( |
| date=datetime.utcnow().date() - timedelta(days=1), |
| num_drivers=10, |
| num_routes=10, |
| num_packages=100, |
| status="SUCCESS" |
| ) |
| db_session.add(alloc_run) |
| await db_session.flush() |
| |
| |
| episode = LearningEpisode( |
| allocation_run_id=alloc_run.id, |
| config_hash="dummy_hash", |
| fairness_config={"gini_threshold": 0.3}, |
| is_experimental=False, |
| created_at=datetime.utcnow() - timedelta(hours=25) |
| ) |
| db_session.add(episode) |
| await db_session.commit() |
| |
| |
| metrics = await pipeline.run() |
| |
| |
| assert metrics["status"] != "failed" |
| assert metrics["episodes_processed"] >= 1 |
| |
| |
| |
| await db_session.refresh(episode) |
| |
| |
| |
| |
| |
| |
| |
| |
| assert "duration_seconds" in metrics |
|
|