"""Lightweight data drift monitoring for the Gradio app. Tracks input feature distributions and flags when new inputs fall outside the training distribution. This is a production-quality signal that Google interviewers look for — it shows awareness of ML monitoring challenges. """ import logging from typing import Optional import numpy as np logger = logging.getLogger(__name__) class DataDriftMonitor: """Monitors input features for distribution drift. Compares incoming predictions against training set statistics. Flags inputs where any feature falls outside the training range (plus a configurable margin). """ def __init__(self, margin: float = 0.1) -> None: self.margin = margin # fractional margin beyond training range self.feature_mins: Optional[dict[str, float]] = None self.feature_maxs: Optional[dict[str, float]] = None self.prediction_count: int = 0 self.ood_count: int = 0 def fit(self, feature_ranges: dict[str, tuple[float, float]]) -> None: """Set training distribution ranges. Args: feature_ranges: Dict mapping feature name to (min, max) from training data. """ self.feature_mins = {k: v[0] for k, v in feature_ranges.items()} self.feature_maxs = {k: v[1] for k, v in feature_ranges.items()} def check(self, features: dict[str, float]) -> dict: """Check if input features are within training distribution. Returns: Dict with: - 'in_distribution': bool - 'warnings': list of out-of-range feature descriptions - 'ood_rate': fraction of total predictions that were OOD """ if self.feature_mins is None: return {"in_distribution": True, "warnings": [], "ood_rate": 0.0} self.prediction_count += 1 warnings = [] for name, value in features.items(): if name not in self.feature_mins: continue lo = self.feature_mins[name] hi = self.feature_maxs[name] range_size = hi - lo if hi > lo else 1.0 margin = self.margin * range_size if value < lo - margin: pct = abs(value - lo) / range_size * 100 warnings.append( f"{name}={value:.4g} is {pct:.1f}% below training minimum ({lo:.4g})" ) elif value > hi + margin: pct = abs(value - hi) / range_size * 100 warnings.append( f"{name}={value:.4g} is {pct:.1f}% above training maximum ({hi:.4g})" ) is_ood = len(warnings) > 0 if is_ood: self.ood_count += 1 logger.warning(f"OOD input detected: {warnings}") return { "in_distribution": not is_ood, "warnings": warnings, "ood_rate": self.ood_count / self.prediction_count, } def get_stats(self) -> dict: """Return monitoring statistics.""" return { "total_predictions": self.prediction_count, "ood_predictions": self.ood_count, "ood_rate": self.ood_count / max(self.prediction_count, 1), }