File size: 12,714 Bytes
713632e
df25ba9
713632e
 
 
df25ba9
713632e
 
 
 
 
 
 
 
 
df25ba9
713632e
 
 
 
 
 
 
 
4617060
 
 
 
 
 
 
 
 
 
713632e
 
 
 
 
 
 
 
 
 
 
 
 
 
df25ba9
713632e
 
 
 
df25ba9
713632e
 
df25ba9
713632e
 
 
 
df25ba9
713632e
 
 
 
 
 
 
 
 
 
 
 
 
 
df25ba9
713632e
 
 
df25ba9
713632e
 
df25ba9
713632e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df25ba9
713632e
 
df25ba9
713632e
 
 
 
 
 
df25ba9
713632e
 
 
 
 
 
df25ba9
 
713632e
df25ba9
713632e
df25ba9
 
713632e
 
df25ba9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713632e
 
 
 
 
 
df25ba9
713632e
 
df25ba9
713632e
20503f2
713632e
 
 
 
df25ba9
713632e
 
df25ba9
713632e
20503f2
713632e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df25ba9
713632e
 
 
df25ba9
713632e
 
 
 
 
 
 
 
 
 
 
df25ba9
713632e
 
 
 
 
 
 
 
 
 
 
 
 
df25ba9
713632e
 
df25ba9
713632e
 
 
 
df25ba9
713632e
 
 
 
 
 
 
 
 
 
 
 
df25ba9
713632e
 
 
 
 
df25ba9
 
 
713632e
df25ba9
713632e
 
 
 
 
df25ba9
713632e
 
 
 
df25ba9
713632e
df25ba9
 
 
 
713632e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da1e99f
713632e
 
df25ba9
713632e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
"""Model Layer - ML model management and inference.

This module handles the low-level ML operations including:
- Model loading and storage via ModelRegistry
- Inference execution via ModelPredictor

Architecture:
    - ModelRegistry: Central storage for loaded models with lazy loading
    - ModelPredictor: Executes inference using registered models
"""

import logging
import os
from pathlib import Path
import sys
from typing import Any, Dict, List, Optional, Tuple

import dagshub
import mlflow
import numpy as np
import torch

from nygaardcodecommentclassification import config

# Patch torch.load to use CPU mapping by default if CUDA is not available
# This prevents "Attempting to deserialize object on a CUDA device" errors
_original_torch_load = torch.load
def _patched_torch_load(f, map_location=None, *args, **kwargs):
    """Wrapper around torch.load that uses CPU mapping if CUDA unavailable."""
    if map_location is None and not torch.cuda.is_available():
        map_location = torch.device('cpu')
    return _original_torch_load(f, map_location=map_location, *args, **kwargs)
torch.load = _patched_torch_load

# Configure module logger with explicit handler to ensure visibility
logger = logging.getLogger("nygaard.models")
logger.setLevel(logging.DEBUG)
# Avoid duplicate handlers if module is reloaded
if not logger.handlers:
    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(logging.DEBUG)
    formatter = logging.Formatter("[%(levelname)s] %(name)s: %(message)s")
    handler.setFormatter(formatter)
    logger.addHandler(handler)


