Fonty02 commited on
Commit
7bdb4b9
·
verified ·
1 Parent(s): c2d3410

Update nygaardcodecommentclassification/api/controllers.py

Browse files
nygaardcodecommentclassification/api/controllers.py CHANGED
@@ -12,7 +12,6 @@ The controller is designed to be thread-safe for concurrent access.
12
  """
13
 
14
  import logging
15
- from pathlib import Path
16
  from typing import Any, Dict, List
17
 
18
  import numpy as np
@@ -36,48 +35,42 @@ class PredictionController:
36
  Attributes:
37
  registry: ModelRegistry instance for model storage
38
  predictor: ModelPredictor instance for inference
39
- models_dir: Path to the directory containing trained models
40
 
41
  Example:
42
  ```python
43
- controller = PredictionController(Path("./models"))
44
- controller.startup() # Load models
45
 
46
  results = controller.predict(
47
  texts=["# Calculate sum"],
 
48
  language="python",
49
  model_type="catboost"
50
  )
51
- # results: [{"text": "# Calculate sum", "labels": ["summary"]}]
52
 
53
  controller.shutdown() # Release resources
54
  ```
55
  """
56
 
57
- def __init__(self, models_dir: Path) -> None:
58
- """Initialize the prediction controller.
59
-
60
- Args:
61
- models_dir: Path to directory containing trained ML models.
62
- Expected structure: models_dir/{language}/{feature_type}/
63
- """
64
  self.registry = ModelRegistry()
65
  self.predictor = ModelPredictor(self.registry)
66
- self.models_dir = models_dir
67
 
68
  def startup(self) -> None:
69
- """Load all ML models into memory.
70
 
71
  This method should be called during application startup.
72
- It scans the models directory and loads all available models
73
- into the registry for fast inference.
74
 
75
  Note:
76
  This operation may take several seconds depending on
77
  the number and size of models.
78
  """
79
- logger.info("Loading models from %s...", self.models_dir)
80
- self.registry.load_all_models(self.models_dir)
81
  logger.info("Models loaded successfully")
82
 
83
  def shutdown(self) -> None:
@@ -103,10 +96,7 @@ class PredictionController:
103
  return info
104
 
105
  def predict(
106
- self,
107
- texts: List[str],
108
- language: str,
109
- model_type: str
110
  ) -> List[Dict[str, Any]]:
111
  """Execute multi-label classification on code comments.
112
 
@@ -114,13 +104,15 @@ class PredictionController:
114
  the results with human-readable labels.
115
 
116
  Args:
117
- texts: List of code comment strings to classify
 
118
  language: Programming language context ("java", "python", "pharo")
119
  model_type: Type of model to use ("catboost")
120
 
121
  Returns:
122
  List of dicts with classification results. Each dict contains:
123
  - "text": The original input text
 
124
  - "labels": List of predicted category labels (strings)
125
 
126
  Raises:
@@ -131,33 +123,37 @@ class PredictionController:
131
  ```python
132
  results = controller.predict(
133
  texts=["This calculates fibonacci", "TODO: optimize"],
 
134
  language="python",
135
  model_type="catboost"
136
  )
137
  # Returns:
138
  # [
139
- # {"text": "This calculates fibonacci", "labels": ["summary"]},
140
- # {"text": "TODO: optimize", "labels": ["expand"]}
141
  # ]
142
  ```
143
  """
144
  # --- Request Validation ---
145
  if language not in config.LANGUAGES:
146
- raise ValueError(
147
- f"Language '{language}' not supported. "
148
- f"Available: {config.LANGUAGES}"
149
- )
150
 
151
  available_types = ["catboost"] # Currently only CatBoost is supported
152
  if model_type not in available_types:
153
  raise ValueError(
154
- f"Model '{model_type}' unavailable for {language}. "
155
- f"Available: {available_types}"
156
  )
157
 
 
 
 
 
158
  # --- Model Inference ---
159
  try:
160
- y_pred = self.predictor.predict(texts, language, model_type)
161
  except Exception as e:
162
  logger.error("Prediction failed for %s/%s: %s", language, model_type, e)
163
  raise RuntimeError(f"Internal model error: {e}") from e
@@ -167,13 +163,11 @@ class PredictionController:
167
  try:
168
  labels_map = config.LABELS[language]
169
  except KeyError as e:
170
- raise RuntimeError(
171
- f"Configuration error: Labels map missing for {language}"
172
- ) from e
173
 
174
  # Convert numeric predictions to human-readable labels
175
  results: List[Dict[str, Any]] = []
176
- for i, text_input in enumerate(texts):
177
  row_pred = y_pred[i] # Binary array (1 = label present, 0 = absent)
178
 
179
  # Find indices where prediction is 1 (positive class)
@@ -182,9 +176,8 @@ class PredictionController:
182
  # Map indices to label strings
183
  predicted_labels = [labels_map[idx] for idx in predicted_indices]
