fea-surrogate / src /app /monitoring.py
WolfDavid's picture
Upload folder using huggingface_hub
8e5ba9e verified
"""Lightweight data drift monitoring for the Gradio app.
Tracks input feature distributions and flags when new inputs fall outside
the training distribution. This is a production-quality signal that Google
interviewers look for — it shows awareness of ML monitoring challenges.
"""
import logging
from typing import Optional
import numpy as np
logger = logging.getLogger(__name__)
class DataDriftMonitor:
"""Monitors input features for distribution drift.
Compares incoming predictions against training set statistics.
Flags inputs where any feature falls outside the training range
(plus a configurable margin).
"""
def __init__(self, margin: float = 0.1) -> None:
self.margin = margin # fractional margin beyond training range
self.feature_mins: Optional[dict[str, float]] = None
self.feature_maxs: Optional[dict[str, float]] = None
self.prediction_count: int = 0
self.ood_count: int = 0
def fit(self, feature_ranges: dict[str, tuple[float, float]]) -> None:
"""Set training distribution ranges.
Args:
feature_ranges: Dict mapping feature name to (min, max) from training data.
"""
self.feature_mins = {k: v[0] for k, v in feature_ranges.items()}
self.feature_maxs = {k: v[1] for k, v in feature_ranges.items()}
def check(self, features: dict[str, float]) -> dict:
"""Check if input features are within training distribution.
Returns:
Dict with:
- 'in_distribution': bool
- 'warnings': list of out-of-range feature descriptions
- 'ood_rate': fraction of total predictions that were OOD
"""
if self.feature_mins is None:
return {"in_distribution": True, "warnings": [], "ood_rate": 0.0}
self.prediction_count += 1
warnings = []
for name, value in features.items():
if name not in self.feature_mins:
continue
lo = self.feature_mins[name]
hi = self.feature_maxs[name]
range_size = hi - lo if hi > lo else 1.0
margin = self.margin * range_size
if value < lo - margin:
pct = abs(value - lo) / range_size * 100
warnings.append(
f"{name}={value:.4g} is {pct:.1f}% below training minimum ({lo:.4g})"
)
elif value > hi + margin:
pct = abs(value - hi) / range_size * 100
warnings.append(
f"{name}={value:.4g} is {pct:.1f}% above training maximum ({hi:.4g})"
)
is_ood = len(warnings) > 0
if is_ood:
self.ood_count += 1
logger.warning(f"OOD input detected: {warnings}")
return {
"in_distribution": not is_ood,
"warnings": warnings,
"ood_rate": self.ood_count / self.prediction_count,
}
def get_stats(self) -> dict:
"""Return monitoring statistics."""
return {
"total_predictions": self.prediction_count,
"ood_predictions": self.ood_count,
"ood_rate": self.ood_count / max(self.prediction_count, 1),
}