Davy592 commited on
Commit
df25ba9
·
1 Parent(s): 992c3c4

Updated files to match new versions

Browse files
nygaardcodecommentclassification/api/API.py CHANGED
@@ -15,11 +15,12 @@ from contextlib import asynccontextmanager
15
  from datetime import datetime
16
  from functools import wraps
17
  from http import HTTPStatus
 
18
  from typing import Any, Callable, Dict
19
 
20
- from fastapi import FastAPI, Request
21
- from fastapi.middleware.cors import CORSMiddleware
22
  from fastapi.responses import RedirectResponse
 
23
 
24
  from nygaardcodecommentclassification import config
25
  from nygaardcodecommentclassification.api.controllers import PredictionController
@@ -35,9 +36,11 @@ controller = PredictionController()
35
  # This prevents blocking the async event loop during model predictions
36
  _executor = ThreadPoolExecutor(max_workers=4)
37
 
 
 
38
 
39
  @asynccontextmanager
40
- async def lifespan(app: FastAPI):
41
  """Async context manager for application lifecycle events.
42
 
43
  This handles:
@@ -48,7 +51,7 @@ async def lifespan(app: FastAPI):
48
  app: The FastAPI application instance
49
 
50
  Yields:
51
- Control back to the application after startup is complete
52
  """
53
  # Startup: load models into memory
54
  controller.startup()
@@ -199,7 +202,9 @@ async def _get_languages(request: Request) -> Dict[str, Any]:
199
 
200
  @app.post("/predict", tags=["Prediction"])
201
  @construct_response
202
- async def _predict(request: Request, payload: PredictionRequest) -> Dict[str, Any]:
 
 
203
  """Classify code comments using multi-label classification.
204
 
205
  This endpoint performs ML inference to classify code comments into
@@ -207,6 +212,7 @@ async def _predict(request: Request, payload: PredictionRequest) -> Dict[str, An
207
 
208
  Args:
209
  request: The FastAPI request object
 
210
  payload: PredictionRequest containing:
211
  - texts: List of code comments
212
  - class_names: List of class names corresponding to each comment
@@ -246,6 +252,13 @@ async def _predict(request: Request, payload: PredictionRequest) -> Dict[str, An
246
  payload.model_type,
247
  )
248
 
 
 
 
 
 
 
 
249
  return {
250
  "status-code": HTTPStatus.OK,
251
  "message": "Prediction successful",
 
15
  from datetime import datetime
16
  from functools import wraps
17
  from http import HTTPStatus
18
+ import json
19
  from typing import Any, Callable, Dict
20
 
21
+ from fastapi import FastAPI, Request, Response
 
22
  from fastapi.responses import RedirectResponse
23
+ from fastapi.middleware.cors import CORSMiddleware
24
 
25
  from nygaardcodecommentclassification import config
26
  from nygaardcodecommentclassification.api.controllers import PredictionController
 
36
  # This prevents blocking the async event loop during model predictions
37
  _executor = ThreadPoolExecutor(max_workers=4)
38
 
39
+ # ---------------------------------------------------------------------------
40
+
41
 
42
  @asynccontextmanager
43
+ async def lifespan(app: FastAPI) -> Any:
44
  """Async context manager for application lifecycle events.
45
 
46
  This handles:
 
51
  app: The FastAPI application instance
52
 
53
  Yields:
54
+ None: Control back to the application after startup is complete
55
  """
56
  # Startup: load models into memory
57
  controller.startup()
 
202
 
203
  @app.post("/predict", tags=["Prediction"])
204
  @construct_response
205
+ async def _predict(
206
+ request: Request, response: Response, payload: PredictionRequest
207
+ ) -> Dict[str, Any]:
208
  """Classify code comments using multi-label classification.
209
 
210
  This endpoint performs ML inference to classify code comments into
 
212
 
213
  Args:
214
  request: The FastAPI request object
215
+ response: The FastAPI response object
216
  payload: PredictionRequest containing:
217
  - texts: List of code comments
218
  - class_names: List of class names corresponding to each comment
 
