Spaces:
Runtime error
Runtime error
Update nygaardcodecommentclassification/api/controllers.py
#2
by
Fonty02 - opened
nygaardcodecommentclassification/api/controllers.py
CHANGED
|
@@ -113,24 +113,19 @@ class PredictionController:
|
|
| 113 |
Raises:
|
| 114 |
ValueError: If language is not supported or model type unavailable
|
| 115 |
RuntimeError: If prediction fails or labels configuration is missing
|
| 116 |
-
|
| 117 |
-
Example:
|
| 118 |
-
```python
|
| 119 |
-
results = controller.predict(
|
| 120 |
-
texts=["This calculates fibonacci", "TODO: optimize"],
|
| 121 |
-
language="python",
|
| 122 |
-
model_type="catboost"
|
| 123 |
-
)
|
| 124 |
-
# Returns:
|
| 125 |
-
# [
|
| 126 |
-
# {"text": "This calculates fibonacci", "labels": ["summary"]},
|
| 127 |
-
# {"text": "TODO: optimize", "labels": ["expand"]}
|
| 128 |
-
# ]
|
| 129 |
-
```
|
| 130 |
"""
|
| 131 |
-
# --- Request Validation ---
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
raise ValueError(f"Language '{language}' not supported. Available: {config.LANGUAGES}")
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
available_types = ["catboost"] # Currently only CatBoost is supported
|
| 136 |
if model_type not in available_types:
|
|
@@ -138,31 +133,42 @@ class PredictionController:
|
|
| 138 |
f"Model '{model_type}' unavailable for {language}. Available: {available_types}"
|
| 139 |
)
|
| 140 |
|
| 141 |
-
# --- Model Inference ---
|
| 142 |
try:
|
| 143 |
-
|
|
|
|
| 144 |
except Exception as e:
|
| 145 |
-
logger.error("Prediction failed for %s/%s: %s",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
raise RuntimeError(f"Internal model error: {e}") from e
|
| 147 |
|
| 148 |
-
# --- Result Formatting ---
|
| 149 |
-
# Get the label mapping
|
| 150 |
try:
|
| 151 |
-
labels_map = config.LABELS[
|
| 152 |
except KeyError as e:
|
| 153 |
-
raise RuntimeError(f"Configuration error: Labels map missing for {
|
| 154 |
|
| 155 |
# Convert numeric predictions to human-readable labels
|
| 156 |
results: List[Dict[str, Any]] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
for i, text_input in enumerate(texts):
|
| 158 |
row_pred = y_pred[i] # Binary array (1 = label present, 0 = absent)
|
| 159 |
-
|
| 160 |
# Find indices where prediction is 1 (positive class)
|
| 161 |
predicted_indices = np.where(row_pred == 1)[0]
|
| 162 |
-
|
| 163 |
# Map indices to label strings
|
| 164 |
predicted_labels = [labels_map[idx] for idx in predicted_indices]
|
| 165 |
-
|
| 166 |
results.append({"text": text_input, "labels": predicted_labels})
|
| 167 |
|
| 168 |
return results
|
|
|
|
|
|
| 113 |
Raises:
|
| 114 |
ValueError: If language is not supported or model type unavailable
|
| 115 |
RuntimeError: If prediction fails or labels configuration is missing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
"""
|
| 117 |
+
# --- 1. Robust Request Validation (Case-Insensitive) ---
|
| 118 |
+
|
| 119 |
+
# Crea una mappa { "python": "Python", "java": "Java" } basata sul config
|
| 120 |
+
# Questo permette di trovare la chiave corretta anche se l'input è minuscolo
|
| 121 |
+
supported_languages_map = {l.lower(): l for l in config.LANGUAGES}
|
| 122 |
+
input_lang_lower = language.lower()
|
| 123 |
+
|
| 124 |
+
if input_lang_lower not in supported_languages_map:
|
| 125 |
raise ValueError(f"Language '{language}' not supported. Available: {config.LANGUAGES}")
|
| 126 |
+
|
| 127 |
+
# Recupera la stringa esatta usata nel config e nel registry (es. "Python" o "python")
|
| 128 |
+
canonical_language = supported_languages_map[input_lang_lower]
|
| 129 |
|
| 130 |
available_types = ["catboost"] # Currently only CatBoost is supported
|
| 131 |
if model_type not in available_types:
|
|
|
|
| 133 |
f"Model '{model_type}' unavailable for {language}. Available: {available_types}"
|
| 134 |
)
|
| 135 |
|
| 136 |
+
# --- 2. Model Inference ---
|
| 137 |
try:
|
| 138 |
+
# Usiamo canonical_language per essere sicuri di matchare la chiave nel Registry
|
| 139 |
+
y_pred = self.predictor.predict(texts, canonical_language, model_type)
|
| 140 |
except Exception as e:
|
| 141 |
+
logger.error("Prediction failed for %s/%s: %s", canonical_language, model_type, e)
|
| 142 |
+
# Loggo anche le chiavi disponibili nel registry per debug
|
| 143 |
+
try:
|
| 144 |
+
available_keys = list(self.registry._registry.keys())
|
| 145 |
+
logger.error("Debug - Registry keys available: %s", available_keys)
|
| 146 |
+
except:
|
| 147 |
+
pass
|
| 148 |
raise RuntimeError(f"Internal model error: {e}") from e
|
| 149 |
|
| 150 |
+
# --- 3. Result Formatting ---
|
| 151 |
+
# Get the label mapping using the canonical language key
|
| 152 |
try:
|
| 153 |
+
labels_map = config.LABELS[canonical_language]
|
| 154 |
except KeyError as e:
|
| 155 |
+
raise RuntimeError(f"Configuration error: Labels map missing for {canonical_language}") from e
|
| 156 |
|
| 157 |
# Convert numeric predictions to human-readable labels
|
| 158 |
results: List[Dict[str, Any]] = []
|
| 159 |
+
|
| 160 |
+
# Se c'è solo un testo, predict potrebbe ritornare un array 1D invece di 2D.
|
| 161 |
+
# Assicuriamoci che y_pred sia sempre 2D (n_samples, n_labels)
|
| 162 |
+
if y_pred.ndim == 1:
|
| 163 |
+
y_pred = y_pred.reshape(1, -1)
|
| 164 |
+
|
| 165 |
for i, text_input in enumerate(texts):
|
| 166 |
row_pred = y_pred[i] # Binary array (1 = label present, 0 = absent)
|
|
|
|
| 167 |
# Find indices where prediction is 1 (positive class)
|
| 168 |
predicted_indices = np.where(row_pred == 1)[0]
|
|
|
|
| 169 |
# Map indices to label strings
|
| 170 |
predicted_labels = [labels_map[idx] for idx in predicted_indices]
|
|
|
|
| 171 |
results.append({"text": text_input, "labels": predicted_labels})
|
| 172 |
|
| 173 |
return results
|
| 174 |
+
|