|
|
"""Controller Layer - Business logic for prediction operations. |
|
|
|
|
|
This module implements the business logic layer following the MVC pattern. |
|
|
It acts as an intermediary between the API endpoints (views) and the |
|
|
ML models (models layer), handling: |
|
|
- Model lifecycle management (loading/unloading) |
|
|
- Request validation and preprocessing |
|
|
- Response formatting and label mapping |
|
|
- Error handling and logging |
|
|
|
|
|
The controller is designed to be thread-safe for concurrent access. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from typing import Any, Dict, List |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from nygaardcodecommentclassification import config |
|
|
from nygaardcodecommentclassification.api.models import ModelPredictor, ModelRegistry |
|
|
|
|
|
|
|
|
logger = logging.getLogger("controllers") |
|
|
|
|
|
|
|
|
class PredictionController: |
|
|
"""Manages prediction logic, model lifecycle, and response formatting. |
|
|
|
|
|
This controller orchestrates the ML prediction pipeline, including: |
|
|
- Loading and managing ML models via ModelRegistry |
|
|
- Validating prediction requests against supported languages/models |
|
|
- Executing predictions through ModelPredictor |
|
|
- Mapping numeric predictions to human-readable labels |
|
|
|
|
|
Attributes: |
|
|
registry: ModelRegistry instance for model storage |
|
|
predictor: ModelPredictor instance for inference |
|
|
|
|
|
Example: |
|
|
```python |
|
|
controller = PredictionController() |
|
|
controller.startup() # Load models from MLflow |
|
|
|
|
|
results = controller.predict( |
|
|
texts=["# Calculate sum"], |
|
|
class_names=["Utils"], |
|
|
language="python", |
|
|
model_type="catboost" |
|
|
) |
|
|
# results: [{"text": "# Calculate sum", "class_name": "Utils", "labels": ["summary"]}] |
|
|
|
|
|
controller.shutdown() # Release resources |
|
|
``` |
|
|
""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
"""Initialize the prediction controller.""" |
|
|
self.registry = ModelRegistry() |
|
|
self.predictor = ModelPredictor(self.registry) |
|
|
|
|
|
def startup(self) -> None: |
|
|
"""Load all ML models into memory from MLflow. |
|
|
|
|
|
This method should be called during application startup. |
|
|
It connects to the MLflow tracking server and loads all available |
|
|
models into the registry for fast inference. |
|
|
|
|
|
Note: |
|
|
This operation may take several seconds depending on |
|
|
the number and size of models. |
|
|
""" |
|
|
logger.info("Loading models from MLflow...") |
|
|
self.registry.load_all_models() |
|
|
logger.info("Models loaded successfully") |
|
|
|
|
|
def shutdown(self) -> None: |
|
|
"""Release all model resources. |
|
|
|
|
|
Clears the model registry and frees GPU memory if applicable. |
|
|
This should be called during application shutdown. |
|
|
""" |
|
|
self.registry.clear() |
|
|
logger.info("Models cleared and resources released") |
|
|
|
|
|
def get_models_info(self) -> Dict[str, List[str]]: |
|
|
"""Return available models grouped by programming language. |
|
|
|
|
|
Returns: |
|
|
Dict mapping language codes to lists of available model types. |
|
|
Example: {"java": ["catboost"], "python": ["catboost"], "pharo": ["catboost"]} |
|
|
""" |
|
|
info: Dict[str, List[str]] = {} |
|
|
for lang in config.LANGUAGES: |
|
|
|
|
|
info[lang] = ["catboost"] |
|
|
return info |
|
|
|
|
|
def predict( |
|
|
self, texts: List[str], class_names: List[str], language: str, model_type: str |
|
|
) -> List[Dict[str, Any]]: |
|
|
"""Execute multi-label classification on code comments. |
|
|
|
|
|
This method validates the request, runs ML inference, and formats |
|
|
the results with human-readable labels. |
|
|
|
|
|
Args: |
|
|
texts: List of code comment strings |
|
|
class_names: List of class names corresponding to each comment |
|
|
language: Programming language context ("java", "python", "pharo") |
|
|
model_type: Type of model to use ("catboost") |
|
|
|
|
|
Returns: |
|
|
List of dicts with classification results. Each dict contains: |
|
|
- "text": The original input text |
|
|
- "class_name": The class name corresponding to the input text |
|
|
- "labels": List of predicted category labels (strings) |
|
|
|
|
|
Raises: |
|
|
ValueError: If language is not supported or model type unavailable |
|
|
RuntimeError: If prediction fails or labels configuration is missing |
|
|
|
|
|
Example: |
|
|
```python |
|
|
results = controller.predict( |
|
|
texts=["This calculates fibonacci", "TODO: optimize"], |
|
|
class_names=["MathUtils", "Calculator"], |
|
|
language="python", |
|
|
model_type="catboost" |
|
|
) |
|
|
# Returns: |
|
|
# [ |
|
|
# {"text": "This calculates fibonacci", "class_name": "MathUtils", "labels": ["summary"]}, |
|
|
# {"text": "TODO: optimize", "class_name": "Calculator", "labels": ["expand"]} |
|
|
# ] |
|
|
``` |
|
|
""" |
|
|
|
|
|
if language not in config.LANGUAGES: |
|
|
raise ValueError(f"Language '{language}' not supported. Available: {config.LANGUAGES}") |
|
|
|
|
|
if len(texts) != len(class_names): |
|
|
raise ValueError(f"Mismatch: {len(texts)} texts but {len(class_names)} class names") |
|
|
|
|
|
available_types = ["catboost"] |
|
|
if model_type not in available_types: |
|
|
raise ValueError( |
|
|
f"Model '{model_type}' unavailable for {language}. Available: {available_types}" |
|
|
) |
|
|
|
|
|
combined_texts = [f"{text} | {class_name}" for text, class_name in zip(texts, class_names)] |
|
|
|
|
|
|
|
|
try: |
|
|
y_pred, embeddings = self.predictor.predict(combined_texts, language, model_type) |
|
|
except Exception as e: |
|
|
logger.error("Prediction failed for %s/%s: %s", language, model_type, e) |
|
|
raise RuntimeError(f"Internal model error: {e}") from e |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
labels_map = config.LABELS[language] |
|
|
except KeyError as e: |
|
|
raise RuntimeError(f"Configuration error: Labels map missing for {language}") from e |
|
|
|
|
|
|
|
|
results: List[Dict[str, Any]] = [] |
|
|
for i, text_input in enumerate(texts): |
|
|
row_pred = y_pred[i] |
|
|
|
|
|
|
|
|
predicted_indices = np.where(row_pred == 1)[0] |
|
|
|
|
|
|
|
|
predicted_labels = [labels_map[idx] for idx in predicted_indices] |
|
|
|
|
|
results.append({"text": text_input, "labels": predicted_labels}) |
|
|
|
|
|
return results |
|
|
|