Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -395,31 +395,37 @@ def xgb_predict_safe(model, X, label_encoder=None):
|
|
| 395 |
|
| 396 |
|
| 397 |
# -------- Gradio --------
|
| 398 |
-
def predict_with_model(model,
|
| 399 |
-
"""
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
if val is None or isinstance(val, (list, dict)):
|
| 405 |
-
val = 0
|
| 406 |
-
row.append(float(val)) # s'assurer que c'est float
|
| 407 |
|
| 408 |
-
#
|
| 409 |
-
|
| 410 |
|
| 411 |
-
#
|
| 412 |
-
dmatrix = xgb.DMatrix(
|
| 413 |
|
| 414 |
-
# Prédiction
|
| 415 |
-
|
|
|
|
| 416 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
if le:
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
| 419 |
return pred_int
|
| 420 |
|
| 421 |
|
| 422 |
-
|
| 423 |
def predict_with_metadata(url):
|
| 424 |
if url.strip() == "":
|
| 425 |
return "❌ Veuillez entrer une URL FreeSound."
|
|
@@ -462,14 +468,17 @@ def predict_with_metadata(url):
|
|
| 462 |
dmatrix = xgb.DMatrix(df_for_model.values, feature_names=list(df_for_model.columns))
|
| 463 |
|
| 464 |
|
| 465 |
-
# 7️ Faire les prédictions
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
NUM_DOWNLOADS_MAP = {0: "Low", 1: "Medium", 2: "High"}
|
|
|
|
| 467 |
|
| 468 |
-
#
|
| 469 |
-
|
| 470 |
-
pred_avg_rating = predict_with_model(model_ar, df_for_model
|
| 471 |
-
|
| 472 |
-
|
| 473 |
# 8️ Affichage des features prétraitées
|
| 474 |
processed_lines = ["\n=== Features après preprocessing ==="]
|
| 475 |
for col in df_processed.columns:
|
|
|
|
| 395 |
|
| 396 |
|
| 397 |
# -------- Gradio --------
|
| 398 |
+
def predict_with_model(model, df_input, feat_list, le=None):
|
| 399 |
+
"""
|
| 400 |
+
On passe directement le DataFrame filtré pour éviter les erreurs de dictionnaire
|
| 401 |
+
"""
|
| 402 |
+
# 1. On s'assure de n'avoir que les colonnes attendues par le booster
|
| 403 |
+
booster_feats = model.get_booster().feature_names
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
+
# 2. On aligne le DataFrame sur ces colonnes précisément
|
| 406 |
+
X_aligned = df_input.reindex(columns=booster_feats, fill_value=0.0).astype(float)
|
| 407 |
|
| 408 |
+
# 3. Création de la DMatrix avec les noms de features officiels du modèle
|
| 409 |
+
dmatrix = xgb.DMatrix(X_aligned.values, feature_names=booster_feats)
|
| 410 |
|
| 411 |
+
# 4. Prédiction
|
| 412 |
+
preds = model.get_booster().predict(dmatrix)
|
| 413 |
+
pred_val = preds[0]
|
| 414 |
|
| 415 |
+
# Si c'est une classification (plusieurs probabilités), on prend l'index max
|
| 416 |
+
if len(preds.shape) > 1 and preds.shape[1] > 1:
|
| 417 |
+
pred_int = int(np.argmax(pred_val))
|
| 418 |
+
else:
|
| 419 |
+
pred_int = int(round(float(pred_val)))
|
| 420 |
+
|
| 421 |
if le:
|
| 422 |
+
try:
|
| 423 |
+
return le.inverse_transform([pred_int])[0]
|
| 424 |
+
except:
|
| 425 |
+
return f"Classe inconnue ({pred_int})"
|
| 426 |
return pred_int
|
| 427 |
|
| 428 |
|
|
|
|
| 429 |
def predict_with_metadata(url):
|
| 430 |
if url.strip() == "":
|
| 431 |
return "❌ Veuillez entrer une URL FreeSound."
|
|
|
|
| 468 |
dmatrix = xgb.DMatrix(df_for_model.values, feature_names=list(df_for_model.columns))
|
| 469 |
|
| 470 |
|
| 471 |
+
# 7️ Faire les prédictions
|
| 472 |
+
# On passe 'df_for_model' directement (qui est déjà un DataFrame)
|
| 473 |
+
pred_num_downloads_val = predict_with_model(model_nd, df_for_model, model_features)
|
| 474 |
+
|
| 475 |
+
# Mapping pour num_downloads si le modèle renvoie un entier
|
| 476 |
NUM_DOWNLOADS_MAP = {0: "Low", 1: "Medium", 2: "High"}
|
| 477 |
+
pred_num_downloads = NUM_DOWNLOADS_MAP.get(pred_num_downloads_val, str(pred_num_downloads_val))
|
| 478 |
|
| 479 |
+
# Prédiction du rating avec le LabelEncoder
|
| 480 |
+
current_le = music_avg_rating_le if dur >= 10 else effect_avg_rating_le
|
| 481 |
+
pred_avg_rating = predict_with_model(model_ar, df_for_model, model_features, le=current_le)
|
|
|
|
|
|
|
| 482 |
# 8️ Affichage des features prétraitées
|
| 483 |
processed_lines = ["\n=== Features après preprocessing ==="]
|
| 484 |
for col in df_processed.columns:
|