184
 
185
- results.append({
186
- "text": text_input,
187
- "labels": predicted_labels
188
- })
189
 
190
- return results
 
12
  """
13
 
14
  import logging
 
15
  from typing import Any, Dict, List
16
 
17
  import numpy as np
 
35
  Attributes:
36
  registry: ModelRegistry instance for model storage
37
  predictor: ModelPredictor instance for inference
 
38
 
39
  Example:
40
  ```python
41
+ controller = PredictionController()
42
+ controller.startup() # Load models from MLflow
43
 
44
  results = controller.predict(
45
  texts=["# Calculate sum"],
46
+ class_names=["Utils"],
47
  language="python",
48
  model_type="catboost"
49
  )
50
+ # results: [{"text": "# Calculate sum", "class_name": "Utils", "labels": ["summary"]}]
51
 
52
  controller.shutdown() # Release resources
53
  ```
54
  """
55
 
56
+ def __init__(self) -> None:
57
+ """Initialize the prediction controller."""
 
 
 
 
 
58
  self.registry = ModelRegistry()
59
  self.predictor = ModelPredictor(self.registry)
 
60
 
61
  def startup(self) -> None:
62
+ """Load all ML models into memory from MLflow.
63
 
64
  This method should be called during application startup.
65
+ It connects to the MLflow tracking server and loads all available
66
+ models into the registry for fast inference.
67
 
68
  Note:
69
  This operation may take several seconds depending on
70
  the number and size of models.
71
  """
72
+ logger.info("Loading models from MLflow...")
73
+ self.registry.load_all_models()
74
  logger.info("Models loaded successfully")
75
 
76
  def shutdown(self) -> None:
 
96
  return info
97
 
98
  def predict(
99
+ self, texts: List[str], class_names: List[str], language: str, model_type: str
 
 
 
100
  ) -> List[Dict[str, Any]]:
101
  """Execute multi-label classification on code comments.
102
 
 
104
  the results with human-readable labels.
105
 
106
  Args:
107
+ texts: List of code comment strings
108
+ class_names: List of class names corresponding to each comment
109
  language: Programming language context ("java", "python", "pharo")
110
  model_type: Type of model to use ("catboost")
111
 
112
  Returns:
113
  List of dicts with classification results. Each dict contains:
114
  - "text": The original input text
115
+ - "class_name": The class name corresponding to the input text
116
  - "labels": List of predicted category labels (strings)
117
 
118
  Raises:
 
123
  ```python
124
  results = controller.predict(
125
  texts=["This calculates fibonacci", "TODO: optimize"],
126
+ class_names=["MathUtils", "Calculator"],
127
  language="python",
128
  model_type="catboost"
129
  )
130
  # Returns:
131
  # [
132
+ # {"text": "This calculates fibonacci", "class_name": "MathUtils", "labels": ["summary"]},
133
+ # {"text": "TODO: optimize", "class_name": "Calculator", "labels": ["expand"]}
134
  # ]
135
  ```
136
  """
137
  # --- Request Validation ---
138
  if language not in config.LANGUAGES:
139
+ raise ValueError(f"Language '{language}' not supported. Available: {config.LANGUAGES}")
140
+
141
+ if len(texts) != len(class_names):
142
+ raise ValueError(f"Mismatch: {len(texts)} texts but {len(class_names)} class names")
143
 
144
  available_types = ["catboost"] # Currently only CatBoost is supported
145
  if model_type not in available_types:
146
  raise ValueError(
147
+ f"Model '{model_type}' unavailable for {language}. Available: {available_types}"
 
148
  )
149
 
150
+ # --- Combine texts with class names ---
151
+ # Format: "comment | class_name" for the model
152
+ combined_texts = [f"{text} | {class_name}" for text, class_name in zip(texts, class_names)]
153
+
154
  # --- Model Inference ---
155
  try:
156
+ y_pred = self.predictor.predict(combined_texts, language, model_type)
157
  except Exception as e:
158
  logger.error("Prediction failed for %s/%s: %s", language, model_type, e)
159
  raise RuntimeError(f"Internal model error: {e}") from e
 
163
  try:
164
  labels_map = config.LABELS[language]
165
  except KeyError as e:
166
+ raise RuntimeError(f"Configuration error: Labels map missing for {language}") from e
 
 
167
 
168
  # Convert numeric predictions to human-readable labels
169
  results: List[Dict[str, Any]] = []
170
+ for i, (text_input, class_name) in enumerate(zip(texts, class_names)):
171
  row_pred = y_pred[i] # Binary array (1 = label present, 0 = absent)
172
 
173
  # Find indices where prediction is 1 (positive class)
 
176
  # Map indices to label strings
177
  predicted_labels = [labels_map[idx] for idx in predicted_indices]
178
 
179
+ results.append(
180
+ {"text": text_input, "class_name": class_name, "labels": predicted_labels}
181
+ )
 
182
 
183
+ return results