class ModelRegistry:
    """Central registry for ML models loaded in memory.

    This class manages the lifecycle of ML models, providing:
    - Automatic discovery and loading of models from the filesystem
    - Organized storage by language and model type
    - Memory management with explicit cleanup

    Attributes:
        _registry: Internal dictionary storing loaded models

    Example:
        ```python
        registry = ModelRegistry()
        registry.load_all_models(Path("./models"))

        # Access a loaded model
        model_entry = registry.get_model("python", "catboost")
        if model_entry:
            model = model_entry["model"]
            embedder = model_entry.get("embedder")
        ```
    """

    def __init__(self) -> None:
        """Initialize an empty model registry."""
        self._registry: Dict[str, Dict[str, Any]] = {}

    def load_all_models(self) -> None:
        """Load all ML models from MLflow tracking server.

        This method connects to the MLflow tracking server (DagsHub) and loads
        CatBoost classifiers and sentence transformer embedders for all
        configured languages.

        Environment Variables:
            DAGSHUB_USER_TOKEN: Authentication token for DagsHub/MLflow

        Note:
            - Continues loading other models if one fails
            - Logs all loading activity for debugging
        """
        logger.info("Starting to load all models from MLflow")
        # Initialize MLflow with DagsHub - uses DAGSHUB_USER_TOKEN env var for auth
        # Set DAGSHUB_USER_TOKEN in your environment to avoid interactive login
        dagshub_token = os.environ.get("DAGSHUB_USER_TOKEN")
        if dagshub_token:
            os.environ["MLFLOW_TRACKING_USERNAME"] = dagshub_token
            os.environ["MLFLOW_TRACKING_PASSWORD"] = dagshub_token
            logger.info("Using DAGSHUB_USER_TOKEN for authentication")
        else:
            logger.warning("DAGSHUB_USER_TOKEN not set - may require interactive login")

        dagshub.init(repo_owner="se4ai2526-uniba", repo_name="Nygaard", mlflow=True)
        mlflow.set_experiment("evaluating")

        # Load models for all configured languages directly from MLflow
        # No need for local directory structure
        for lang in config.LANGUAGES:
            logger.info("Loading models for language: %s", lang)
            if lang not in self._registry:
                self._registry[lang] = {}
            self._load_catboost_models(lang)
        logger.info("Finished loading all models from MLflow")

    def _load_catboost_models(self, lang: str) -> None:
        """Load CatBoost models for a specific language from MLflow.

        Downloads and loads the CatBoost classifier and sentence transformer
        embedder directly from MLflow tracking server.

        Args:
            lang: The programming language code (e.g., "python", "java")
        """
        # Find the CatBoost run
        catboost_runs = mlflow.search_runs(
            experiment_names=["evaluating"], filter_string="tags.model = 'catboost'"
        ).sort_values(by="metrics.final_score", ascending=False)
        if catboost_runs.empty:
            logger.error("No CatBoost run found in 'evaluating' experiment")
            return
        catboost_run = catboost_runs.iloc[0]
        catboost_run_id = catboost_run.run_id
        catboost_run_name = catboost_run.get("tags.mlflow.runName", "unknown")
        catboost_git_commit = catboost_run.get("tags.mlflow.source.git.commit")

        logger.info(
            "Found CatBoost run: '%s' (ID: %s, commit: %s)",
            catboost_run_name,
            catboost_run_id,
            catboost_git_commit,
        )

        # Find the embedder run with same git commit and source file
        embedder_run = None
        embedder_run_id = None
        embedder_run_name = None

        if catboost_git_commit:
            # Search for sentence transformer with same git commit
            logger.info(
                "[%s] Searching for embedder with git commit: %s",
                lang.upper(),
                catboost_git_commit,
            )
            embedder_runs = mlflow.search_runs(
                experiment_names=["evaluating"],
                filter_string=f"tags.`mlflow.source.git.commit` = '{catboost_git_commit}' and run_name LIKE 'sentence_transformer%'",
            )

            if not embedder_runs.empty:
                embedder_run = embedder_runs.iloc[0]
                embedder_run_id = embedder_run.run_id
                embedder_run_name = embedder_run.get("tags.mlflow.runName", "unknown")
                logger.info(
                    "[%s] Found embedder with matching git commit: '%s' (ID: %s)",
                    lang.upper(),
                    embedder_run_name,
                    embedder_run_id,
                )

        # Fallback: search by default name if git commit search failed
        if not embedder_run_id:
            logger.info(
                "[%s] Falling back to default embedder search",
                lang.upper(),
            )
            embedder_runs = mlflow.search_runs(
                experiment_names=["evaluating"],
                filter_string="run_name = 'sentence_transformer_paraphrase-MiniLM-L6-v2'",
            )
            if embedder_runs.empty:
                logger.error(
                    "No embedder run found for 'sentence_transformer_paraphrase-MiniLM-L6-v2'"
                )
                return
            embedder_run = embedder_runs.iloc[0]
            embedder_run_id = embedder_run.run_id
            embedder_run_name = embedder_run.get("tags.mlflow.runName", "unknown")
            logger.info(
                "Found Embedder run: '%s' (ID: %s)",
                embedder_run_name,
                embedder_run_id,
            )

        try:
            # Load the CatBoost model from MLflow
            model_uri = f"runs:/{catboost_run_id}/model_{lang}"
            logger.info(
                "[%s] Loading CatBoost classifier from run '%s' (ID: %s)...",
                lang.upper(),
                catboost_run_name,
                catboost_run_id,
            )
            model = mlflow.sklearn.load_model(model_uri)

            # Load the sentence transformer embedder from MLflow
            embedder_uri = f"runs:/{embedder_run_id}/model_{lang}"
            logger.info(
                "[%s] Loading sentence transformer from run '%s' (ID: %s)...",
                lang.upper(),
                embedder_run_name,
                embedder_run_id,
            )
            embedder = mlflow.sklearn.load_model(embedder_uri)

            # Register the model with its metadata
            self._registry[lang]["catboost"] = {
                "model": model,
                "feature_type": "embeddings",
                "embedder": embedder,
            }
            logger.info(
                "[%s] ✓ Ready: CatBoost + %s embeddings",
                lang.upper(),
                embedder_run_name.replace("sentence_transformer_", ""),
            )

        except Exception as e:
            logger.error("[%s] Error loading models: %s", lang.upper(), e)

    def get_model(self, language: str, model_type: str) -> Optional[Dict[str, Any]]:
        """Retrieve a loaded model entry by language and type.

        Args:
            language: The programming language code
            model_type: The type of model

        Returns:
            Dict containing the model and metadata, or None if not found.
            The dict contains:
            - "model": The loaded ML model object
            - "feature_type": Type of features used
            - "embedder": Optional sentence transformer for embedding generation
        """
        return self._registry.get(language, {}).get(model_type)

    def clear(self) -> None:
        """Clear all models from the registry and free memory.

        This method should be called during application shutdown to
        release GPU memory and other resources.
        """
        self._registry.clear()

        # Clear CUDA cache if GPU was used
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            logger.info("CUDA cache cleared")