252
  payload.model_type,
253
  )
254
 
255
+ response.headers["X-model"] = payload.model_type
256
+ response.headers["X-language"] = payload.language
257
+
258
+ # Collect all predicted labels
259
+ all_labels = [label for result in results for label in result["labels"]]
260
+ response.headers["X-predicted-labels"] = json.dumps(all_labels)
261
+
262
  return {
263
  "status-code": HTTPStatus.OK,
264
  "message": "Prediction successful",
nygaardcodecommentclassification/api/controllers.py CHANGED
@@ -147,13 +147,11 @@ class PredictionController:
147
  f"Model '{model_type}' unavailable for {language}. Available: {available_types}"
148
  )
149
 
150
- # --- Combine texts with class names ---
151
- # Format: "comment | class_name" for the model
152
  combined_texts = [f"{text} | {class_name}" for text, class_name in zip(texts, class_names)]
153
 
154
  # --- Model Inference ---
155
  try:
156
- y_pred = self.predictor.predict(combined_texts, language, model_type)
157
  except Exception as e:
158
  logger.error("Prediction failed for %s/%s: %s", language, model_type, e)
159
  raise RuntimeError(f"Internal model error: {e}") from e
@@ -167,7 +165,7 @@ class PredictionController:
167
 
168
  # Convert numeric predictions to human-readable labels
169
  results: List[Dict[str, Any]] = []
170
- for i, (text_input, class_name) in enumerate(zip(texts, class_names)):
171
  row_pred = y_pred[i] # Binary array (1 = label present, 0 = absent)
172
 
173
  # Find indices where prediction is 1 (positive class)
@@ -176,8 +174,6 @@ class PredictionController:
176
  # Map indices to label strings
177
  predicted_labels = [labels_map[idx] for idx in predicted_indices]
178
 
179
- results.append(
180
- {"text": text_input, "class_name": class_name, "labels": predicted_labels}
181
- )
182
 
183
  return results
 
147
  f"Model '{model_type}' unavailable for {language}. Available: {available_types}"
148
  )
149
 
 
 
150
  combined_texts = [f"{text} | {class_name}" for text, class_name in zip(texts, class_names)]
151
 
152
  # --- Model Inference ---
153
  try:
154
+ y_pred, embeddings = self.predictor.predict(combined_texts, language, model_type)
155
  except Exception as e:
156
  logger.error("Prediction failed for %s/%s: %s", language, model_type, e)
157
  raise RuntimeError(f"Internal model error: {e}") from e
 
165
 
166
  # Convert numeric predictions to human-readable labels
167
  results: List[Dict[str, Any]] = []
168
+ for i, text_input in enumerate(texts):
169
  row_pred = y_pred[i] # Binary array (1 = label present, 0 = absent)
170
 
171
  # Find indices where prediction is 1 (positive class)
 
174
  # Map indices to label strings
175
  predicted_labels = [labels_map[idx] for idx in predicted_indices]
176
 
177
+ results.append({"text": text_input, "labels": predicted_labels})
 
 
178
 
179
  return results
nygaardcodecommentclassification/api/models.py CHANGED
@@ -1,7 +1,9 @@
1
  """Model Layer - ML model management and inference.
 
2
  This module handles the low-level ML operations including:
3
  - Model loading and storage via ModelRegistry
4
  - Inference execution via ModelPredictor
 
5
  Architecture:
6
  - ModelRegistry: Central storage for loaded models with lazy loading
7
  - ModelPredictor: Executes inference using registered models
@@ -11,7 +13,7 @@ import logging
11
  import os
12
  from pathlib import Path
13
  import sys
14
- from typing import Any, Dict, List, Optional
15
 
16
  import dagshub
17
  import mlflow
@@ -20,16 +22,6 @@ import torch
20
 
21
  from nygaardcodecommentclassification import config
22
 
23
- # Patch torch.load to use CPU mapping by default if CUDA is not available
24
- # This prevents "Attempting to deserialize object on a CUDA device" errors
25
- _original_torch_load = torch.load
26
- def _patched_torch_load(f, map_location=None, *args, **kwargs):
27
- """Wrapper around torch.load that uses CPU mapping if CUDA unavailable."""
28
- if map_location is None and not torch.cuda.is_available():
29
- map_location = torch.device('cpu')
30
- return _original_torch_load(f, map_location=map_location, *args, **kwargs)
31
- torch.load = _patched_torch_load
32
-
33
  # Configure module logger with explicit handler to ensure visibility
34
  logger = logging.getLogger("nygaard.models")
35
  logger.setLevel(logging.DEBUG)
@@ -44,16 +36,20 @@ if not logger.handlers:
44
 
45
  class ModelRegistry:
46
  """Central registry for ML models loaded in memory.
 
