FairRelay / brain /cron /daily_learning.py
MouleeswaranM's picture
Upload folder using huggingface_hub
fcf8749 verified
raw
history blame
10.5 kB
#!/usr/bin/env python3
"""
Daily Learning Pipeline for Phase 8.
Run daily at 6 AM (recommended) via cron:
0 6 * * * cd /path/to/fair-dispatch && python -m cron.daily_learning
This script:
1. Processes completed allocation runs from yesterday
2. Computes episode rewards from driver feedback
3. Updates bandit posteriors
4. Selects new fairness config for today
5. Retrains per-driver XGBoost models
6. Logs metrics for monitoring
"""
import asyncio
import logging
import sys
from datetime import datetime, date, timedelta
from typing import List, Optional
# Add parent directory to path for imports
sys.path.insert(0, str(__file__).replace("\\", "/").rsplit("/cron", 1)[0])
from sqlalchemy import select, and_
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import async_session_maker
from app.models import (
AllocationRun,
AllocationRunStatus,
LearningEpisode,
FairnessConfig,
Driver,
DriverEffortModel,
)
from app.services.learning_agent import (
LearningAgent,
FairnessBandit,
hash_config,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("daily_learning")
class DailyLearningPipeline:
"""
Daily learning pipeline that processes feedback, updates models,
and deploys new configurations.
"""
# Safety rails
MIN_EPISODES_FOR_BANDIT = 10
MIN_FEEDBACK_RATE = 0.3 # At least 30% of drivers must give feedback
EXPERIMENTAL_COHORT_PCT = 0.10 # 10% get experimental config
def __init__(self, db: AsyncSession):
self.db = db
self.learning_agent = LearningAgent(db)
self.metrics = {
"episodes_processed": 0,
"rewards_computed": 0,
"models_updated": 0,
"errors": [],
}
async def run(self) -> dict:
"""Run the full daily learning pipeline."""
logger.info("Starting daily learning pipeline...")
start_time = datetime.utcnow()
try:
# Step 1: Process yesterday's episodes
await self._process_pending_episodes()
# Step 2: Select today's config
await self._select_todays_config()
# Step 3: Update per-driver models
await self._update_driver_models()
# Step 4: Log metrics
self.metrics["duration_seconds"] = (datetime.utcnow() - start_time).total_seconds()
self.metrics["completed_at"] = datetime.utcnow().isoformat()
logger.info(f"Daily learning completed: {self.metrics}")
return self.metrics
except Exception as e:
logger.error(f"Daily learning failed: {e}")
self.metrics["error"] = str(e)
self.metrics["status"] = "failed"
raise
async def _process_pending_episodes(self) -> None:
"""Process all pending episodes with sufficient feedback time (24h+)."""
logger.info("Processing pending episodes...")
# Find episodes created more than 24h ago without computed reward
cutoff = datetime.utcnow() - timedelta(hours=24)
result = await self.db.execute(
select(LearningEpisode)
.where(LearningEpisode.created_at <= cutoff)
.where(LearningEpisode.episode_reward.is_(None))
.order_by(LearningEpisode.created_at.asc())
.limit(100) # Process in batches
)
pending_episodes = result.scalars().all()
logger.info(f"Found {len(pending_episodes)} pending episodes")
for episode in pending_episodes:
try:
result = await self.learning_agent.process_episode_reward(episode.id)
if result["status"] == "success":
self.metrics["rewards_computed"] += 1
logger.debug(f"Computed reward for episode {episode.id}: {result['reward']:.3f}")
self.metrics["episodes_processed"] += 1
except Exception as e:
logger.warning(f"Failed to process episode {episode.id}: {e}")
self.metrics["errors"].append(f"episode_{episode.id}: {str(e)}")
await self.db.commit()
async def _select_todays_config(self) -> None:
"""Select and deploy today's fairness configuration."""
logger.info("Selecting today's configuration...")
# Check if we have enough data for bandit selection
await self.learning_agent.bandit.load_priors()
total_samples = int(sum(self.learning_agent.bandit.samples))
if total_samples < self.MIN_EPISODES_FOR_BANDIT:
logger.info(f"Insufficient bandit data ({total_samples}/{self.MIN_EPISODES_FOR_BANDIT}). Using default config.")
self.metrics["config_selection"] = "default_insufficient_data"
return
# Select stable config (non-experimental)
stable_selection = await self.learning_agent.select_config(experimental=False)
stable_config = stable_selection["config"]
logger.info(f"Selected stable config (arm {stable_selection['arm_idx']}): "
f"gini={stable_config['gini_threshold']}, "
f"stddev={stable_config['stddev_threshold']}")
# Update active FairnessConfig in database
await self._update_active_config(stable_config)
self.metrics["config_selection"] = "bandit_selected"
self.metrics["selected_arm"] = stable_selection["arm_idx"]
self.metrics["selected_config_hash"] = stable_selection["config_hash"]
async def _update_active_config(self, config: dict) -> None:
"""Update the active FairnessConfig in database."""
# Deactivate all existing configs
result = await self.db.execute(
select(FairnessConfig).where(FairnessConfig.is_active == True)
)
active_configs = result.scalars().all()
for existing in active_configs:
existing.is_active = False
# Create or update config with bandit-selected values
new_config = FairnessConfig(
is_active=True,
gini_threshold=config["gini_threshold"],
stddev_threshold=config["stddev_threshold"],
recovery_lightening_factor=config["recovery_lightening_factor"],
ev_charging_penalty_weight=config["ev_charging_penalty_weight"],
max_gap_threshold=config.get("max_gap_threshold", 25.0),
workload_weight_packages=config.get("workload_weight_packages", 1.0),
workload_weight_weight_kg=config.get("workload_weight_weight_kg", 0.5),
workload_weight_difficulty=config.get("workload_weight_difficulty", 10.0),
workload_weight_time=config.get("workload_weight_time", 0.2),
recovery_mode_enabled=config.get("recovery_mode_enabled", True),
complexity_debt_hard_threshold=config.get("complexity_debt_hard_threshold", 2.0),
recovery_penalty_weight=config.get("recovery_penalty_weight", 3.0),
ev_safety_margin_pct=config.get("ev_safety_margin_pct", 10.0),
)
self.db.add(new_config)
await self.db.commit()
logger.info("Updated active FairnessConfig")
async def _update_driver_models(self) -> None:
"""Update per-driver XGBoost models."""
logger.info("Updating per-driver effort models...")
# Get all active drivers
result = await self.db.execute(
select(Driver.id)
.order_by(Driver.created_at.asc())
)
driver_ids = [row[0] for row in result.fetchall()]
logger.info(f"Found {len(driver_ids)} drivers to update")
successful = 0
skipped = 0
failed = 0
for driver_id in driver_ids:
try:
result = await self.learning_agent.effort_learner.update_model(driver_id)
if result["status"] == "success":
successful += 1
logger.debug(f"Updated model for driver {driver_id}: MSE={result['mse']:.3f}")
elif result["status"] == "skipped":
skipped += 1
else:
failed += 1
except Exception as e:
failed += 1
logger.warning(f"Failed to update model for driver {driver_id}: {e}")
self.metrics["errors"].append(f"driver_model_{driver_id}: {str(e)}")
await self.db.commit()
self.metrics["models_updated"] = successful
self.metrics["models_skipped"] = skipped
self.metrics["models_failed"] = failed
logger.info(f"Model updates: {successful} success, {skipped} skipped, {failed} failed")
async def run_daily_learning() -> dict:
"""
Main entry point for the daily learning cron job.
Returns:
Dict with execution metrics
"""
async with async_session_maker() as db:
pipeline = DailyLearningPipeline(db)
return await pipeline.run()
def main():
"""CLI entry point."""
logger.info("=" * 60)
logger.info("DAILY LEARNING PIPELINE - " + datetime.utcnow().isoformat())
logger.info("=" * 60)
try:
# Run the async pipeline
metrics = asyncio.run(run_daily_learning())
# Print summary
print("\n" + "=" * 40)
print("PIPELINE SUMMARY")
print("=" * 40)
print(f"Episodes Processed: {metrics.get('episodes_processed', 0)}")
print(f"Rewards Computed: {metrics.get('rewards_computed', 0)}")
print(f"Models Updated: {metrics.get('models_updated', 0)}")
print(f"Config Selection: {metrics.get('config_selection', 'N/A')}")
print(f"Duration: {metrics.get('duration_seconds', 0):.2f}s")
if metrics.get("errors"):
print(f"\nErrors: {len(metrics['errors'])}")
for error in metrics["errors"][:5]:
print(f" - {error}")
print("=" * 40)
return 0
except Exception as e:
logger.error(f"Pipeline failed: {e}")
return 1
if __name__ == "__main__":
sys.exit(main())