Spaces:
Sleeping
Sleeping
| """Lightweight data drift monitoring for the Gradio app. | |
| Tracks input feature distributions and flags when new inputs fall outside | |
| the training distribution. This is a production-quality signal that Google | |
| interviewers look for — it shows awareness of ML monitoring challenges. | |
| """ | |
| import logging | |
| from typing import Optional | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| class DataDriftMonitor: | |
| """Monitors input features for distribution drift. | |
| Compares incoming predictions against training set statistics. | |
| Flags inputs where any feature falls outside the training range | |
| (plus a configurable margin). | |
| """ | |
| def __init__(self, margin: float = 0.1) -> None: | |
| self.margin = margin # fractional margin beyond training range | |
| self.feature_mins: Optional[dict[str, float]] = None | |
| self.feature_maxs: Optional[dict[str, float]] = None | |
| self.prediction_count: int = 0 | |
| self.ood_count: int = 0 | |
| def fit(self, feature_ranges: dict[str, tuple[float, float]]) -> None: | |
| """Set training distribution ranges. | |
| Args: | |
| feature_ranges: Dict mapping feature name to (min, max) from training data. | |
| """ | |
| self.feature_mins = {k: v[0] for k, v in feature_ranges.items()} | |
| self.feature_maxs = {k: v[1] for k, v in feature_ranges.items()} | |
| def check(self, features: dict[str, float]) -> dict: | |
| """Check if input features are within training distribution. | |
| Returns: | |
| Dict with: | |
| - 'in_distribution': bool | |
| - 'warnings': list of out-of-range feature descriptions | |
| - 'ood_rate': fraction of total predictions that were OOD | |
| """ | |
| if self.feature_mins is None: | |
| return {"in_distribution": True, "warnings": [], "ood_rate": 0.0} | |
| self.prediction_count += 1 | |
| warnings = [] | |
| for name, value in features.items(): | |
| if name not in self.feature_mins: | |
| continue | |
| lo = self.feature_mins[name] | |
| hi = self.feature_maxs[name] | |
| range_size = hi - lo if hi > lo else 1.0 | |
| margin = self.margin * range_size | |
| if value < lo - margin: | |
| pct = abs(value - lo) / range_size * 100 | |
| warnings.append( | |
| f"{name}={value:.4g} is {pct:.1f}% below training minimum ({lo:.4g})" | |
| ) | |
| elif value > hi + margin: | |
| pct = abs(value - hi) / range_size * 100 | |
| warnings.append( | |
| f"{name}={value:.4g} is {pct:.1f}% above training maximum ({hi:.4g})" | |
| ) | |
| is_ood = len(warnings) > 0 | |
| if is_ood: | |
| self.ood_count += 1 | |
| logger.warning(f"OOD input detected: {warnings}") | |
| return { | |
| "in_distribution": not is_ood, | |
| "warnings": warnings, | |
| "ood_rate": self.ood_count / self.prediction_count, | |
| } | |
| def get_stats(self) -> dict: | |
| """Return monitoring statistics.""" | |
| return { | |
| "total_predictions": self.prediction_count, | |
| "ood_predictions": self.ood_count, | |
| "ood_rate": self.ood_count / max(self.prediction_count, 1), | |
| } | |