turing-space / turing /monitoring /baseline_manager.py
github-actions[bot]
Sync turing folder from GitHub
38593e7
"""
Baseline Management Module
Handles extraction of baseline statistics from training data,
storage as MLflow artifacts, and retrieval for drift detection.
"""
import json
from pathlib import Path
import pickle
from typing import Dict, List, Optional
from loguru import logger
import numpy as np
from turing import config
try:
import mlflow
from mlflow.tracking import MlflowClient
except ImportError:
mlflow = None
def extract_baseline_statistics(
X_train: List[str],
y_train: np.ndarray,
language: str = "java",
) -> Dict:
"""
Extract baseline statistics from training data.
Args:
X_train: List of training comment texts
y_train: Training labels (binary matrix or label indices)
language: Language of the training data
Returns:
Dictionary containing baseline statistics
"""
text_lengths = np.array([len(text) for text in X_train])
word_counts = np.array([len(text.split()) for text in X_train])
if len(y_train.shape) == 1:
n_labels = int(np.max(y_train)) + 1
label_counts = np.bincount(y_train.astype(int), minlength=n_labels)
else:
label_counts = np.sum(y_train, axis=0)
n_labels = y_train.shape[1]
baseline_stats = {
"text_length_distribution": text_lengths.tolist(),
"word_count_distribution": word_counts.tolist(),
"label_counts": label_counts.tolist(),
"language": language,
"num_samples": len(X_train),
"n_labels": int(n_labels),
"text_length_mean": float(np.mean(text_lengths)),
"text_length_std": float(np.std(text_lengths)),
"text_length_min": float(np.min(text_lengths)),
"text_length_max": float(np.max(text_lengths)),
"word_count_mean": float(np.mean(word_counts)),
"word_count_std": float(np.std(word_counts)),
}
logger.info(f"Extracted baseline for {language}: {len(X_train)} samples")
return baseline_stats
class BaselineManager:
"""
Manages baseline statistics for drift detection.
"""
def __init__(self, mlflow_enabled: bool = True, local_cache_dir: Optional[Path] = None):
"""
Initialize baseline manager.
Args:
mlflow_enabled: Enable MLflow artifact logging
local_cache_dir: Local cache directory (default from config.BASELINE_CACHE_DIR)
"""
self.mlflow_enabled = mlflow_enabled and mlflow is not None
self.local_cache_dir = local_cache_dir or config.BASELINE_CACHE_DIR
self.local_cache_dir.mkdir(parents=True, exist_ok=True)
if self.mlflow_enabled:
self.mlflow_client = MlflowClient()
logger.info(f"BaselineManager initialized (cache: {self.local_cache_dir})")
def save_baseline(
self,
baseline_stats: Dict,
language: str,
dataset_name: str,
model_id: str = "default",
run_id: Optional[str] = None,
) -> None:
"""
Save baseline statistics to MLflow and local cache.
"""
baseline_path = self._get_baseline_path(language, dataset_name, model_id)
baseline_path.parent.mkdir(parents=True, exist_ok=True)
with open(baseline_path, "wb") as f:
pickle.dump(baseline_stats, f)
logger.info(f"Saved baseline to {baseline_path}")
if self.mlflow_enabled and run_id:
try:
json_path = baseline_path.with_suffix(".json")
json_stats = {
k: v
for k, v in baseline_stats.items()
if isinstance(v, (int, float, str, list, bool))
}
with open(json_path, "w") as f:
json.dump(json_stats, f, indent=2)
mlflow.log_artifact(str(json_path), artifact_path=f"baselines/{language}")
logger.info("Logged baseline to MLflow")
except Exception as e:
logger.warning(f"Failed to log baseline to MLflow: {e}")
def load_baseline(
self,
language: str,
dataset_name: str,
model_id: str = "default",
) -> Dict:
"""
Load baseline statistics from local cache.
"""
baseline_path = self._get_baseline_path(language, dataset_name, model_id)
if baseline_path.exists():
with open(baseline_path, "rb") as f:
baseline_stats = pickle.load(f)
logger.info(f"Loaded baseline from cache: {baseline_path}")
return baseline_stats
raise FileNotFoundError(f"Baseline not found at {baseline_path}")
def _get_baseline_path(self, language: str, dataset_name: str, model_id: str) -> Path:
"""Generate local cache path for baseline."""
return self.local_cache_dir / language / f"{dataset_name}_{model_id}_baseline.pkl"