Spaces:
Running
Running
| """ | |
| Baseline Management Module | |
| Handles extraction of baseline statistics from training data, | |
| storage as MLflow artifacts, and retrieval for drift detection. | |
| """ | |
| import json | |
| from pathlib import Path | |
| import pickle | |
| from typing import Dict, List, Optional | |
| from loguru import logger | |
| import numpy as np | |
| from turing import config | |
| try: | |
| import mlflow | |
| from mlflow.tracking import MlflowClient | |
| except ImportError: | |
| mlflow = None | |
| def extract_baseline_statistics( | |
| X_train: List[str], | |
| y_train: np.ndarray, | |
| language: str = "java", | |
| ) -> Dict: | |
| """ | |
| Extract baseline statistics from training data. | |
| Args: | |
| X_train: List of training comment texts | |
| y_train: Training labels (binary matrix or label indices) | |
| language: Language of the training data | |
| Returns: | |
| Dictionary containing baseline statistics | |
| """ | |
| text_lengths = np.array([len(text) for text in X_train]) | |
| word_counts = np.array([len(text.split()) for text in X_train]) | |
| if len(y_train.shape) == 1: | |
| n_labels = int(np.max(y_train)) + 1 | |
| label_counts = np.bincount(y_train.astype(int), minlength=n_labels) | |
| else: | |
| label_counts = np.sum(y_train, axis=0) | |
| n_labels = y_train.shape[1] | |
| baseline_stats = { | |
| "text_length_distribution": text_lengths.tolist(), | |
| "word_count_distribution": word_counts.tolist(), | |
| "label_counts": label_counts.tolist(), | |
| "language": language, | |
| "num_samples": len(X_train), | |
| "n_labels": int(n_labels), | |
| "text_length_mean": float(np.mean(text_lengths)), | |
| "text_length_std": float(np.std(text_lengths)), | |
| "text_length_min": float(np.min(text_lengths)), | |
| "text_length_max": float(np.max(text_lengths)), | |
| "word_count_mean": float(np.mean(word_counts)), | |
| "word_count_std": float(np.std(word_counts)), | |
| } | |
| logger.info(f"Extracted baseline for {language}: {len(X_train)} samples") | |
| return baseline_stats | |
| class BaselineManager: | |
| """ | |
| Manages baseline statistics for drift detection. | |
| """ | |
| def __init__(self, mlflow_enabled: bool = True, local_cache_dir: Optional[Path] = None): | |
| """ | |
| Initialize baseline manager. | |
| Args: | |
| mlflow_enabled: Enable MLflow artifact logging | |
| local_cache_dir: Local cache directory (default from config.BASELINE_CACHE_DIR) | |
| """ | |
| self.mlflow_enabled = mlflow_enabled and mlflow is not None | |
| self.local_cache_dir = local_cache_dir or config.BASELINE_CACHE_DIR | |
| self.local_cache_dir.mkdir(parents=True, exist_ok=True) | |
| if self.mlflow_enabled: | |
| self.mlflow_client = MlflowClient() | |
| logger.info(f"BaselineManager initialized (cache: {self.local_cache_dir})") | |
| def save_baseline( | |
| self, | |
| baseline_stats: Dict, | |
| language: str, | |
| dataset_name: str, | |
| model_id: str = "default", | |
| run_id: Optional[str] = None, | |
| ) -> None: | |
| """ | |
| Save baseline statistics to MLflow and local cache. | |
| """ | |
| baseline_path = self._get_baseline_path(language, dataset_name, model_id) | |
| baseline_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(baseline_path, "wb") as f: | |
| pickle.dump(baseline_stats, f) | |
| logger.info(f"Saved baseline to {baseline_path}") | |
| if self.mlflow_enabled and run_id: | |
| try: | |
| json_path = baseline_path.with_suffix(".json") | |
| json_stats = { | |
| k: v | |
| for k, v in baseline_stats.items() | |
| if isinstance(v, (int, float, str, list, bool)) | |
| } | |
| with open(json_path, "w") as f: | |
| json.dump(json_stats, f, indent=2) | |
| mlflow.log_artifact(str(json_path), artifact_path=f"baselines/{language}") | |
| logger.info("Logged baseline to MLflow") | |
| except Exception as e: | |
| logger.warning(f"Failed to log baseline to MLflow: {e}") | |
| def load_baseline( | |
| self, | |
| language: str, | |
| dataset_name: str, | |
| model_id: str = "default", | |
| ) -> Dict: | |
| """ | |
| Load baseline statistics from local cache. | |
| """ | |
| baseline_path = self._get_baseline_path(language, dataset_name, model_id) | |
| if baseline_path.exists(): | |
| with open(baseline_path, "rb") as f: | |
| baseline_stats = pickle.load(f) | |
| logger.info(f"Loaded baseline from cache: {baseline_path}") | |
| return baseline_stats | |
| raise FileNotFoundError(f"Baseline not found at {baseline_path}") | |
| def _get_baseline_path(self, language: str, dataset_name: str, model_id: str) -> Path: | |
| """Generate local cache path for baseline.""" | |
| return self.local_cache_dir / language / f"{dataset_name}_{model_id}_baseline.pkl" | |