File size: 4,873 Bytes
38593e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
Baseline Management Module

Handles extraction of baseline statistics from training data,
storage as MLflow artifacts, and retrieval for drift detection.
"""

import json
from pathlib import Path
import pickle
from typing import Dict, List, Optional

from loguru import logger
import numpy as np

from turing import config

try:
    import mlflow
    from mlflow.tracking import MlflowClient
except ImportError:
    mlflow = None


def extract_baseline_statistics(
    X_train: List[str],
    y_train: np.ndarray,
    language: str = "java",
) -> Dict:
    """
    Extract baseline statistics from training data.

    Args:
        X_train: List of training comment texts
        y_train: Training labels (binary matrix or label indices)
        language: Language of the training data

    Returns:
        Dictionary containing baseline statistics
    """
    text_lengths = np.array([len(text) for text in X_train])
    word_counts = np.array([len(text.split()) for text in X_train])

    if len(y_train.shape) == 1:
        n_labels = int(np.max(y_train)) + 1
        label_counts = np.bincount(y_train.astype(int), minlength=n_labels)
    else:
        label_counts = np.sum(y_train, axis=0)
        n_labels = y_train.shape[1]

    baseline_stats = {
        "text_length_distribution": text_lengths.tolist(),
        "word_count_distribution": word_counts.tolist(),
        "label_counts": label_counts.tolist(),
        "language": language,
        "num_samples": len(X_train),
        "n_labels": int(n_labels),
        "text_length_mean": float(np.mean(text_lengths)),
        "text_length_std": float(np.std(text_lengths)),
        "text_length_min": float(np.min(text_lengths)),
        "text_length_max": float(np.max(text_lengths)),
        "word_count_mean": float(np.mean(word_counts)),
        "word_count_std": float(np.std(word_counts)),
    }

    logger.info(f"Extracted baseline for {language}: {len(X_train)} samples")

    return baseline_stats


class BaselineManager:
    """
    Manages baseline statistics for drift detection.
    """

    def __init__(self, mlflow_enabled: bool = True, local_cache_dir: Optional[Path] = None):
        """
        Initialize baseline manager.

        Args:
            mlflow_enabled: Enable MLflow artifact logging
            local_cache_dir: Local cache directory (default from config.BASELINE_CACHE_DIR)
        """
        self.mlflow_enabled = mlflow_enabled and mlflow is not None
        self.local_cache_dir = local_cache_dir or config.BASELINE_CACHE_DIR
        self.local_cache_dir.mkdir(parents=True, exist_ok=True)

        if self.mlflow_enabled:
            self.mlflow_client = MlflowClient()

        logger.info(f"BaselineManager initialized (cache: {self.local_cache_dir})")

    def save_baseline(
        self,
        baseline_stats: Dict,
        language: str,
        dataset_name: str,
        model_id: str = "default",
        run_id: Optional[str] = None,
    ) -> None:
        """
        Save baseline statistics to MLflow and local cache.
        """
        baseline_path = self._get_baseline_path(language, dataset_name, model_id)

        baseline_path.parent.mkdir(parents=True, exist_ok=True)
        with open(baseline_path, "wb") as f:
            pickle.dump(baseline_stats, f)
        logger.info(f"Saved baseline to {baseline_path}")

        if self.mlflow_enabled and run_id:
            try:
                json_path = baseline_path.with_suffix(".json")
                json_stats = {
                    k: v
                    for k, v in baseline_stats.items()
                    if isinstance(v, (int, float, str, list, bool))
                }
                with open(json_path, "w") as f:
                    json.dump(json_stats, f, indent=2)

                mlflow.log_artifact(str(json_path), artifact_path=f"baselines/{language}")
                logger.info("Logged baseline to MLflow")
            except Exception as e:
                logger.warning(f"Failed to log baseline to MLflow: {e}")

    def load_baseline(
        self,
        language: str,
        dataset_name: str,
        model_id: str = "default",
    ) -> Dict:
        """
        Load baseline statistics from local cache.
        """
        baseline_path = self._get_baseline_path(language, dataset_name, model_id)

        if baseline_path.exists():
            with open(baseline_path, "rb") as f:
                baseline_stats = pickle.load(f)
            logger.info(f"Loaded baseline from cache: {baseline_path}")
            return baseline_stats

        raise FileNotFoundError(f"Baseline not found at {baseline_path}")

    def _get_baseline_path(self, language: str, dataset_name: str, model_id: str) -> Path:
        """Generate local cache path for baseline."""
        return self.local_cache_dir / language / f"{dataset_name}_{model_id}_baseline.pkl"