""" Baseline Management Module Handles extraction of baseline statistics from training data, storage as MLflow artifacts, and retrieval for drift detection. """ import json from pathlib import Path import pickle from typing import Dict, List, Optional from loguru import logger import numpy as np from turing import config try: import mlflow from mlflow.tracking import MlflowClient except ImportError: mlflow = None def extract_baseline_statistics( X_train: List[str], y_train: np.ndarray, language: str = "java", ) -> Dict: """ Extract baseline statistics from training data. Args: X_train: List of training comment texts y_train: Training labels (binary matrix or label indices) language: Language of the training data Returns: Dictionary containing baseline statistics """ text_lengths = np.array([len(text) for text in X_train]) word_counts = np.array([len(text.split()) for text in X_train]) if len(y_train.shape) == 1: n_labels = int(np.max(y_train)) + 1 label_counts = np.bincount(y_train.astype(int), minlength=n_labels) else: label_counts = np.sum(y_train, axis=0) n_labels = y_train.shape[1] baseline_stats = { "text_length_distribution": text_lengths.tolist(), "word_count_distribution": word_counts.tolist(), "label_counts": label_counts.tolist(), "language": language, "num_samples": len(X_train), "n_labels": int(n_labels), "text_length_mean": float(np.mean(text_lengths)), "text_length_std": float(np.std(text_lengths)), "text_length_min": float(np.min(text_lengths)), "text_length_max": float(np.max(text_lengths)), "word_count_mean": float(np.mean(word_counts)), "word_count_std": float(np.std(word_counts)), } logger.info(f"Extracted baseline for {language}: {len(X_train)} samples") return baseline_stats class BaselineManager: """ Manages baseline statistics for drift detection. """ def __init__(self, mlflow_enabled: bool = True, local_cache_dir: Optional[Path] = None): """ Initialize baseline manager. Args: mlflow_enabled: Enable MLflow artifact logging local_cache_dir: Local cache directory (default from config.BASELINE_CACHE_DIR) """ self.mlflow_enabled = mlflow_enabled and mlflow is not None self.local_cache_dir = local_cache_dir or config.BASELINE_CACHE_DIR self.local_cache_dir.mkdir(parents=True, exist_ok=True) if self.mlflow_enabled: self.mlflow_client = MlflowClient() logger.info(f"BaselineManager initialized (cache: {self.local_cache_dir})") def save_baseline( self, baseline_stats: Dict, language: str, dataset_name: str, model_id: str = "default", run_id: Optional[str] = None, ) -> None: """ Save baseline statistics to MLflow and local cache. """ baseline_path = self._get_baseline_path(language, dataset_name, model_id) baseline_path.parent.mkdir(parents=True, exist_ok=True) with open(baseline_path, "wb") as f: pickle.dump(baseline_stats, f) logger.info(f"Saved baseline to {baseline_path}") if self.mlflow_enabled and run_id: try: json_path = baseline_path.with_suffix(".json") json_stats = { k: v for k, v in baseline_stats.items() if isinstance(v, (int, float, str, list, bool)) } with open(json_path, "w") as f: json.dump(json_stats, f, indent=2) mlflow.log_artifact(str(json_path), artifact_path=f"baselines/{language}") logger.info("Logged baseline to MLflow") except Exception as e: logger.warning(f"Failed to log baseline to MLflow: {e}") def load_baseline( self, language: str, dataset_name: str, model_id: str = "default", ) -> Dict: """ Load baseline statistics from local cache. """ baseline_path = self._get_baseline_path(language, dataset_name, model_id) if baseline_path.exists(): with open(baseline_path, "rb") as f: baseline_stats = pickle.load(f) logger.info(f"Loaded baseline from cache: {baseline_path}") return baseline_stats raise FileNotFoundError(f"Baseline not found at {baseline_path}") def _get_baseline_path(self, language: str, dataset_name: str, model_id: str) -> Path: """Generate local cache path for baseline.""" return self.local_cache_dir / language / f"{dataset_name}_{model_id}_baseline.pkl"