argus-mlops / scripts /simulate_drift.py
hodfa840's picture
Fix flat performance graph on HF Spaces
50145b8
"""Production drift simulation script.
Sends requests to the FastAPI endpoint to simulate traffic with configurable
drift types and delayed feedback.
Usage:
python scripts/simulate_drift.py
python scripts/simulate_drift.py --drift-type sudden
python scripts/simulate_drift.py --drift-type mixed --steps 1000 --delay 0.05
"""
from __future__ import annotations
import argparse
import json
import random
import sys
import time
from collections import deque
from pathlib import Path
import requests
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.data.generator import TaxiDataGenerator
from src.data.drift_simulator import DriftSimulator
from src.utils.config import settings
from src.utils.logging_config import get_logger
log = get_logger("simulate_drift")
API_URL = "http://localhost:8000"
DRIFT_TYPES = ["gradual", "sudden", "seasonal", "mixed"]
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Simulate production drift")
p.add_argument("--drift-type", choices=DRIFT_TYPES, default="gradual")
p.add_argument("--steps", type=int, default=500)
p.add_argument("--batch-size", type=int, default=5)
p.add_argument("--delay", type=float, default=0.1)
p.add_argument("--feedback-lag", type=int, default=20)
p.add_argument("--api-url", default=API_URL)
p.add_argument("--severity", type=float, default=1.0)
return p.parse_args()
def main() -> None:
args = parse_args()
api = args.api_url
try:
health = requests.get(f"{api}/health", timeout=5).json()
log.info("API online — model=%s, uptime=%.0fs", health.get("model_version"), health.get("uptime_seconds"))
except Exception as e:
log.error("Cannot reach API at %s: %s", api, e)
log.error("Start the API first: uvicorn app:app --reload")
sys.exit(1)
gen = TaxiDataGenerator(random_seed=42)
simulator = DriftSimulator(random_seed=99)
base_df = gen.generate(n_samples=args.steps * args.batch_size)
feature_cols = [c for c in settings.data.features if c in base_df.columns]
log.info("Starting drift simulation: type=%s, steps=%d, severity=%.2f",
args.drift_type, args.steps, args.severity)
pending_feedback: deque = deque()
stats = {"predictions": 0, "feedback_sent": 0, "drift_alerts": 0, "retrain_events": 0}
for step in range(args.steps):
batch_start = (step * args.batch_size) % len(base_df)
batch = base_df.iloc[batch_start: batch_start + args.batch_size].copy()
if args.drift_type != "sudden" or step == args.steps // 3:
drifted = simulator.apply(
batch[feature_cols],
drift_type=args.drift_type,
severity=args.severity,
step=step,
total_steps=args.steps,
)
for col in feature_cols:
if col in drifted.columns:
batch[col] = drifted[col].values
for _, row in batch.iterrows():
payload = {
"passenger_count": int(max(1, min(6, round(row.get("passenger_count", 2))))),
"trip_distance": float(max(0.1, min(50, row.get("trip_distance", 3)))),
"pickup_hour": int(max(0, min(23, round(row.get("pickup_hour", 8))))),
"pickup_dow": int(max(0, min(6, round(row.get("pickup_dow", 1))))),
"pickup_month": int(max(1, min(12, round(row.get("pickup_month", 1))))),
"pickup_is_weekend": int(row.get("pickup_is_weekend", 0)),
"rate_code_id": int(max(1, min(5, round(row.get("rate_code_id", 1))))),
"payment_type": int(max(1, min(4, round(row.get("payment_type", 1))))),
"pu_location_zone": int(max(1, min(50, round(row.get("pu_location_zone", 10))))),
"do_location_zone": int(max(1, min(50, round(row.get("do_location_zone", 25))))),
"vendor_id": int(max(1, min(2, round(row.get("vendor_id", 1))))),
}
try:
r = requests.post(f"{api}/predict", json=payload, timeout=5)
if r.status_code == 200:
result = r.json()
actual = float(row.get("trip_duration_min", result["predicted_duration_min"] * random.uniform(0.8, 1.2)))
pending_feedback.append((step, result["request_id"], actual))
stats["predictions"] += 1
except Exception as e:
log.debug("Prediction failed: %s", e)
while pending_feedback and (step - pending_feedback[0][0]) >= args.feedback_lag:
_, req_id, actual = pending_feedback.popleft()
try:
requests.post(
f"{api}/predict/feedback",
json={"request_id": req_id, "actual_duration_min": actual},
timeout=3,
)
stats["feedback_sent"] += 1
except Exception:
pass
if step > 0 and step % 10 == 0 and step % 50 != 0:
# Emit a metrics snapshot so the performance log captures the drift curve
try:
requests.get(f"{api}/monitor/metrics", timeout=5)
except Exception:
pass
if step > 0 and step % 50 == 0:
try:
r = requests.get(f"{api}/monitor/drift", timeout=10)
drift = r.json()
if drift.get("drift_detected"):
stats["drift_alerts"] += 1
log.warning(
"Step %d — DRIFT DETECTED features=%s action=%s",
step, drift.get("drifted_features"), drift.get("action"),
)
if drift.get("action") == "retraining_triggered":
stats["retrain_events"] += 1
else:
log.info("Step %d — predictions=%d feedback=%d",
step, stats["predictions"], stats["feedback_sent"])
except Exception as e:
log.debug("Drift check failed: %s", e)
time.sleep(args.delay)
log.info("Simulation complete: predictions=%d, feedback=%d, drift_alerts=%d, retrain=%d",
stats["predictions"], stats["feedback_sent"], stats["drift_alerts"], stats["retrain_events"])
if __name__ == "__main__":
main()