FairRelay / brain /tests /test_learning_agent.py
MouleeswaranM's picture
Upload folder using huggingface_hub
fcf8749 verified
import pytest
import numpy as np
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
from datetime import datetime, timedelta
from app.services.learning_agent import RewardComputer, LearningAgent, FairnessBandit
from cron.daily_learning import DailyLearningPipeline
from app.models import LearningEpisode, AllocationRun
class TestRewardComputer:
"""Tests for RewardComputer class."""
@pytest.fixture
def mock_db(self):
mock = AsyncMock()
mock.execute = AsyncMock()
return mock
@pytest.fixture
def reward_computer(self, mock_db):
return RewardComputer(mock_db)
def test_weight_constants(self, reward_computer):
total = (
reward_computer.FAIRNESS_WEIGHT +
reward_computer.STRESS_WEIGHT +
reward_computer.COMPLETION_WEIGHT +
reward_computer.RETENTION_WEIGHT
)
assert abs(total - 1.0) < 0.01
class TestBanditConvergence:
"""Test Thompson Sampling convergence logic."""
def test_bandit_prefers_high_reward(self):
"""Simulate 20 updates -> bandit prefers high-reward config."""
mock_db = MagicMock()
bandit = FairnessBandit(mock_db)
# Get first two arms
arm0_hash = list(bandit.arm_hashes.keys())[0]
arm1_hash = list(bandit.arm_hashes.keys())[1]
# Initial priors loaded (mocked, effectively 1.0/1.0)
bandit.alpha = np.ones(bandit.n_arms)
bandit.beta = np.ones(bandit.n_arms)
# Simulate 15 good updates for Arm 0 (Reward 0.9)
for _ in range(15):
bandit.update(arm0_hash, 0.9)
# Simulate 15 bad updates for Arm 1 (Reward 0.2)
for _ in range(15):
bandit.update(arm1_hash, 0.2)
# Check updated parameters verification
idx0 = bandit.arm_indices[arm0_hash]
idx1 = bandit.arm_indices[arm1_hash]
# Alpha should be higher for arm0 (1 + 15*0.9 = 14.5)
# Beta should be higher for arm1 (1 + 15*(1-0.2) = 13.0) vs (1 + 15*(1-0.9) = 2.5)
assert bandit.alpha[idx0] > bandit.alpha[idx1]
assert bandit.beta[idx1] > bandit.beta[idx0]
# Sampling should pick arm0 most of the time
selections = []
for _ in range(100):
res = bandit.select_arm(experimental=False)
selections.append(res["arm_idx"])
count0 = selections.count(idx0)
count1 = selections.count(idx1)
assert count0 > count1, f"Should prefer arm0 (got {count0} vs {count1})"
@pytest.mark.asyncio
async def test_learning_integration(db_session):
"""Integration test for Learning Agent interacting with DB."""
agent = LearningAgent(db_session)
# Test getting status with real DB
status = await agent.get_learning_status()
assert "bandit_statistics" in status
assert len(status["bandit_statistics"]) > 0
@pytest.mark.asyncio
async def test_daily_learning_cron_pipeline(db_session, sample_drivers):
"""Test the full daily learning pipeline execution."""
pipeline = DailyLearningPipeline(db_session)
# 1. Setup: Create a past allocation run and learning episode
alloc_run = AllocationRun(
date=datetime.utcnow().date() - timedelta(days=1),
num_drivers=10,
num_routes=10,
num_packages=100,
status="SUCCESS"
)
db_session.add(alloc_run)
await db_session.flush()
# Create episode (created > 24h ago)
episode = LearningEpisode(
allocation_run_id=alloc_run.id,
config_hash="dummy_hash",
fairness_config={"gini_threshold": 0.3},
is_experimental=False,
created_at=datetime.utcnow() - timedelta(hours=25)
)
db_session.add(episode)
await db_session.commit()
# 2. Run pipeline
metrics = await pipeline.run()
# 3. Verify
assert metrics["status"] != "failed"
assert metrics["episodes_processed"] >= 1
# Note: Reward might be 0.5 (neutral) if no feedback, but processed count should increment.
# Verify episode was updated
await db_session.refresh(episode)
# If compute_reward succeeded (even with neutral), it writes result to DB
# Actually, process_episode_reward does: episode.episode_reward = reward
# But only if no error. RewardComputer returns "no_assignments" if no assignments.
# We didn't create assignments for alloc_run. So reward might not be set?
# Let's check RewardComputer behavior.
# It updates the episode if successful.
# Even if reward logic skipped due to no assignments, pipeline should complete.
assert "duration_seconds" in metrics