Update nygaardcodecommentclassification/api/controllers.py

#2
by Fonty02 - opened
nygaardcodecommentclassification/api/controllers.py CHANGED
@@ -113,24 +113,19 @@ class PredictionController:
113
  Raises:
114
  ValueError: If language is not supported or model type unavailable
115
  RuntimeError: If prediction fails or labels configuration is missing
116
-
117
- Example:
118
- ```python
119
- results = controller.predict(
120
- texts=["This calculates fibonacci", "TODO: optimize"],
121
- language="python",
122
- model_type="catboost"
123
- )
124
- # Returns:
125
- # [
126
- # {"text": "This calculates fibonacci", "labels": ["summary"]},
127
- # {"text": "TODO: optimize", "labels": ["expand"]}
128
- # ]
129
- ```
130
  """
131
- # --- Request Validation ---
132
- if language not in config.LANGUAGES:
 
 
 
 
 
 
133
  raise ValueError(f"Language '{language}' not supported. Available: {config.LANGUAGES}")
 
 
 
134
 
135
  available_types = ["catboost"] # Currently only CatBoost is supported
136
  if model_type not in available_types:
@@ -138,31 +133,42 @@ class PredictionController:
138
  f"Model '{model_type}' unavailable for {language}. Available: {available_types}"
139
  )
140
 
141
- # --- Model Inference ---
142
  try:
143
- y_pred = self.predictor.predict(texts, language, model_type)
 
144
  except Exception as e:
145
- logger.error("Prediction failed for %s/%s: %s", language, model_type, e)
 
 
 
 
 
 
146
  raise RuntimeError(f"Internal model error: {e}") from e
147
 
148
- # --- Result Formatting ---
149
- # Get the label mapping for this language
150
  try:
151
- labels_map = config.LABELS[language]
152
  except KeyError as e:
153
- raise RuntimeError(f"Configuration error: Labels map missing for {language}") from e
154
 
155
  # Convert numeric predictions to human-readable labels
156
  results: List[Dict[str, Any]] = []
 
 
 
 
 
 
157
  for i, text_input in enumerate(texts):
158
  row_pred = y_pred[i] # Binary array (1 = label present, 0 = absent)
159
-
160
  # Find indices where prediction is 1 (positive class)
161
  predicted_indices = np.where(row_pred == 1)[0]
162
-
163
  # Map indices to label strings
164
  predicted_labels = [labels_map[idx] for idx in predicted_indices]
165
-
166
  results.append({"text": text_input, "labels": predicted_labels})
167
 
168
  return results
 
 
113
  Raises:
114
  ValueError: If language is not supported or model type unavailable
115
  RuntimeError: If prediction fails or labels configuration is missing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  """
117
+ # --- 1. Robust Request Validation (Case-Insensitive) ---
118
+
119
+ # Crea una mappa { "python": "Python", "java": "Java" } basata sul config
120
+ # Questo permette di trovare la chiave corretta anche se l'input è minuscolo
121
+ supported_languages_map = {l.lower(): l for l in config.LANGUAGES}
122
+ input_lang_lower = language.lower()
123
+
124
+ if input_lang_lower not in supported_languages_map:
125
  raise ValueError(f"Language '{language}' not supported. Available: {config.LANGUAGES}")
126
+
127
+ # Recupera la stringa esatta usata nel config e nel registry (es. "Python" o "python")
128
+ canonical_language = supported_languages_map[input_lang_lower]
129
 
130
  available_types = ["catboost"] # Currently only CatBoost is supported
131
  if model_type not in available_types:
 
133
  f"Model '{model_type}' unavailable for {language}. Available: {available_types}"
134
  )
135
 
136
+ # --- 2. Model Inference ---
137
  try:
138
+ # Usiamo canonical_language per essere sicuri di matchare la chiave nel Registry
139
+ y_pred = self.predictor.predict(texts, canonical_language, model_type)
140
  except Exception as e:
141
+ logger.error("Prediction failed for %s/%s: %s", canonical_language, model_type, e)
142
+ # Loggo anche le chiavi disponibili nel registry per debug
143
+ try:
144
+ available_keys = list(self.registry._registry.keys())
145
+ logger.error("Debug - Registry keys available: %s", available_keys)
146
+ except:
147
+ pass
148
  raise RuntimeError(f"Internal model error: {e}") from e
149
 
150
+ # --- 3. Result Formatting ---
151
+ # Get the label mapping using the canonical language key
152
  try:
153
+ labels_map = config.LABELS[canonical_language]
154
  except KeyError as e:
155
+ raise RuntimeError(f"Configuration error: Labels map missing for {canonical_language}") from e
156
 
157
  # Convert numeric predictions to human-readable labels
158
  results: List[Dict[str, Any]] = []
159
+
160
+ # Se c'è solo un testo, predict potrebbe ritornare un array 1D invece di 2D.
161
+ # Assicuriamoci che y_pred sia sempre 2D (n_samples, n_labels)
162
+ if y_pred.ndim == 1:
163
+ y_pred = y_pred.reshape(1, -1)
164
+
165
  for i, text_input in enumerate(texts):
166
  row_pred = y_pred[i] # Binary array (1 = label present, 0 = absent)
 
167
  # Find indices where prediction is 1 (positive class)
168
  predicted_indices = np.where(row_pred == 1)[0]
 
169
  # Map indices to label strings
170
  predicted_labels = [labels_map[idx] for idx in predicted_indices]
 
171
  results.append({"text": text_input, "labels": predicted_labels})
172
 
173
  return results
174
+