Spaces:
Sleeping
Sleeping
| """Production drift simulation script. | |
| Sends requests to the FastAPI endpoint to simulate traffic with configurable | |
| drift types and delayed feedback. | |
| Usage: | |
| python scripts/simulate_drift.py | |
| python scripts/simulate_drift.py --drift-type sudden | |
| python scripts/simulate_drift.py --drift-type mixed --steps 1000 --delay 0.05 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import random | |
| import sys | |
| import time | |
| from collections import deque | |
| from pathlib import Path | |
| import requests | |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | |
| from src.data.generator import TaxiDataGenerator | |
| from src.data.drift_simulator import DriftSimulator | |
| from src.utils.config import settings | |
| from src.utils.logging_config import get_logger | |
| log = get_logger("simulate_drift") | |
| API_URL = "http://localhost:8000" | |
| DRIFT_TYPES = ["gradual", "sudden", "seasonal", "mixed"] | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser(description="Simulate production drift") | |
| p.add_argument("--drift-type", choices=DRIFT_TYPES, default="gradual") | |
| p.add_argument("--steps", type=int, default=500) | |
| p.add_argument("--batch-size", type=int, default=5) | |
| p.add_argument("--delay", type=float, default=0.1) | |
| p.add_argument("--feedback-lag", type=int, default=20) | |
| p.add_argument("--api-url", default=API_URL) | |
| p.add_argument("--severity", type=float, default=1.0) | |
| return p.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| api = args.api_url | |
| try: | |
| health = requests.get(f"{api}/health", timeout=5).json() | |
| log.info("API online — model=%s, uptime=%.0fs", health.get("model_version"), health.get("uptime_seconds")) | |
| except Exception as e: | |
| log.error("Cannot reach API at %s: %s", api, e) | |
| log.error("Start the API first: uvicorn app:app --reload") | |
| sys.exit(1) | |
| gen = TaxiDataGenerator(random_seed=42) | |
| simulator = DriftSimulator(random_seed=99) | |
| base_df = gen.generate(n_samples=args.steps * args.batch_size) | |
| feature_cols = [c for c in settings.data.features if c in base_df.columns] | |
| log.info("Starting drift simulation: type=%s, steps=%d, severity=%.2f", | |
| args.drift_type, args.steps, args.severity) | |
| pending_feedback: deque = deque() | |
| stats = {"predictions": 0, "feedback_sent": 0, "drift_alerts": 0, "retrain_events": 0} | |
| for step in range(args.steps): | |
| batch_start = (step * args.batch_size) % len(base_df) | |
| batch = base_df.iloc[batch_start: batch_start + args.batch_size].copy() | |
| if args.drift_type != "sudden" or step == args.steps // 3: | |
| drifted = simulator.apply( | |
| batch[feature_cols], | |
| drift_type=args.drift_type, | |
| severity=args.severity, | |
| step=step, | |
| total_steps=args.steps, | |
| ) | |
| for col in feature_cols: | |
| if col in drifted.columns: | |
| batch[col] = drifted[col].values | |
| for _, row in batch.iterrows(): | |
| payload = { | |
| "passenger_count": int(max(1, min(6, round(row.get("passenger_count", 2))))), | |
| "trip_distance": float(max(0.1, min(50, row.get("trip_distance", 3)))), | |
| "pickup_hour": int(max(0, min(23, round(row.get("pickup_hour", 8))))), | |
| "pickup_dow": int(max(0, min(6, round(row.get("pickup_dow", 1))))), | |
| "pickup_month": int(max(1, min(12, round(row.get("pickup_month", 1))))), | |
| "pickup_is_weekend": int(row.get("pickup_is_weekend", 0)), | |
| "rate_code_id": int(max(1, min(5, round(row.get("rate_code_id", 1))))), | |
| "payment_type": int(max(1, min(4, round(row.get("payment_type", 1))))), | |
| "pu_location_zone": int(max(1, min(50, round(row.get("pu_location_zone", 10))))), | |
| "do_location_zone": int(max(1, min(50, round(row.get("do_location_zone", 25))))), | |
| "vendor_id": int(max(1, min(2, round(row.get("vendor_id", 1))))), | |
| } | |
| try: | |
| r = requests.post(f"{api}/predict", json=payload, timeout=5) | |
| if r.status_code == 200: | |
| result = r.json() | |
| actual = float(row.get("trip_duration_min", result["predicted_duration_min"] * random.uniform(0.8, 1.2))) | |
| pending_feedback.append((step, result["request_id"], actual)) | |
| stats["predictions"] += 1 | |
| except Exception as e: | |
| log.debug("Prediction failed: %s", e) | |
| while pending_feedback and (step - pending_feedback[0][0]) >= args.feedback_lag: | |
| _, req_id, actual = pending_feedback.popleft() | |
| try: | |
| requests.post( | |
| f"{api}/predict/feedback", | |
| json={"request_id": req_id, "actual_duration_min": actual}, | |
| timeout=3, | |
| ) | |
| stats["feedback_sent"] += 1 | |
| except Exception: | |
| pass | |
| if step > 0 and step % 10 == 0 and step % 50 != 0: | |
| # Emit a metrics snapshot so the performance log captures the drift curve | |
| try: | |
| requests.get(f"{api}/monitor/metrics", timeout=5) | |
| except Exception: | |
| pass | |
| if step > 0 and step % 50 == 0: | |
| try: | |
| r = requests.get(f"{api}/monitor/drift", timeout=10) | |
| drift = r.json() | |
| if drift.get("drift_detected"): | |
| stats["drift_alerts"] += 1 | |
| log.warning( | |
| "Step %d — DRIFT DETECTED features=%s action=%s", | |
| step, drift.get("drifted_features"), drift.get("action"), | |
| ) | |
| if drift.get("action") == "retraining_triggered": | |
| stats["retrain_events"] += 1 | |
| else: | |
| log.info("Step %d — predictions=%d feedback=%d", | |
| step, stats["predictions"], stats["feedback_sent"]) | |
| except Exception as e: | |
| log.debug("Drift check failed: %s", e) | |
| time.sleep(args.delay) | |
| log.info("Simulation complete: predictions=%d, feedback=%d, drift_alerts=%d, retrain=%d", | |
| stats["predictions"], stats["feedback_sent"], stats["drift_alerts"], stats["retrain_events"]) | |
| if __name__ == "__main__": | |
| main() | |