File size: 6,865 Bytes
713632e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bdb4b9
 
713632e
 
 
7bdb4b9
713632e
 
 
7bdb4b9
713632e
 
 
 
 
7bdb4b9
 
713632e
 
 
 
7bdb4b9
713632e
 
7bdb4b9
 
713632e
 
 
 
 
7bdb4b9
 
713632e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5497e51
7bdb4b9
5497e51
713632e
 
 
 
 
 
7bdb4b9
 
713632e
 
 
 
 
 
7bdb4b9
713632e
 
 
 
 
 
 
 
 
 
7bdb4b9
713632e
 
 
 
 
7bdb4b9
 
713632e
 
 
 
 
7bdb4b9
 
 
 
713632e
 
 
 
7bdb4b9
713632e
 
7bdb4b9
 
713632e
 
df25ba9
713632e
 
 
 
 
 
 
 
 
7bdb4b9
713632e
 
 
df25ba9
713632e
 
 
 
 
 
 
 
df25ba9
713632e
7bdb4b9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""Controller Layer - Business logic for prediction operations.

This module implements the business logic layer following the MVC pattern.
It acts as an intermediary between the API endpoints (views) and the
ML models (models layer), handling:
- Model lifecycle management (loading/unloading)
- Request validation and preprocessing
- Response formatting and label mapping
- Error handling and logging

The controller is designed to be thread-safe for concurrent access.
"""

import logging
from typing import Any, Dict, List

import numpy as np

from nygaardcodecommentclassification import config
from nygaardcodecommentclassification.api.models import ModelPredictor, ModelRegistry

# Configure module logger
logger = logging.getLogger("controllers")


class PredictionController:
    """Manages prediction logic, model lifecycle, and response formatting.

    This controller orchestrates the ML prediction pipeline, including:
    - Loading and managing ML models via ModelRegistry
    - Validating prediction requests against supported languages/models
    - Executing predictions through ModelPredictor
    - Mapping numeric predictions to human-readable labels

    Attributes:
        registry: ModelRegistry instance for model storage
        predictor: ModelPredictor instance for inference

    Example:
        ```python
        controller = PredictionController()
        controller.startup()  # Load models from MLflow

        results = controller.predict(
            texts=["# Calculate sum"],
            class_names=["Utils"],
            language="python",
            model_type="catboost"
        )
        # results: [{"text": "# Calculate sum", "class_name": "Utils", "labels": ["summary"]}]

        controller.shutdown()  # Release resources
        ```
    """

    def __init__(self) -> None:
        """Initialize the prediction controller."""
        self.registry = ModelRegistry()
        self.predictor = ModelPredictor(self.registry)

    def startup(self) -> None:
        """Load all ML models into memory from MLflow.

        This method should be called during application startup.
        It connects to the MLflow tracking server and loads all available
        models into the registry for fast inference.

        Note:
            This operation may take several seconds depending on
            the number and size of models.
        """
        logger.info("Loading models from MLflow...")
        self.registry.load_all_models()
        logger.info("Models loaded successfully")

    def shutdown(self) -> None:
        """Release all model resources.

        Clears the model registry and frees GPU memory if applicable.
        This should be called during application shutdown.
        """
        self.registry.clear()
        logger.info("Models cleared and resources released")

    def get_models_info(self) -> Dict[str, List[str]]:
        """Return available models grouped by programming language.

        Returns:
            Dict mapping language codes to lists of available model types.
            Example: {"java": ["catboost"], "python": ["catboost"], "pharo": ["catboost"]}
        """
        info: Dict[str, List[str]] = {}
        for lang in config.LANGUAGES:
            # Currently only CatBoost models are supported
            info[lang] = ["catboost"]
        return info

    def predict(
        self, texts: List[str], class_names: List[str], language: str, model_type: str
    ) -> List[Dict[str, Any]]:
        """Execute multi-label classification on code comments.

        This method validates the request, runs ML inference, and formats
        the results with human-readable labels.

        Args:
            texts: List of code comment strings
            class_names: List of class names corresponding to each comment
            language: Programming language context ("java", "python", "pharo")
            model_type: Type of model to use ("catboost")

        Returns:
            List of dicts with classification results. Each dict contains:
            - "text": The original input text
            - "class_name": The class name corresponding to the input text
            - "labels": List of predicted category labels (strings)

        Raises:
            ValueError: If language is not supported or model type unavailable
            RuntimeError: If prediction fails or labels configuration is missing

        Example:
            ```python
            results = controller.predict(
                texts=["This calculates fibonacci", "TODO: optimize"],
                class_names=["MathUtils", "Calculator"],
                language="python",
                model_type="catboost"
            )
            # Returns:
            # [
            #     {"text": "This calculates fibonacci", "class_name": "MathUtils", "labels": ["summary"]},
            #     {"text": "TODO: optimize", "class_name": "Calculator", "labels": ["expand"]}
            # ]
            ```
        """
        # --- Request Validation ---
        if language not in config.LANGUAGES:
            raise ValueError(f"Language '{language}' not supported. Available: {config.LANGUAGES}")

        if len(texts) != len(class_names):
            raise ValueError(f"Mismatch: {len(texts)} texts but {len(class_names)} class names")

        available_types = ["catboost"]  # Currently only CatBoost is supported
        if model_type not in available_types:
            raise ValueError(
                f"Model '{model_type}' unavailable for {language}. Available: {available_types}"
            )

        combined_texts = [f"{text} | {class_name}" for text, class_name in zip(texts, class_names)]

        # --- Model Inference ---
        try:
            y_pred, embeddings = self.predictor.predict(combined_texts, language, model_type)
        except Exception as e:
            logger.error("Prediction failed for %s/%s: %s", language, model_type, e)
            raise RuntimeError(f"Internal model error: {e}") from e

        # --- Result Formatting ---
        # Get the label mapping for this language
        try:
            labels_map = config.LABELS[language]
        except KeyError as e:
            raise RuntimeError(f"Configuration error: Labels map missing for {language}") from e

        # Convert numeric predictions to human-readable labels
        results: List[Dict[str, Any]] = []
        for i, text_input in enumerate(texts):
            row_pred = y_pred[i]  # Binary array (1 = label present, 0 = absent)

            # Find indices where prediction is 1 (positive class)
            predicted_indices = np.where(row_pred == 1)[0]

            # Map indices to label strings
            predicted_labels = [labels_map[idx] for idx in predicted_indices]

            results.append({"text": text_input, "labels": predicted_labels})

        return results