turing-space / turing /monitoring /mlflow_logger.py
github-actions[bot]
Sync turing folder from GitHub
38593e7
"""
MLflow Logging for Drift Detection
Handles logging drift metrics and alerts to MLflow experiment runs.
"""
from typing import Dict, Optional
from loguru import logger
try:
import mlflow
except ImportError:
mlflow = None
class DriftLogger:
"""
Logs drift detection results to MLflow.
"""
def __init__(self, log_artifacts: bool = False):
"""
Initialize drift logger.
"""
self.log_artifacts = log_artifacts
self.has_mlflow = mlflow is not None
def log_drift_results(
self,
drift_results: Dict,
step: Optional[int] = None,
prefix: str = "drift",
) -> None:
"""
Log drift detection results to MLflow.
"""
if not self.has_mlflow:
logger.debug("MLflow not available")
return
try:
overall = drift_results.get("overall", {})
mlflow.log_metric(f"{prefix}/drifted", float(overall.get("drifted", False)), step=step)
mlflow.log_metric(
f"{prefix}/num_drifts", float(overall.get("num_drifts", 0)), step=step
)
for drift_type, result in drift_results.items():
if drift_type == "overall":
continue
if "p_value" in result:
mlflow.log_metric(
f"{prefix}/{drift_type}/p_value", result["p_value"], step=step
)
logger.debug("Logged drift results to MLflow")
except Exception as e:
logger.warning(f"Failed to log drift to MLflow: {e}")
def log_baseline_statistics(
self,
baseline_stats: Dict,
prefix: str = "baseline",
) -> None:
"""
Log baseline statistics to MLflow.
"""
if not self.has_mlflow:
return
try:
metrics = {
f"{prefix}/num_samples": baseline_stats.get("num_samples"),
f"{prefix}/text_length_mean": baseline_stats.get("text_length_mean"),
f"{prefix}/word_count_mean": baseline_stats.get("word_count_mean"),
}
for metric_name, value in metrics.items():
if value is not None:
mlflow.log_metric(metric_name, float(value))
mlflow.log_param(f"{prefix}/language", baseline_stats.get("language", "unknown"))
logger.debug("Logged baseline to MLflow")
except Exception as e:
logger.warning(f"Failed to log baseline: {e}")
def log_alert(self, message: str, severity: str = "warning") -> None:
"""
Log drift alert message.
"""
logger_func = getattr(logger, severity, logger.warning)
logger_func(f"DRIFT ALERT: {message}")