class ModelPredictor:
    """Handles low-level prediction logic.

    Attributes:
        registry: Reference to the ModelRegistry for model access

    Example:
        ```python
        registry = ModelRegistry()
        registry.load_all_models(Path("./models"))

        predictor = ModelPredictor(registry)
        predictions = predictor.predict(
            texts=["# Calculate sum of list"],
            language="python",
            model_type="catboost"
        )
        # predictions: np.ndarray with shape (1, num_labels)
        ```
    """

    def __init__(self, model_registry: ModelRegistry) -> None:
        """Initialize the predictor with a model registry.

        Args:
            model_registry: The ModelRegistry instance containing loaded models
        """
        self.registry = model_registry

    def predict(
        self, texts: List[str], language: str, model_type: str
    ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
        """Execute prediction on a list of texts.

        This method handles the full inference pipeline:
        1. Retrieve the appropriate model from the registry
        2. Extract features (e.g., generate embeddings)
        3. Run model inference
        4. Return raw predictions

        Args:
            texts: List of code comment strings to classify
            language: Programming language context for model selection
            model_type: Type of model to use

        Returns:
            Tuple containing:
            - numpy array of predictions with shape (n_samples, n_labels).
            - numpy array of embeddings (if available, else None).

        Raises:
            ValueError: If the requested model is not available or
                       if an unsupported feature/model type is specified
        """
        # Retrieve model entry from registry
        model_entry = self.registry.get_model(language, model_type)
        if not model_entry or "model" not in model_entry:
            raise ValueError(f"Model {model_type} not available for {language}")

        model = model_entry["model"]

        # Handle CatBoost models
        if model_type == "catboost":
            if model_entry.get("feature_type") == "embeddings":
                # Generate embeddings using the SetFit sentence transformer
                embedder = model_entry.get("embedder")
                if embedder is None:
                    raise ValueError(f"Embedder not loaded for {language}")

                # Encode texts to dense embeddings (no progress bar for API use)
                embeddings = embedder.encode(texts, show_progress_bar=False)

                # Run CatBoost prediction on embeddings
                return model.predict(embeddings), embeddings

            raise ValueError("Unsupported feature type for CatBoost")

        raise ValueError(f"Unknown model type: {model_type}")