| |
| """ |
| Daily Learning Pipeline for Phase 8. |
| |
| Run daily at 6 AM (recommended) via cron: |
| 0 6 * * * cd /path/to/fair-dispatch && python -m cron.daily_learning |
| |
| This script: |
| 1. Processes completed allocation runs from yesterday |
| 2. Computes episode rewards from driver feedback |
| 3. Updates bandit posteriors |
| 4. Selects new fairness config for today |
| 5. Retrains per-driver XGBoost models |
| 6. Logs metrics for monitoring |
| """ |
|
|
| import asyncio |
| import logging |
| import sys |
| from datetime import datetime, date, timedelta |
| from typing import List, Optional |
|
|
| |
| sys.path.insert(0, str(__file__).replace("\\", "/").rsplit("/cron", 1)[0]) |
|
|
| from sqlalchemy import select, and_ |
| from sqlalchemy.ext.asyncio import AsyncSession |
|
|
| from app.database import async_session_maker |
| from app.models import ( |
| AllocationRun, |
| AllocationRunStatus, |
| LearningEpisode, |
| FairnessConfig, |
| Driver, |
| DriverEffortModel, |
| ) |
| from app.services.learning_agent import ( |
| LearningAgent, |
| FairnessBandit, |
| hash_config, |
| ) |
|
|
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| ) |
| logger = logging.getLogger("daily_learning") |
|
|
|
|
| class DailyLearningPipeline: |
| """ |
| Daily learning pipeline that processes feedback, updates models, |
| and deploys new configurations. |
| """ |
| |
| |
| MIN_EPISODES_FOR_BANDIT = 10 |
| MIN_FEEDBACK_RATE = 0.3 |
| EXPERIMENTAL_COHORT_PCT = 0.10 |
| |
| def __init__(self, db: AsyncSession): |
| self.db = db |
| self.learning_agent = LearningAgent(db) |
| self.metrics = { |
| "episodes_processed": 0, |
| "rewards_computed": 0, |
| "models_updated": 0, |
| "errors": [], |
| } |
| |
| async def run(self) -> dict: |
| """Run the full daily learning pipeline.""" |
| logger.info("Starting daily learning pipeline...") |
| start_time = datetime.utcnow() |
| |
| try: |
| |
| await self._process_pending_episodes() |
| |
| |
| await self._select_todays_config() |
| |
| |
| await self._update_driver_models() |
| |
| |
| self.metrics["duration_seconds"] = (datetime.utcnow() - start_time).total_seconds() |
| self.metrics["completed_at"] = datetime.utcnow().isoformat() |
| |
| logger.info(f"Daily learning completed: {self.metrics}") |
| return self.metrics |
| |
| except Exception as e: |
| logger.error(f"Daily learning failed: {e}") |
| self.metrics["error"] = str(e) |
| self.metrics["status"] = "failed" |
| raise |
| |
| async def _process_pending_episodes(self) -> None: |
| """Process all pending episodes with sufficient feedback time (24h+).""" |
| logger.info("Processing pending episodes...") |
| |
| |
| cutoff = datetime.utcnow() - timedelta(hours=24) |
| |
| result = await self.db.execute( |
| select(LearningEpisode) |
| .where(LearningEpisode.created_at <= cutoff) |
| .where(LearningEpisode.episode_reward.is_(None)) |
| .order_by(LearningEpisode.created_at.asc()) |
| .limit(100) |
| ) |
| pending_episodes = result.scalars().all() |
| |
| logger.info(f"Found {len(pending_episodes)} pending episodes") |
| |
| for episode in pending_episodes: |
| try: |
| result = await self.learning_agent.process_episode_reward(episode.id) |
| if result["status"] == "success": |
| self.metrics["rewards_computed"] += 1 |
| logger.debug(f"Computed reward for episode {episode.id}: {result['reward']:.3f}") |
| self.metrics["episodes_processed"] += 1 |
| except Exception as e: |
| logger.warning(f"Failed to process episode {episode.id}: {e}") |
| self.metrics["errors"].append(f"episode_{episode.id}: {str(e)}") |
| |
| await self.db.commit() |
| |
| async def _select_todays_config(self) -> None: |
| """Select and deploy today's fairness configuration.""" |
| logger.info("Selecting today's configuration...") |
| |
| |
| await self.learning_agent.bandit.load_priors() |
| total_samples = int(sum(self.learning_agent.bandit.samples)) |
| |
| if total_samples < self.MIN_EPISODES_FOR_BANDIT: |
| logger.info(f"Insufficient bandit data ({total_samples}/{self.MIN_EPISODES_FOR_BANDIT}). Using default config.") |
| self.metrics["config_selection"] = "default_insufficient_data" |
| return |
| |
| |
| stable_selection = await self.learning_agent.select_config(experimental=False) |
| stable_config = stable_selection["config"] |
| |
| logger.info(f"Selected stable config (arm {stable_selection['arm_idx']}): " |
| f"gini={stable_config['gini_threshold']}, " |
| f"stddev={stable_config['stddev_threshold']}") |
| |
| |
| await self._update_active_config(stable_config) |
| |
| self.metrics["config_selection"] = "bandit_selected" |
| self.metrics["selected_arm"] = stable_selection["arm_idx"] |
| self.metrics["selected_config_hash"] = stable_selection["config_hash"] |
| |
| async def _update_active_config(self, config: dict) -> None: |
| """Update the active FairnessConfig in database.""" |
| |
| result = await self.db.execute( |
| select(FairnessConfig).where(FairnessConfig.is_active == True) |
| ) |
| active_configs = result.scalars().all() |
| |
| for existing in active_configs: |
| existing.is_active = False |
| |
| |
| new_config = FairnessConfig( |
| is_active=True, |
| gini_threshold=config["gini_threshold"], |
| stddev_threshold=config["stddev_threshold"], |
| recovery_lightening_factor=config["recovery_lightening_factor"], |
| ev_charging_penalty_weight=config["ev_charging_penalty_weight"], |
| max_gap_threshold=config.get("max_gap_threshold", 25.0), |
| workload_weight_packages=config.get("workload_weight_packages", 1.0), |
| workload_weight_weight_kg=config.get("workload_weight_weight_kg", 0.5), |
| workload_weight_difficulty=config.get("workload_weight_difficulty", 10.0), |
| workload_weight_time=config.get("workload_weight_time", 0.2), |
| recovery_mode_enabled=config.get("recovery_mode_enabled", True), |
| complexity_debt_hard_threshold=config.get("complexity_debt_hard_threshold", 2.0), |
| recovery_penalty_weight=config.get("recovery_penalty_weight", 3.0), |
| ev_safety_margin_pct=config.get("ev_safety_margin_pct", 10.0), |
| ) |
| self.db.add(new_config) |
| |
| await self.db.commit() |
| logger.info("Updated active FairnessConfig") |
| |
| async def _update_driver_models(self) -> None: |
| """Update per-driver XGBoost models.""" |
| logger.info("Updating per-driver effort models...") |
| |
| |
| result = await self.db.execute( |
| select(Driver.id) |
| .order_by(Driver.created_at.asc()) |
| ) |
| driver_ids = [row[0] for row in result.fetchall()] |
| |
| logger.info(f"Found {len(driver_ids)} drivers to update") |
| |
| successful = 0 |
| skipped = 0 |
| failed = 0 |
| |
| for driver_id in driver_ids: |
| try: |
| result = await self.learning_agent.effort_learner.update_model(driver_id) |
| |
| if result["status"] == "success": |
| successful += 1 |
| logger.debug(f"Updated model for driver {driver_id}: MSE={result['mse']:.3f}") |
| elif result["status"] == "skipped": |
| skipped += 1 |
| else: |
| failed += 1 |
| |
| except Exception as e: |
| failed += 1 |
| logger.warning(f"Failed to update model for driver {driver_id}: {e}") |
| self.metrics["errors"].append(f"driver_model_{driver_id}: {str(e)}") |
| |
| await self.db.commit() |
| |
| self.metrics["models_updated"] = successful |
| self.metrics["models_skipped"] = skipped |
| self.metrics["models_failed"] = failed |
| |
| logger.info(f"Model updates: {successful} success, {skipped} skipped, {failed} failed") |
|
|
|
|
| async def run_daily_learning() -> dict: |
| """ |
| Main entry point for the daily learning cron job. |
| |
| Returns: |
| Dict with execution metrics |
| """ |
| async with async_session_maker() as db: |
| pipeline = DailyLearningPipeline(db) |
| return await pipeline.run() |
|
|
|
|
| def main(): |
| """CLI entry point.""" |
| logger.info("=" * 60) |
| logger.info("DAILY LEARNING PIPELINE - " + datetime.utcnow().isoformat()) |
| logger.info("=" * 60) |
| |
| try: |
| |
| metrics = asyncio.run(run_daily_learning()) |
| |
| |
| print("\n" + "=" * 40) |
| print("PIPELINE SUMMARY") |
| print("=" * 40) |
| print(f"Episodes Processed: {metrics.get('episodes_processed', 0)}") |
| print(f"Rewards Computed: {metrics.get('rewards_computed', 0)}") |
| print(f"Models Updated: {metrics.get('models_updated', 0)}") |
| print(f"Config Selection: {metrics.get('config_selection', 'N/A')}") |
| print(f"Duration: {metrics.get('duration_seconds', 0):.2f}s") |
| |
| if metrics.get("errors"): |
| print(f"\nErrors: {len(metrics['errors'])}") |
| for error in metrics["errors"][:5]: |
| print(f" - {error}") |
| |
| print("=" * 40) |
| |
| return 0 |
| |
| except Exception as e: |
| logger.error(f"Pipeline failed: {e}") |
| return 1 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|