|
|
"""Model Layer - ML model management and inference. |
|
|
|
|
|
This module handles the low-level ML operations including: |
|
|
- Model loading and storage via ModelRegistry |
|
|
- Inference execution via ModelPredictor |
|
|
|
|
|
Architecture: |
|
|
- ModelRegistry: Central storage for loaded models with lazy loading |
|
|
- ModelPredictor: Executes inference using registered models |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import os |
|
|
from pathlib import Path |
|
|
import sys |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
import dagshub |
|
|
import mlflow |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from nygaardcodecommentclassification import config |
|
|
|
|
|
|
|
|
|
|
|
_original_torch_load = torch.load |
|
|
def _patched_torch_load(f, map_location=None, *args, **kwargs): |
|
|
"""Wrapper around torch.load that uses CPU mapping if CUDA unavailable.""" |
|
|
if map_location is None and not torch.cuda.is_available(): |
|
|
map_location = torch.device('cpu') |
|
|
return _original_torch_load(f, map_location=map_location, *args, **kwargs) |
|
|
torch.load = _patched_torch_load |
|
|
|
|
|
|
|
|
logger = logging.getLogger("nygaard.models") |
|
|
logger.setLevel(logging.DEBUG) |
|
|
|
|
|
if not logger.handlers: |
|
|
handler = logging.StreamHandler(sys.stdout) |
|
|
handler.setLevel(logging.DEBUG) |
|
|
formatter = logging.Formatter("[%(levelname)s] %(name)s: %(message)s") |
|
|
handler.setFormatter(formatter) |
|
|
logger.addHandler(handler) |
|
|
|
|
|
|
|
|
class ModelRegistry: |
|
|
"""Central registry for ML models loaded in memory. |
|
|
|
|
|
This class manages the lifecycle of ML models, providing: |
|
|
- Automatic discovery and loading of models from the filesystem |
|
|
- Organized storage by language and model type |
|
|
- Memory management with explicit cleanup |
|
|
|
|
|
Attributes: |
|
|
_registry: Internal dictionary storing loaded models |
|
|
|
|
|
Example: |
|
|
```python |
|
|
registry = ModelRegistry() |
|
|
registry.load_all_models(Path("./models")) |
|
|
|
|
|
# Access a loaded model |
|
|
model_entry = registry.get_model("python", "catboost") |
|
|
if model_entry: |
|
|
model = model_entry["model"] |
|
|
embedder = model_entry.get("embedder") |
|
|
``` |
|
|
""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
"""Initialize an empty model registry.""" |
|
|
self._registry: Dict[str, Dict[str, Any]] = {} |
|
|
|
|
|
def load_all_models(self) -> None: |
|
|
"""Load all ML models from MLflow tracking server. |
|
|
|
|
|
This method connects to the MLflow tracking server (DagsHub) and loads |
|
|
CatBoost classifiers and sentence transformer embedders for all |
|
|
configured languages. |
|
|
|
|
|
Environment Variables: |
|
|
DAGSHUB_USER_TOKEN: Authentication token for DagsHub/MLflow |
|
|
|
|
|
Note: |
|
|
- Continues loading other models if one fails |
|
|
- Logs all loading activity for debugging |
|
|
""" |
|
|
logger.info("Starting to load all models from MLflow") |
|
|
|
|
|
|
|
|
dagshub_token = os.environ.get("DAGSHUB_USER_TOKEN") |
|
|
if dagshub_token: |
|
|
os.environ["MLFLOW_TRACKING_USERNAME"] = dagshub_token |
|
|
os.environ["MLFLOW_TRACKING_PASSWORD"] = dagshub_token |
|
|
logger.info("Using DAGSHUB_USER_TOKEN for authentication") |
|
|
else: |
|
|
logger.warning("DAGSHUB_USER_TOKEN not set - may require interactive login") |
|
|
|
|
|
dagshub.init(repo_owner="se4ai2526-uniba", repo_name="Nygaard", mlflow=True) |
|
|
mlflow.set_experiment("evaluating") |
|
|
|
|
|
|
|
|
|
|
|
for lang in config.LANGUAGES: |
|
|
logger.info("Loading models for language: %s", lang) |
|
|
if lang not in self._registry: |
|
|
self._registry[lang] = {} |
|
|
self._load_catboost_models(lang) |
|
|
logger.info("Finished loading all models from MLflow") |
|
|
|
|
|
def _load_catboost_models(self, lang: str) -> None: |
|
|
"""Load CatBoost models for a specific language from MLflow. |
|
|
|
|
|
Downloads and loads the CatBoost classifier and sentence transformer |
|
|
embedder directly from MLflow tracking server. |
|
|
|
|
|
Args: |
|
|
lang: The programming language code (e.g., "python", "java") |
|
|
""" |
|
|
|
|
|
catboost_runs = mlflow.search_runs( |
|
|
experiment_names=["evaluating"], filter_string="tags.model = 'catboost'" |
|
|
).sort_values(by="metrics.final_score", ascending=False) |
|
|
if catboost_runs.empty: |
|
|
logger.error("No CatBoost run found in 'evaluating' experiment") |
|
|
return |
|
|
catboost_run = catboost_runs.iloc[0] |
|
|
catboost_run_id = catboost_run.run_id |
|
|
catboost_run_name = catboost_run.get("tags.mlflow.runName", "unknown") |
|
|
catboost_git_commit = catboost_run.get("tags.mlflow.source.git.commit") |
|
|
|
|
|
logger.info( |
|
|
"Found CatBoost run: '%s' (ID: %s, commit: %s)", |
|
|
catboost_run_name, |
|
|
catboost_run_id, |
|
|
catboost_git_commit, |
|
|
) |
|
|
|
|
|
|
|
|
embedder_run = None |
|
|
embedder_run_id = None |
|
|
embedder_run_name = None |
|
|
|
|
|
if catboost_git_commit: |
|
|
|
|
|
logger.info( |
|
|
"[%s] Searching for embedder with git commit: %s", |
|
|
lang.upper(), |
|
|
catboost_git_commit, |
|
|
) |
|
|
embedder_runs = mlflow.search_runs( |
|
|
experiment_names=["evaluating"], |
|
|
filter_string=f"tags.`mlflow.source.git.commit` = '{catboost_git_commit}' and run_name LIKE 'sentence_transformer%'", |
|
|
) |
|
|
|
|
|
if not embedder_runs.empty: |
|
|
embedder_run = embedder_runs.iloc[0] |
|
|
embedder_run_id = embedder_run.run_id |
|
|
embedder_run_name = embedder_run.get("tags.mlflow.runName", "unknown") |
|
|
logger.info( |
|
|
"[%s] Found embedder with matching git commit: '%s' (ID: %s)", |
|
|
lang.upper(), |
|
|
embedder_run_name, |
|
|
embedder_run_id, |
|
|
) |
|
|
|
|
|
|
|
|
if not embedder_run_id: |
|
|
logger.info( |
|
|
"[%s] Falling back to default embedder search", |
|
|
lang.upper(), |
|
|
) |
|
|
embedder_runs = mlflow.search_runs( |
|
|
experiment_names=["evaluating"], |
|
|
filter_string="run_name = 'sentence_transformer_paraphrase-MiniLM-L6-v2'", |
|
|
) |
|
|
if embedder_runs.empty: |
|
|
logger.error( |
|
|
"No embedder run found for 'sentence_transformer_paraphrase-MiniLM-L6-v2'" |
|
|
) |
|
|
return |
|
|
embedder_run = embedder_runs.iloc[0] |
|
|
embedder_run_id = embedder_run.run_id |
|
|
embedder_run_name = embedder_run.get("tags.mlflow.runName", "unknown") |
|
|
logger.info( |
|
|
"Found Embedder run: '%s' (ID: %s)", |
|
|
embedder_run_name, |
|
|
embedder_run_id, |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
model_uri = f"runs:/{catboost_run_id}/model_{lang}" |
|
|
logger.info( |
|
|
"[%s] Loading CatBoost classifier from run '%s' (ID: %s)...", |
|
|
lang.upper(), |
|
|
catboost_run_name, |
|
|
catboost_run_id, |
|
|
) |
|
|
model = mlflow.sklearn.load_model(model_uri) |
|
|
|
|
|
|
|
|
embedder_uri = f"runs:/{embedder_run_id}/model_{lang}" |
|
|
logger.info( |
|
|
"[%s] Loading sentence transformer from run '%s' (ID: %s)...", |
|
|
lang.upper(), |
|
|
embedder_run_name, |
|
|
embedder_run_id, |
|
|
) |
|
|
embedder = mlflow.sklearn.load_model(embedder_uri) |
|
|
|
|
|
|
|
|
self._registry[lang]["catboost"] = { |
|
|
"model": model, |
|
|
"feature_type": "embeddings", |
|
|
"embedder": embedder, |
|
|
} |
|
|
logger.info( |
|
|
"[%s] ✓ Ready: CatBoost + %s embeddings", |
|
|
lang.upper(), |
|
|
embedder_run_name.replace("sentence_transformer_", ""), |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error("[%s] Error loading models: %s", lang.upper(), e) |
|
|
|
|
|
def get_model(self, language: str, model_type: str) -> Optional[Dict[str, Any]]: |
|
|
"""Retrieve a loaded model entry by language and type. |
|
|
|
|
|
Args: |
|
|
language: The programming language code |
|
|
model_type: The type of model |
|
|
|
|
|
Returns: |
|
|
Dict containing the model and metadata, or None if not found. |
|
|
The dict contains: |
|
|
- "model": The loaded ML model object |
|
|
- "feature_type": Type of features used |
|
|
- "embedder": Optional sentence transformer for embedding generation |
|
|
""" |
|
|
return self._registry.get(language, {}).get(model_type) |
|
|
|
|
|
def clear(self) -> None: |
|
|
"""Clear all models from the registry and free memory. |
|
|
|
|
|
This method should be called during application shutdown to |
|
|
release GPU memory and other resources. |
|
|
""" |
|
|
self._registry.clear() |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
logger.info("CUDA cache cleared") |
|
|
|
|
|
|
|
|
class ModelPredictor: |
|
|
"""Handles low-level prediction logic. |
|
|
|
|
|
Attributes: |
|
|
registry: Reference to the ModelRegistry for model access |
|
|
|
|
|
Example: |
|
|
```python |
|
|
registry = ModelRegistry() |
|
|
registry.load_all_models(Path("./models")) |
|
|
|
|
|
predictor = ModelPredictor(registry) |
|
|
predictions = predictor.predict( |
|
|
texts=["# Calculate sum of list"], |
|
|
language="python", |
|
|
model_type="catboost" |
|
|
) |
|
|
# predictions: np.ndarray with shape (1, num_labels) |
|
|
``` |
|
|
""" |
|
|
|
|
|
def __init__(self, model_registry: ModelRegistry) -> None: |
|
|
"""Initialize the predictor with a model registry. |
|
|
|
|
|
Args: |
|
|
model_registry: The ModelRegistry instance containing loaded models |
|
|
""" |
|
|
self.registry = model_registry |
|
|
|
|
|
def predict( |
|
|
self, texts: List[str], language: str, model_type: str |
|
|
) -> Tuple[np.ndarray, Optional[np.ndarray]]: |
|
|
"""Execute prediction on a list of texts. |
|
|
|
|
|
This method handles the full inference pipeline: |
|
|
1. Retrieve the appropriate model from the registry |
|
|
2. Extract features (e.g., generate embeddings) |
|
|
3. Run model inference |
|
|
4. Return raw predictions |
|
|
|
|
|
Args: |
|
|
texts: List of code comment strings to classify |
|
|
language: Programming language context for model selection |
|
|
model_type: Type of model to use |
|
|
|
|
|
Returns: |
|
|
Tuple containing: |
|
|
- numpy array of predictions with shape (n_samples, n_labels). |
|
|
- numpy array of embeddings (if available, else None). |
|
|
|
|
|
Raises: |
|
|
ValueError: If the requested model is not available or |
|
|
if an unsupported feature/model type is specified |
|
|
""" |
|
|
|
|
|
model_entry = self.registry.get_model(language, model_type) |
|
|
if not model_entry or "model" not in model_entry: |
|
|
raise ValueError(f"Model {model_type} not available for {language}") |
|
|
|
|
|
model = model_entry["model"] |
|
|
|
|
|
|
|
|
if model_type == "catboost": |
|
|
if model_entry.get("feature_type") == "embeddings": |
|
|
|
|
|
embedder = model_entry.get("embedder") |
|
|
if embedder is None: |
|
|
raise ValueError(f"Embedder not loaded for {language}") |
|
|
|
|
|
|
|
|
embeddings = embedder.encode(texts, show_progress_bar=False) |
|
|
|
|
|
|
|
|
return model.predict(embeddings), embeddings |
|
|
|
|
|
raise ValueError("Unsupported feature type for CatBoost") |
|
|
|
|
|
raise ValueError(f"Unknown model type: {model_type}") |
|
|
|