NIIHAAD commited on
Commit
d0b8c26
·
verified ·
1 Parent(s): 6f37b35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -23
app.py CHANGED
@@ -395,31 +395,37 @@ def xgb_predict_safe(model, X, label_encoder=None):
395
 
396
 
397
  # -------- Gradio --------
398
- def predict_with_model(model, features, feat_list, le=None):
399
- """Prédiction XGBoost sûre comme dans ton exemple"""
400
- # Préparer la ligne
401
- row = []
402
- for col in feat_list:
403
- val = features.get(col, 0)
404
- if val is None or isinstance(val, (list, dict)):
405
- val = 0
406
- row.append(float(val)) # s'assurer que c'est float
407
 
408
- # Créer DataFrame
409
- X = pd.DataFrame([row], columns=feat_list)
410
 
411
- # Transformer en DMatrix
412
- dmatrix = xgb.DMatrix(X.values, feature_names=feat_list)
413
 
414
- # Prédiction
415
- pred_int = int(model.get_booster().predict(dmatrix)[0])
 
416
 
 
 
 
 
 
 
417
  if le:
418
- return le.inverse_transform([pred_int])[0]
 
 
 
419
  return pred_int
420
 
421
 
422
-
423
  def predict_with_metadata(url):
424
  if url.strip() == "":
425
  return "❌ Veuillez entrer une URL FreeSound."
@@ -462,14 +468,17 @@ def predict_with_metadata(url):
462
  dmatrix = xgb.DMatrix(df_for_model.values, feature_names=list(df_for_model.columns))
463
 
464
 
465
- # 7️ Faire les prédictions s
 
 
 
 
466
  NUM_DOWNLOADS_MAP = {0: "Low", 1: "Medium", 2: "High"}
 
467
 
468
- # Utiliser la fonction simplifiée
469
- pred_num_downloads = predict_with_model(model_nd, df_for_model.iloc[0].to_dict(), model_features)
470
- pred_avg_rating = predict_with_model(model_ar, df_for_model.iloc[0].to_dict(), model_features, le=music_avg_rating_le if dur >= 10 else effect_avg_rating_le)
471
-
472
-
473
  # 8️ Affichage des features prétraitées
474
  processed_lines = ["\n=== Features après preprocessing ==="]
475
  for col in df_processed.columns:
 
395
 
396
 
397
  # -------- Gradio --------
398
+ def predict_with_model(model, df_input, feat_list, le=None):
399
+ """
400
+ On passe directement le DataFrame filtré pour éviter les erreurs de dictionnaire
401
+ """
402
+ # 1. On s'assure de n'avoir que les colonnes attendues par le booster
403
+ booster_feats = model.get_booster().feature_names
 
 
 
404
 
405
+ # 2. On aligne le DataFrame sur ces colonnes précisément
406
+ X_aligned = df_input.reindex(columns=booster_feats, fill_value=0.0).astype(float)
407
 
408
+ # 3. Création de la DMatrix avec les noms de features officiels du modèle
409
+ dmatrix = xgb.DMatrix(X_aligned.values, feature_names=booster_feats)
410
 
411
+ # 4. Prédiction
412
+ preds = model.get_booster().predict(dmatrix)
413
+ pred_val = preds[0]
414
 
415
+ # Si c'est une classification (plusieurs probabilités), on prend l'index max
416
+ if len(preds.shape) > 1 and preds.shape[1] > 1:
417
+ pred_int = int(np.argmax(pred_val))
418
+ else:
419
+ pred_int = int(round(float(pred_val)))
420
+
421
  if le:
422
+ try:
423
+ return le.inverse_transform([pred_int])[0]
424
+ except:
425
+ return f"Classe inconnue ({pred_int})"
426
  return pred_int
427
 
428
 
 
429
  def predict_with_metadata(url):
430
  if url.strip() == "":
431
  return "❌ Veuillez entrer une URL FreeSound."
 
468
  dmatrix = xgb.DMatrix(df_for_model.values, feature_names=list(df_for_model.columns))
469
 
470
 
471
+ # 7️ Faire les prédictions
472
+ # On passe 'df_for_model' directement (qui est déjà un DataFrame)
473
+ pred_num_downloads_val = predict_with_model(model_nd, df_for_model, model_features)
474
+
475
+ # Mapping pour num_downloads si le modèle renvoie un entier
476
  NUM_DOWNLOADS_MAP = {0: "Low", 1: "Medium", 2: "High"}
477
+ pred_num_downloads = NUM_DOWNLOADS_MAP.get(pred_num_downloads_val, str(pred_num_downloads_val))
478
 
479
+ # Prédiction du rating avec le LabelEncoder
480
+ current_le = music_avg_rating_le if dur >= 10 else effect_avg_rating_le
481
+ pred_avg_rating = predict_with_model(model_ar, df_for_model, model_features, le=current_le)
 
 
482
  # 8️ Affichage des features prétraitées
483
  processed_lines = ["\n=== Features après preprocessing ==="]
484
  for col in df_processed.columns: