NIIHAAD commited on
Commit
a1d7096
·
1 Parent(s): 7f675d3
Files changed (1) hide show
  1. app.py +28 -125
app.py CHANGED
@@ -377,189 +377,92 @@ def preprocess_sound(df):
377
 
378
 
379
  def xgb_predict_safe(model, X, label_encoder=None):
380
- # Sécurité ultime : forcer exactement les features du booster
381
  booster_features = model.get_booster().feature_names
382
  X_safe = X.reindex(columns=booster_features, fill_value=0.0).astype(np.float32)
383
 
384
- dmatrix = xgb.DMatrix(
385
- X_safe.values,
386
- feature_names=booster_features
387
- )
388
 
389
  pred = model.get_booster().predict(dmatrix)[0]
390
 
391
  if label_encoder is not None:
392
- return label_encoder.inverse_transform([int(round(pred))])[0]
 
 
 
 
393
 
394
  return pred
395
 
396
 
397
  # -------- Gradio --------
398
- def predict_with_metadata(url):
399
 
 
400
  if url.strip() == "":
401
-
402
  return "❌ Veuillez entrer une URL FreeSound."
403
 
404
- # 1️ Récupérer les métadonnées brutes
405
  df_raw = fetch_sound_metadata(url)
406
- # Affichage ligne par ligne pour les métadonnées brutes
407
  raw_lines = ["=== Métadonnées brutes ==="]
408
-
409
  for col in df_raw.columns:
410
  raw_lines.append(f"{col}: {df_raw[col].iloc[0]}")
411
  raw_str = "\n".join(raw_lines)
412
 
413
-
414
- # 2️ Vérifier la durée
415
-
416
  dur = df_raw["duration"].iloc[0]
417
-
418
  if dur < 0.5:
419
-
420
- return raw_str + f"\n\n Son trop court ({dur} sec), veuillez entrer un son qui est court (0.5 à 3 s) ou un son long (10 à 60 s)"
421
-
422
  elif 3 < dur < 10 or dur > 60:
 
423
 
424
- return raw_str + f"\n\n Son trop long ou hors plage acceptable ({dur} sec) , veuillez entrer un son qui est court (0.5 à 3 s) ou un son long (10 à 60 s)"
425
-
426
-
427
-
428
- # 3️ Prétraitement seulement si durée ok
429
-
430
  df_processed = preprocess_sound(df_raw)
431
-
432
-
433
-
434
- # Supprimer les colonnes inutiles
435
-
436
  cols_to_remove = ["avg_rating", "num_downloads_class"]
437
-
438
  df_for_model = df_processed.drop(columns=[c for c in cols_to_remove if c in df_processed.columns])
439
 
440
-
441
-
442
- # Choix modèle
443
-
444
  if 0.5 <= dur <= 3:
445
-
446
- model_features = effect_model_features
447
-
448
  model_nd = effect_model_num_downloads
449
-
450
  model_ar = effect_model_avg_rating
451
-
452
- le_ar = effect_avg_rating_le
453
-
454
  sound_type = "EffectSound"
455
-
456
  else:
457
-
458
- model_features = music_model_features
459
-
460
  model_nd = music_model_num_downloads
461
-
462
  model_ar = music_model_avg_rating
463
-
464
- le_ar = music_avg_rating_le
465
-
466
  sound_type = "Music"
467
 
468
-
469
-
470
- # 🔹 Forcer exactement les colonnes du modèle
471
-
472
- expected_n_cols = len(model_features)
473
-
474
-
475
-
476
- # Supprimer tout ce qui n'est pas dans le modèle
477
-
478
- df_for_model = df_for_model[[c for c in model_features if c in df_for_model.columns]]
479
-
480
-
481
-
482
- # Ajouter les colonnes manquantes avec 0
483
-
484
- for col in model_features:
485
-
486
- if col not in df_for_model.columns:
487
-
488
- df_for_model[col] = 0.0
489
-
490
-
491
-
492
- # Réordonner exactement
493
-
494
  df_for_model = df_for_model.reindex(columns=model_features, fill_value=0.0).astype(float)
495
 
 
 
496
 
497
 
498
- # Dernière sécurité : si encore mismatch, tronquer ou ajouter des colonnes fictives
499
- """
500
- if df_for_model.shape[1] != expected_n_cols:
501
- diff = expected_n_cols - df_for_model.shape[1]
502
- if diff > 0:
503
- for i in range(diff):
504
- df_for_model[f"extra_col_{i}"] = 0.0
505
- elif diff < 0:
506
- df_for_model = df_for_model.iloc[:, :expected_n_cols]
507
- """
508
- # Prédictions
509
- pred_num_downloads = xgb_predict_safe(
510
- model_nd,
511
- df_for_model,
512
- model_features
513
- )
514
-
515
- pred_avg_rating = xgb_predict_safe(
516
- model_ar,
517
- df_for_model,
518
- model_features,
519
- label_encoder=le_ar
520
- )
521
-
522
- #pred_num_downloads = model_nd.predict(df_for_model)[0]
523
 
524
- #pred_avg_rating_enc = model_ar.predict(df_for_model)[0]
 
525
 
526
- #pred_avg_rating = le_ar.inverse_transform([pred_avg_rating_enc])[0]
527
-
528
-
529
-
530
- # Affichage ligne par ligne pour les features apr��s preprocessing
531
 
 
532
  processed_lines = ["\n=== Features après preprocessing ==="]
533
-
534
  for col in df_processed.columns:
