Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import freesound | |
| import joblib | |
| import xgboost as xgb | |
| # ---------------------------- | |
| # Config Freesound | |
| # ---------------------------- | |
| API_TOKEN = "zE9NjEOgUMzH9K7mjiGBaPJiNwJLjSM53LevarRK" | |
| client = freesound.FreesoundClient() | |
| client.set_token(API_TOKEN, "token") | |
| # ---------------------------- | |
| # 1️⃣ Charger les modèles | |
| # ---------------------------- | |
| # Music | |
| xgb_music_num = joblib.load("xgb_num_downloads_music_model.pkl") | |
| xgb_music_feat_num = joblib.load("xgb_num_downloads_music_features.pkl") | |
| xgb_music_avg = joblib.load("xgb_avg_rating_music_model.pkl") | |
| xgb_music_feat_avg = joblib.load("xgb_avg_rating_music_features.pkl") | |
| le_music_avg = joblib.load("xgb_avg_rating_music_label_encoder.pkl") | |
| # Effect Sound | |
| xgb_effect_num = joblib.load("xgb_num_downloads_effectsound_model.pkl") | |
| xgb_effect_feat_num = joblib.load("xgb_num_downloads_effectsound_features.pkl") | |
| xgb_effect_avg = joblib.load("xgb_avg_rating_effectsound_model.pkl") | |
| xgb_effect_feat_avg = joblib.load("xgb_avg_rating_effectsound_features.pkl") | |
| le_effect_avg = joblib.load("xgb_avg_rating_effectsound_label_encoder.pkl") | |
| # ---------------------------- | |
| # 2️⃣ Fonctions utilitaires | |
| # ---------------------------- | |
| def safe_float(v): | |
| try: | |
| return float(v) | |
| except: | |
| return 0.0 | |
| def predict_with_model(model, features, feat_list, le=None): | |
| # Préparer la ligne | |
| row = [] | |
| for col in feat_list: | |
| val = features.get(col, 0) | |
| if val is None or isinstance(val, (list, dict)): | |
| val = 0 | |
| row.append(safe_float(val)) | |
| X = pd.DataFrame([row], columns=feat_list) | |
| # Transformer en DMatrix | |
| dmatrix = xgb.DMatrix(X.values, feature_names=feat_list) | |
| # Prédiction | |
| pred_int = int(model.get_booster().predict(dmatrix)[0]) | |
| if le: | |
| return le.inverse_transform([pred_int])[0] | |
| return pred_int | |
| # ---------------------------- | |
| # 2️⃣ Mapping Num_downloads | |
| # ---------------------------- | |
| NUM_DOWNLOADS_MAP = { | |
| 0: "Low", | |
| 1: "Medium", | |
| 2: "High" | |
| } | |
| # ---------------------------- | |
| # 3️⃣ Extraction + prédiction | |
| # ---------------------------- | |
| def extract_and_predict(url): | |
| try: | |
| sound_id = int(url.rstrip("/").split("/")[-1]) | |
| # Inclure duration explicitement | |
| all_features = list(set( | |
| xgb_music_feat_num + xgb_music_feat_avg + xgb_effect_feat_num + xgb_effect_feat_avg | |
| )) | |
| fields = "duration," + ",".join(all_features) | |
| results = client.search( | |
| query="", | |
| filter=f"id:{sound_id}", | |
| fields=fields | |
| ) | |
| if len(results.results) == 0: | |
| return pd.DataFrame([{"Erreur": "Sound not found"}]) | |
| sound = results.results[0] | |
| # ⚠️ Récupérer duration séparément | |
| duration = safe_float(sound.get("duration", 0)) | |
| # ✅ Décider du type | |
| if 0.5 <= duration <= 3: | |
| # Effect Sound | |
| num = predict_with_model(xgb_effect_num, sound, xgb_effect_feat_num) | |
| avg = predict_with_model(xgb_effect_avg, sound, xgb_effect_feat_avg, le_effect_avg) | |
| return pd.DataFrame([{ | |
| "Type": "Effect Sound", | |
| "Duration": duration, | |
| "Num_downloads": NUM_DOWNLOADS_MAP.get(num, str(num)), | |
| "Avg_rating": avg | |
| }]) | |
| elif 10 <= duration <= 60: | |
| # Music | |
| num = predict_with_model(xgb_music_num, sound, xgb_music_feat_num) | |
| avg = predict_with_model(xgb_music_avg, sound, xgb_music_feat_avg, le_music_avg) | |
| return pd.DataFrame([{ | |
| "Type": "Music", | |
| "Duration": duration, | |
| "Num_downloads": NUM_DOWNLOADS_MAP.get(num, str(num)), | |
| "Avg_rating": avg | |
| }]) | |
| else: | |
| return pd.DataFrame([{ | |
| "Erreur": "Durée non supportée pour prédiction", | |
| "Duration": duration | |
| }]) | |
| except Exception as e: | |
| return pd.DataFrame([{"Erreur": str(e)}]) | |
| # ---------------------------- | |
| # 4️⃣ Interface Gradio | |
| # ---------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🎧 FreeSound – Prédiction XGBoost (DMatrix)") | |
| url = gr.Textbox(label="URL FreeSound", placeholder="https://freesound.org/s/123456/") | |
| btn = gr.Button("Prédire") | |
| out = gr.Dataframe() | |
| btn.click(extract_and_predict, url, out) | |
| demo.launch() |