47
  This class manages the lifecycle of ML models, providing:
48
  - Automatic discovery and loading of models from the filesystem
49
  - Organized storage by language and model type
50
  - Memory management with explicit cleanup
 
51
  Attributes:
52
  _registry: Internal dictionary storing loaded models
 
53
  Example:
54
  ```python
55
  registry = ModelRegistry()
56
  registry.load_all_models(Path("./models"))
 
57
  # Access a loaded model
58
  model_entry = registry.get_model("python", "catboost")
59
  if model_entry:
@@ -68,11 +64,14 @@ class ModelRegistry:
68
 
69
  def load_all_models(self) -> None:
70
  """Load all ML models from MLflow tracking server.
 
71
  This method connects to the MLflow tracking server (DagsHub) and loads
72
  CatBoost classifiers and sentence transformer embedders for all
73
  configured languages.
 
74
  Environment Variables:
75
  DAGSHUB_USER_TOKEN: Authentication token for DagsHub/MLflow
 
76
  Note:
77
  - Continues loading other models if one fails
78
  - Logs all loading activity for debugging
@@ -102,66 +101,104 @@ class ModelRegistry:
102
 
103
  def _load_catboost_models(self, lang: str) -> None:
104
  """Load CatBoost models for a specific language from MLflow.
 
105
  Downloads and loads the CatBoost classifier and sentence transformer
106
  embedder directly from MLflow tracking server.
 
107
  Args:
108
  lang: The programming language code (e.g., "python", "java")
109
  """
110
  # Find the CatBoost run
111
  catboost_runs = mlflow.search_runs(
112
  experiment_names=["evaluating"], filter_string="tags.model = 'catboost'"
113
- )
114
  if catboost_runs.empty:
115
  logger.error("No CatBoost run found in 'evaluating' experiment")
116
  return
117
  catboost_run = catboost_runs.iloc[0]
118
  catboost_run_id = catboost_run.run_id
119
  catboost_run_name = catboost_run.get("tags.mlflow.runName", "unknown")
 
 
120
  logger.info(
121
- "Found CatBoost run: '%s' (started: %s)",
122
  catboost_run_name,
123
- catboost_run.get("start_time", "unknown"),
 
124
  )
125
 
126
- # Find the embedder run
127
- embedder_runs = mlflow.search_runs(
128
- experiment_names=["evaluating"],
129
- filter_string="run_name = 'sentence_transformer_paraphrase-MiniLM-L6-v2'",
130
- )
131
- if embedder_runs.empty:
132
- logger.error(
133
- "No embedder run found for 'sentence_transformer_paraphrase-MiniLM-L6-v2'"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  )
135
- return
136
- embedder_run = embedder_runs.iloc[0]
137
- embedder_run_id = embedder_run.run_id
138
- embedder_run_name = embedder_run.get("tags.mlflow.runName", "unknown")
139
- logger.info(
140
- "Found Embedder run: '%s' (started: %s)",
141
- embedder_run_name,
142
- embedder_run.get("start_time", "unknown"),
143
- )
144
 
145
  try:
146
  # Load the CatBoost model from MLflow
147
  model_uri = f"runs:/{catboost_run_id}/model_{lang}"
148
  logger.info(
149
- "[%s] Loading CatBoost classifier from run '%s'...",
150
  lang.upper(),
151
  catboost_run_name,
 
152
  )
153
  model = mlflow.sklearn.load_model(model_uri)
154
 
155
  # Load the sentence transformer embedder from MLflow
156
  embedder_uri = f"runs:/{embedder_run_id}/model_{lang}"
157
  logger.info(
158
- "[%s] Loading sentence transformer from run '%s'...",
159
  lang.upper(),
160
  embedder_run_name,
 
161
  )
162
  embedder = mlflow.sklearn.load_model(embedder_uri)
163
- if hasattr(embedder, "to"):
164
- embedder.to("cpu")
165
 
166
  # Register the model with its metadata
167
  self._registry[lang]["catboost"] = {
@@ -180,9 +217,11 @@ class ModelRegistry:
180
 
181
  def get_model(self, language: str, model_type: str) -> Optional[Dict[str, Any]]:
182
  """Retrieve a loaded model entry by language and type.
 
183
  Args:
184
  language: The programming language code
185
  model_type: The type of model
 
186
  Returns:
187
  Dict containing the model and metadata, or None if not found.
188
  The dict contains:
@@ -194,6 +233,7 @@ class ModelRegistry:
194
 
195
  def clear(self) -> None:
196
  """Clear all models from the registry and free memory.
 
197
  This method should be called during application shutdown to
198
  release GPU memory and other resources.
199
  """
@@ -207,12 +247,15 @@ class ModelRegistry:
207
 
208
  class ModelPredictor:
209
  """Handles low-level prediction logic.
 
210
  Attributes:
211
  registry: Reference to the ModelRegistry for model access
 
212
  Example:
213
  ```python