535
-
536
  processed_lines.append(f"{col}: {df_processed[col].iloc[0]}")
537
-
538
  processed_str = "\n".join(processed_lines)
539
 
540
-
541
-
542
  prediction_lines = [
543
-
544
- "\n=== Prédictions ===",
545
-
546
  f"Type détecté : {sound_type}",
547
-
548
  f"📥 Num downloads prédit : {pred_num_downloads}",
549
-
550
  f"⭐ Avg rating prédit : {pred_avg_rating}"
551
-
552
  ]
553
-
554
-
555
-
556
  prediction_str = "\n".join(prediction_lines)
557
 
558
-
559
-
560
- return 'rien à afficher'
561
-
562
-
563
 
564
 
565
  def preprocess_name(df, vec_dim=8):
 
377
 
378
 
379
  def xgb_predict_safe(model, X, label_encoder=None):
 
380
  booster_features = model.get_booster().feature_names
381
  X_safe = X.reindex(columns=booster_features, fill_value=0.0).astype(np.float32)
382
 
383
+ dmatrix = xgb.DMatrix(X_safe.values, feature_names=list(booster_features))
 
 
 
384
 
385
  pred = model.get_booster().predict(dmatrix)[0]
386
 
387
  if label_encoder is not None:
388
+ # label_encoder est une liste de classes
389
+ pred_int = int(round(pred))
390
+ if pred_int < 0: pred_int = 0
391
+ if pred_int >= len(label_encoder): pred_int = len(label_encoder) - 1
392
+ return label_encoder[pred_int]
393
 
394
  return pred
395
 
396
 
397
  # -------- Gradio --------
 
398
 
399
+ def predict_with_metadata(url):
400
  if url.strip() == "":
 
401
  return "❌ Veuillez entrer une URL FreeSound."
402
 
403
+ # 1️ Récupérer les métadonnées brutes
404
  df_raw = fetch_sound_metadata(url)
 
405
  raw_lines = ["=== Métadonnées brutes ==="]
 
406
  for col in df_raw.columns:
407
  raw_lines.append(f"{col}: {df_raw[col].iloc[0]}")
408
  raw_str = "\n".join(raw_lines)
409
 
410
+ # 2️⃣ Vérifier la durée
 
 
411
  dur = df_raw["duration"].iloc[0]
 
412
  if dur < 0.5:
413
+ return raw_str + f"\n\n❌ Son trop court ({dur} sec). Plage acceptée: 0.5-3 ou 10-60 sec"
 
 
414
  elif 3 < dur < 10 or dur > 60:
415
+ return raw_str + f"\n\n❌ Son hors plage ({dur} sec). Plage acceptée: 0.5-3 ou 10-60 sec"
416
 
417
+ # 3️⃣ Prétraitement
 
 
 
 
 
418
  df_processed = preprocess_sound(df_raw)
 
 
 
 
 
419
  cols_to_remove = ["avg_rating", "num_downloads_class"]
 
420
  df_for_model = df_processed.drop(columns=[c for c in cols_to_remove if c in df_processed.columns])
421
 
422
+ # 4️⃣ Choix modèle selon durée
 
 
 
423
  if 0.5 <= dur <= 3:
 
 
 
424
  model_nd = effect_model_num_downloads
 
425
  model_ar = effect_model_avg_rating
426
+ model_features = effect_model_features
 
 
427
  sound_type = "EffectSound"
 
428
  else:
 
 
 
429
  model_nd = music_model_num_downloads
 
430
  model_ar = music_model_avg_rating
431
+ model_features = music_model_features
 
 
432
  sound_type = "Music"
433
 
434
+ # 5️⃣ Forcer exactement les colonnes du modèle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  df_for_model = df_for_model.reindex(columns=model_features, fill_value=0.0).astype(float)
436
 
437
+ # 6️⃣ DMatrix XGBoost
438
+ dmatrix = xgb.DMatrix(df_for_model.values, feature_names=list(df_for_model.columns))
439
 
440
 
441
+ # 7️⃣ Faire les prédictions
442
+ NUM_DOWNLOADS_MAP = {0: "Low", 1: "Medium", 2: "High"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
+ pred_num_downloads_int = int(model_nd.get_booster().predict(dmatrix)[0])
445
+ pred_avg_rating_int = int(model_ar.get_booster().predict(dmatrix)[0])
446
 
447
+ pred_num_downloads = NUM_DOWNLOADS_MAP.get(pred_num_downloads_int, str(pred_num_downloads_int))
448
+ pred_avg_rating = NUM_DOWNLOADS_MAP.get(pred_avg_rating_int, str(pred_avg_rating_int))
 
 
 
449
 
450
+ # 8️⃣ Affichage des features prétraitées
451
  processed_lines = ["\n=== Features après preprocessing ==="]
 
452
  for col in df_processed.columns:
 
453
  processed_lines.append(f"{col}: {df_processed[col].iloc[0]}")
 
454
  processed_str = "\n".join(processed_lines)
455
 
456
+ # 9️ Résultat final
 
457
  prediction_lines = [
458
+ "\n=== Prédictions ===",
 
 
459
  f"Type détecté : {sound_type}",
 
460
  f"📥 Num downloads prédit : {pred_num_downloads}",
 
461
  f"⭐ Avg rating prédit : {pred_avg_rating}"
 
462
  ]
 
 
 
463
  prediction_str = "\n".join(prediction_lines)
464
 
465
+ return raw_str + processed_str + prediction_str
 
 
 
 
466
 
467
 
468
  def preprocess_name(df, vec_dim=8):