Davy592's picture
Added patch to run models on CPU
4617060
"""Model Layer - ML model management and inference.
This module handles the low-level ML operations including:
- Model loading and storage via ModelRegistry
- Inference execution via ModelPredictor
Architecture:
- ModelRegistry: Central storage for loaded models with lazy loading
- ModelPredictor: Executes inference using registered models
"""
import logging
import os
from pathlib import Path
import sys
from typing import Any, Dict, List, Optional, Tuple
import dagshub
import mlflow
import numpy as np
import torch
from nygaardcodecommentclassification import config
# Patch torch.load to use CPU mapping by default if CUDA is not available
# This prevents "Attempting to deserialize object on a CUDA device" errors
_original_torch_load = torch.load
def _patched_torch_load(f, map_location=None, *args, **kwargs):
"""Wrapper around torch.load that uses CPU mapping if CUDA unavailable."""
if map_location is None and not torch.cuda.is_available():
map_location = torch.device('cpu')
return _original_torch_load(f, map_location=map_location, *args, **kwargs)
torch.load = _patched_torch_load
# Configure module logger with explicit handler to ensure visibility
logger = logging.getLogger("nygaard.models")
logger.setLevel(logging.DEBUG)
# Avoid duplicate handlers if module is reloaded
if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter("[%(levelname)s] %(name)s: %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
class ModelRegistry:
"""Central registry for ML models loaded in memory.
This class manages the lifecycle of ML models, providing:
- Automatic discovery and loading of models from the filesystem
- Organized storage by language and model type
- Memory management with explicit cleanup
Attributes:
_registry: Internal dictionary storing loaded models
Example:
```python
registry = ModelRegistry()
registry.load_all_models(Path("./models"))
# Access a loaded model
model_entry = registry.get_model("python", "catboost")
if model_entry:
model = model_entry["model"]
embedder = model_entry.get("embedder")
```
"""
def __init__(self) -> None:
"""Initialize an empty model registry."""
self._registry: Dict[str, Dict[str, Any]] = {}
def load_all_models(self) -> None:
"""Load all ML models from MLflow tracking server.
This method connects to the MLflow tracking server (DagsHub) and loads
CatBoost classifiers and sentence transformer embedders for all
configured languages.
Environment Variables:
DAGSHUB_USER_TOKEN: Authentication token for DagsHub/MLflow
Note:
- Continues loading other models if one fails
- Logs all loading activity for debugging
"""
logger.info("Starting to load all models from MLflow")
# Initialize MLflow with DagsHub - uses DAGSHUB_USER_TOKEN env var for auth
# Set DAGSHUB_USER_TOKEN in your environment to avoid interactive login
dagshub_token = os.environ.get("DAGSHUB_USER_TOKEN")
if dagshub_token:
os.environ["MLFLOW_TRACKING_USERNAME"] = dagshub_token
os.environ["MLFLOW_TRACKING_PASSWORD"] = dagshub_token
logger.info("Using DAGSHUB_USER_TOKEN for authentication")
else:
logger.warning("DAGSHUB_USER_TOKEN not set - may require interactive login")
dagshub.init(repo_owner="se4ai2526-uniba", repo_name="Nygaard", mlflow=True)
mlflow.set_experiment("evaluating")
# Load models for all configured languages directly from MLflow
# No need for local directory structure
for lang in config.LANGUAGES:
logger.info("Loading models for language: %s", lang)
if lang not in self._registry:
self._registry[lang] = {}
self._load_catboost_models(lang)
logger.info("Finished loading all models from MLflow")
def _load_catboost_models(self, lang: str) -> None:
"""Load CatBoost models for a specific language from MLflow.
Downloads and loads the CatBoost classifier and sentence transformer
embedder directly from MLflow tracking server.
Args:
lang: The programming language code (e.g., "python", "java")
"""
# Find the CatBoost run
catboost_runs = mlflow.search_runs(
experiment_names=["evaluating"], filter_string="tags.model = 'catboost'"
).sort_values(by="metrics.final_score", ascending=False)
if catboost_runs.empty:
logger.error("No CatBoost run found in 'evaluating' experiment")
return
catboost_run = catboost_runs.iloc[0]
catboost_run_id = catboost_run.run_id
catboost_run_name = catboost_run.get("tags.mlflow.runName", "unknown")
catboost_git_commit = catboost_run.get("tags.mlflow.source.git.commit")
logger.info(
"Found CatBoost run: '%s' (ID: %s, commit: %s)",
catboost_run_name,
catboost_run_id,
catboost_git_commit,
)
# Find the embedder run with same git commit and source file
embedder_run = None
embedder_run_id = None
embedder_run_name = None
if catboost_git_commit:
# Search for sentence transformer with same git commit
logger.info(
"[%s] Searching for embedder with git commit: %s",
lang.upper(),
catboost_git_commit,
)
embedder_runs = mlflow.search_runs(
experiment_names=["evaluating"],
filter_string=f"tags.`mlflow.source.git.commit` = '{catboost_git_commit}' and run_name LIKE 'sentence_transformer%'",
)
if not embedder_runs.empty:
embedder_run = embedder_runs.iloc[0]
embedder_run_id = embedder_run.run_id
embedder_run_name = embedder_run.get("tags.mlflow.runName", "unknown")
logger.info(
"[%s] Found embedder with matching git commit: '%s' (ID: %s)",
lang.upper(),
embedder_run_name,
embedder_run_id,
)
# Fallback: search by default name if git commit search failed
if not embedder_run_id:
logger.info(
"[%s] Falling back to default embedder search",
lang.upper(),
)
embedder_runs = mlflow.search_runs(
experiment_names=["evaluating"],
filter_string="run_name = 'sentence_transformer_paraphrase-MiniLM-L6-v2'",
)
if embedder_runs.empty:
logger.error(
"No embedder run found for 'sentence_transformer_paraphrase-MiniLM-L6-v2'"
)
return
embedder_run = embedder_runs.iloc[0]
embedder_run_id = embedder_run.run_id
embedder_run_name = embedder_run.get("tags.mlflow.runName", "unknown")
logger.info(
"Found Embedder run: '%s' (ID: %s)",
embedder_run_name,
embedder_run_id,
)
try:
# Load the CatBoost model from MLflow
model_uri = f"runs:/{catboost_run_id}/model_{lang}"
logger.info(
"[%s] Loading CatBoost classifier from run '%s' (ID: %s)...",
lang.upper(),
catboost_run_name,
catboost_run_id,
)
model = mlflow.sklearn.load_model(model_uri)
# Load the sentence transformer embedder from MLflow
embedder_uri = f"runs:/{embedder_run_id}/model_{lang}"
logger.info(
"[%s] Loading sentence transformer from run '%s' (ID: %s)...",
lang.upper(),
embedder_run_name,
embedder_run_id,
)
embedder = mlflow.sklearn.load_model(embedder_uri)
# Register the model with its metadata
self._registry[lang]["catboost"] = {
"model": model,
"feature_type": "embeddings",
"embedder": embedder,
}
logger.info(
"[%s] ✓ Ready: CatBoost + %s embeddings",
lang.upper(),
embedder_run_name.replace("sentence_transformer_", ""),
)
except Exception as e:
logger.error("[%s] Error loading models: %s", lang.upper(), e)
def get_model(self, language: str, model_type: str) -> Optional[Dict[str, Any]]:
"""Retrieve a loaded model entry by language and type.
Args:
language: The programming language code
model_type: The type of model
Returns:
Dict containing the model and metadata, or None if not found.
The dict contains:
- "model": The loaded ML model object
- "feature_type": Type of features used
- "embedder": Optional sentence transformer for embedding generation
"""
return self._registry.get(language, {}).get(model_type)
def clear(self) -> None:
"""Clear all models from the registry and free memory.
This method should be called during application shutdown to
release GPU memory and other resources.
"""
self._registry.clear()
# Clear CUDA cache if GPU was used
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("CUDA cache cleared")
class ModelPredictor:
"""Handles low-level prediction logic.
Attributes:
registry: Reference to the ModelRegistry for model access
Example:
```python
registry = ModelRegistry()
registry.load_all_models(Path("./models"))
predictor = ModelPredictor(registry)
predictions = predictor.predict(
texts=["# Calculate sum of list"],
language="python",
model_type="catboost"
)
# predictions: np.ndarray with shape (1, num_labels)
```
"""
def __init__(self, model_registry: ModelRegistry) -> None:
"""Initialize the predictor with a model registry.
Args:
model_registry: The ModelRegistry instance containing loaded models
"""
self.registry = model_registry
def predict(
self, texts: List[str], language: str, model_type: str
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""Execute prediction on a list of texts.
This method handles the full inference pipeline:
1. Retrieve the appropriate model from the registry
2. Extract features (e.g., generate embeddings)
3. Run model inference
4. Return raw predictions
Args:
texts: List of code comment strings to classify
language: Programming language context for model selection
model_type: Type of model to use
Returns:
Tuple containing:
- numpy array of predictions with shape (n_samples, n_labels).
- numpy array of embeddings (if available, else None).
Raises:
ValueError: If the requested model is not available or
if an unsupported feature/model type is specified
"""
# Retrieve model entry from registry
model_entry = self.registry.get_model(language, model_type)
if not model_entry or "model" not in model_entry:
raise ValueError(f"Model {model_type} not available for {language}")
model = model_entry["model"]
# Handle CatBoost models
if model_type == "catboost":
if model_entry.get("feature_type") == "embeddings":
# Generate embeddings using the SetFit sentence transformer
embedder = model_entry.get("embedder")
if embedder is None:
raise ValueError(f"Embedder not loaded for {language}")
# Encode texts to dense embeddings (no progress bar for API use)
embeddings = embedder.encode(texts, show_progress_bar=False)
# Run CatBoost prediction on embeddings
return model.predict(embeddings), embeddings
raise ValueError("Unsupported feature type for CatBoost")
raise ValueError(f"Unknown model type: {model_type}")