214
  registry = ModelRegistry()
215
  registry.load_all_models(Path("./models"))
 
216
  predictor = ModelPredictor(registry)
217
  predictions = predictor.predict(
218
  texts=["# Calculate sum of list"],
@@ -225,25 +268,33 @@ class ModelPredictor:
225
 
226
  def __init__(self, model_registry: ModelRegistry) -> None:
227
  """Initialize the predictor with a model registry.
 
228
  Args:
229
  model_registry: The ModelRegistry instance containing loaded models
230
  """
231
  self.registry = model_registry
232
 
233
- def predict(self, texts: List[str], language: str, model_type: str) -> np.ndarray:
 
 
234
  """Execute prediction on a list of texts.
 
235
  This method handles the full inference pipeline:
236
  1. Retrieve the appropriate model from the registry
237
  2. Extract features (e.g., generate embeddings)
238
  3. Run model inference
239
  4. Return raw predictions
 
240
  Args:
241
  texts: List of code comment strings to classify
242
  language: Programming language context for model selection
243
  model_type: Type of model to use
 
244
  Returns:
245
- numpy array of predictions with shape (n_samples, n_labels).
246
- For multi-label classification, each row is a binary array.
 
 
247
  Raises:
248
  ValueError: If the requested model is not available or
249
  if an unsupported feature/model type is specified
@@ -267,7 +318,7 @@ class ModelPredictor:
267
  embeddings = embedder.encode(texts, show_progress_bar=False)
268
 
269
  # Run CatBoost prediction on embeddings
270
- return model.predict(embeddings)
271
 
272
  raise ValueError("Unsupported feature type for CatBoost")
273
 
 
1
  """Model Layer - ML model management and inference.
2
+
3
  This module handles the low-level ML operations including:
4
  - Model loading and storage via ModelRegistry
5
  - Inference execution via ModelPredictor
6
+
7
  Architecture:
8
  - ModelRegistry: Central storage for loaded models with lazy loading
9
  - ModelPredictor: Executes inference using registered models
 
13
  import os
14
  from pathlib import Path
15
  import sys
16
+ from typing import Any, Dict, List, Optional, Tuple
17
 
18
  import dagshub
19
  import mlflow
 
22
 
23
  from nygaardcodecommentclassification import config
24
 
 
 
 
 
 
 
 
 
 
 
25
  # Configure module logger with explicit handler to ensure visibility
26
  logger = logging.getLogger("nygaard.models")
27
  logger.setLevel(logging.DEBUG)
 
36
 
37
  class ModelRegistry:
38
  """Central registry for ML models loaded in memory.
39
+
40
  This class manages the lifecycle of ML models, providing:
41
  - Automatic discovery and loading of models from the filesystem
42
  - Organized storage by language and model type
43
  - Memory management with explicit cleanup
44
+
45
  Attributes:
46
  _registry: Internal dictionary storing loaded models
47
+
48
  Example:
49
  ```python
50
  registry = ModelRegistry()
51
  registry.load_all_models(Path("./models"))
52
+
53
  # Access a loaded model
54
  model_entry = registry.get_model("python", "catboost")
55
  if model_entry:
 
64
 
65
  def load_all_models(self) -> None:
66
  """Load all ML models from MLflow tracking server.
67
+
68
  This method connects to the MLflow tracking server (DagsHub) and loads
69
  CatBoost classifiers and sentence transformer embedders for all
70
  configured languages.
71
+
72
  Environment Variables:
73
  DAGSHUB_USER_TOKEN: Authentication token for DagsHub/MLflow
74
+
75
  Note:
76
  - Continues loading other models if one fails
77
  - Logs all loading activity for debugging
 
101
 
102
  def _load_catboost_models(self, lang: str) -> None:
103
  """Load CatBoost models for a specific language from MLflow.
104
+
105
  Downloads and loads the CatBoost classifier and sentence transformer
106
  embedder directly from MLflow tracking server.
107
+
108
  Args:
109
  lang: The programming language code (e.g., "python", "java")
110
  """
111
  # Find the CatBoost run
112
  catboost_runs = mlflow.search_runs(
113
  experiment_names=["evaluating"], filter_string="tags.model = 'catboost'"
114
+ ).sort_values(by="metrics.final_score", ascending=False)
115
  if catboost_runs.empty:
116
  logger.error("No CatBoost run found in 'evaluating' experiment")
117
  return
118
  catboost_run = catboost_runs.iloc[0]
119
  catboost_run_id = catboost_run.run_id
120
  catboost_run_name = catboost_run.get("tags.mlflow.runName", "unknown")
121
+ catboost_git_commit = catboost_run.get("tags.mlflow.source.git.commit")
122
+
123
  logger.info(
124
+ "Found CatBoost run: '%s' (ID: %s, commit: %s)",
125
  catboost_run_name,
126
+ catboost_run_id,
127
+ catboost_git_commit,
128
  )
129
 
130
+ # Find the embedder run with same git commit and source file
131
+ embedder_run = None
132
+ embedder_run_id = None
133
+ embedder_run_name = None
134
+
135
+ if catboost_git_commit:
136
+ # Search for sentence transformer with same git commit
137
+ logger.info(
138
+ "[%s] Searching for embedder with git commit: %s",
139
+ lang.upper(),
140
+ catboost_git_commit,
141
+ )
142
+ embedder_runs = mlflow.search_runs(
143
+ experiment_names=["evaluating"],
144
+ filter_string=f"tags.`mlflow.source.git.commit` = '{catboost_git_commit}' and run_name LIKE 'sentence_transformer%'",
145
+ )
146
+
147
+ if not embedder_runs.empty:
148
+ embedder_run = embedder_runs.iloc[0]
149
+ embedder_run_id = embedder_run.run_id
150
+ embedder_run_name = embedder_run.get("tags.mlflow.runName", "unknown")
151
+ logger.info(
152
+ "[%s] Found embedder with matching git commit: '%s' (ID: %s)",
153
+ lang.upper(),
154
+ embedder_run_name,
155
+ embedder_run_id,
156
+ )
157
+
158
+ # Fallback: search by default name if git commit search failed
159
+ if not embedder_run_id:
160
+ logger.info(
161
+ "[%s] Falling back to default embedder search",
162
+ lang.upper(),
163
+ )
164
+ embedder_runs = mlflow.search_runs(
165
+ experiment_names=["evaluating"],
166
+ filter_string="run_name = 'sentence_transformer_paraphrase-MiniLM-L6-v2'",
167
+ )
168
+ if embedder_runs.empty:
169
+ logger.error(
170
+ "No embedder run found for 'sentence_transformer_paraphrase-MiniLM-L6-v2'"
171
+ )
172
+ return
173
+ embedder_run = embedder_runs.iloc[0]
174
+ embedder_run_id = embedder_run.run_id
175
+ embedder_run_name = embedder_run.get("tags.mlflow.runName", "unknown")
176
+ logger.info(
177
+ "Found Embedder run: '%s' (ID: %s)",
178
+ embedder_run_name,
179
+ embedder_run_id,
180
  )
 
 
 
 
 
 
 
 
 
181
 
182
  try:
183
  # Load the CatBoost model from MLflow
184
  model_uri = f"runs:/{catboost_run_id}/model_{lang}"
185
  logger.info(
186
+ "[%s] Loading CatBoost classifier from run '%s' (ID: %s)...",
187
  lang.upper(),
188
  catboost_run_name,
189
+ catboost_run_id,
190
  )
191
  model = mlflow.sklearn.load_model(model_uri)
192
 
193
  # Load the sentence transformer embedder from MLflow
194
  embedder_uri = f"runs:/{embedder_run_id}/model_{lang}"
195
  logger.info(
196
+ "[%s] Loading sentence transformer from run '%s' (ID: %s)...",
197
  lang.upper(),
198
  embedder_run_name,
199
+ embedder_run_id,
200
  )
201
  embedder = mlflow.sklearn.load_model(embedder_uri)
 
 
202
 
203
  # Register the model with its metadata
204
  self._registry[lang]["catboost"] = {
 
217
 
218
  def get_model(self, language: str, model_type: str) -> Optional[Dict[str, Any]]:
219
  """Retrieve a loaded model entry by language and type.
220
+
221
  Args:
222
  language: The programming language code
223
  model_type: The type of model
224
+
225
  Returns:
226
  Dict containing the model and metadata, or None if not found.
227
  The dict contains:
 
233
 
234
  def clear(self) -> None:
235
  """Clear all models from the registry and free memory.
236
+
237
  This method should be called during application shutdown to
238
  release GPU memory and other resources.
239
  """
 
247
 
248
  class ModelPredictor:
249
  """Handles low-level prediction logic.
250
+
251
  Attributes:
252
  registry: Reference to the ModelRegistry for model access
253
+
254
  Example:
255
  ```python
256
  registry = ModelRegistry()
257
  registry.load_all_models(Path("./models"))
258
+
259
  predictor = ModelPredictor(registry)
260
  predictions = predictor.predict(
261
  texts=["# Calculate sum of list"],
 
268
 
269
  def __init__(self, model_registry: ModelRegistry) -> None:
270
  """Initialize the predictor with a model registry.
271
+
272
  Args:
273
  model_registry: The ModelRegistry instance containing loaded models
274
  """
275
  self.registry = model_registry
276
 
277
+ def predict(
278
+ self, texts: List[str], language: str, model_type: str
279
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
280
  """Execute prediction on a list of texts.
281
+
282
  This method handles the full inference pipeline:
283
  1. Retrieve the appropriate model from the registry
284
  2. Extract features (e.g., generate embeddings)
285
  3. Run model inference
286
  4. Return raw predictions
287
+
288
  Args:
289
  texts: List of code comment strings to classify
290
  language: Programming language context for model selection
291
  model_type: Type of model to use
292
+
293
  Returns:
294
+ Tuple containing:
295
+ - numpy array of predictions with shape (n_samples, n_labels).
296
+ - numpy array of embeddings (if available, else None).
297
+
298
  Raises:
299
  ValueError: If the requested model is not available or
300
  if an unsupported feature/model type is specified
 
318
  embeddings = embedder.encode(texts, show_progress_bar=False)
319
 
320
  # Run CatBoost prediction on embeddings
321
+ return model.predict(embeddings), embeddings
322
 
323
  raise ValueError("Unsupported feature type for CatBoost")
324