File size: 6,988 Bytes
713632e 5d6fdad 713632e 5d6fdad 713632e 5d6fdad 713632e 5d6fdad 713632e 5d6fdad 713632e 5d6fdad 713632e 5d6fdad 713632e 5d6fdad 713632e 5d6fdad 713632e 5d6fdad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | """Controller Layer - Business logic for prediction operations.
This module implements the business logic layer following the MVC pattern.
It acts as an intermediary between the API endpoints (views) and the
ML models (models layer), handling:
- Model lifecycle management (loading/unloading)
- Request validation and preprocessing
- Response formatting and label mapping
- Error handling and logging
The controller is designed to be thread-safe for concurrent access.
"""
import logging
from typing import Any, Dict, List
import numpy as np
from nygaardcodecommentclassification import config
from nygaardcodecommentclassification.api.models import ModelPredictor, ModelRegistry
# Configure module logger
logger = logging.getLogger("controllers")
class PredictionController:
"""Manages prediction logic, model lifecycle, and response formatting.
This controller orchestrates the ML prediction pipeline, including:
- Loading and managing ML models via ModelRegistry
- Validating prediction requests against supported languages/models
- Executing predictions through ModelPredictor
- Mapping numeric predictions to human-readable labels
Attributes:
registry: ModelRegistry instance for model storage
predictor: ModelPredictor instance for inference
Example:
```python
controller = PredictionController()
controller.startup() # Load models from MLflow
results = controller.predict(
texts=["# Calculate sum"],
language="python",
model_type="catboost"
)
# results: [{"text": "# Calculate sum", "labels": ["summary"]}]
controller.shutdown() # Release resources
```
"""
def __init__(self) -> None:
"""Initialize the prediction controller."""
self.registry = ModelRegistry()
self.predictor = ModelPredictor(self.registry)
def startup(self) -> None:
"""Load all ML models into memory from MLflow.
This method should be called during application startup.
It connects to the MLflow tracking server and loads all available
models into the registry for fast inference.
Note:
This operation may take several seconds depending on
the number and size of models.
"""
logger.info("Loading models from MLflow...")
self.registry.load_all_models()
logger.info("Models loaded successfully")
def shutdown(self) -> None:
"""Release all model resources.
Clears the model registry and frees GPU memory if applicable.
This should be called during application shutdown.
"""
self.registry.clear()
logger.info("Models cleared and resources released")
def get_models_info(self) -> Dict[str, List[str]]:
"""Return available models grouped by programming language.
Returns:
Dict mapping language codes to lists of available model types.
Example: {"java": ["catboost"], "python": ["catboost"], "pharo": ["catboost"]}
"""
info: Dict[str, List[str]] = {}
for lang in config.LANGUAGES:
# Currently only CatBoost models are supported
info[lang] = ["catboost"]
return info
def predict(self, texts: List[str], language: str, model_type: str) -> List[Dict[str, Any]]:
"""Execute multi-label classification on code comments.
This method validates the request, runs ML inference, and formats
the results with human-readable labels.
Args:
texts: List of code comment strings to classify
language: Programming language context ("java", "python", "pharo")
model_type: Type of model to use ("catboost")
Returns:
List of dicts with classification results. Each dict contains:
- "text": The original input text
- "labels": List of predicted category labels (strings)
Raises:
ValueError: If language is not supported or model type unavailable
RuntimeError: If prediction fails or labels configuration is missing
"""
# --- 1. Robust Request Validation (Case-Insensitive) ---
# Crea una mappa { "python": "Python", "java": "Java" } basata sul config
# Questo permette di trovare la chiave corretta anche se l'input è minuscolo
supported_languages_map = {l.lower(): l for l in config.LANGUAGES}
input_lang_lower = language.lower()
if input_lang_lower not in supported_languages_map:
raise ValueError(f"Language '{language}' not supported. Available: {config.LANGUAGES}")
# Recupera la stringa esatta usata nel config e nel registry (es. "Python" o "python")
canonical_language = supported_languages_map[input_lang_lower]
available_types = ["catboost"] # Currently only CatBoost is supported
if model_type not in available_types:
raise ValueError(
f"Model '{model_type}' unavailable for {language}. Available: {available_types}"
)
# --- 2. Model Inference ---
try:
# Usiamo canonical_language per essere sicuri di matchare la chiave nel Registry
y_pred = self.predictor.predict(texts, canonical_language, model_type)
except Exception as e:
logger.error("Prediction failed for %s/%s: %s", canonical_language, model_type, e)
# Loggo anche le chiavi disponibili nel registry per debug
try:
available_keys = list(self.registry._registry.keys())
logger.error("Debug - Registry keys available: %s", available_keys)
except:
pass
raise RuntimeError(f"Internal model error: {e}") from e
# --- 3. Result Formatting ---
# Get the label mapping using the canonical language key
try:
labels_map = config.LABELS[canonical_language]
except KeyError as e:
raise RuntimeError(f"Configuration error: Labels map missing for {canonical_language}") from e
# Convert numeric predictions to human-readable labels
results: List[Dict[str, Any]] = []
# Se c'è solo un testo, predict potrebbe ritornare un array 1D invece di 2D.
# Assicuriamoci che y_pred sia sempre 2D (n_samples, n_labels)
if y_pred.ndim == 1:
y_pred = y_pred.reshape(1, -1)
for i, text_input in enumerate(texts):
row_pred = y_pred[i] # Binary array (1 = label present, 0 = absent)
# Find indices where prediction is 1 (positive class)
predicted_indices = np.where(row_pred == 1)[0]
# Map indices to label strings
predicted_labels = [labels_map[idx] for idx in predicted_indices]
results.append({"text": text_input, "labels": predicted_labels})
return results
|