KawtarTL's picture
Update app.py
95510d5 verified
import gradio as gr
import pandas as pd
import freesound
import joblib
import xgboost as xgb
# ----------------------------
# Config Freesound
# ----------------------------
API_TOKEN = "zE9NjEOgUMzH9K7mjiGBaPJiNwJLjSM53LevarRK"
client = freesound.FreesoundClient()
client.set_token(API_TOKEN, "token")
# ----------------------------
# 1️⃣ Charger les modèles
# ----------------------------
# Music
xgb_music_num = joblib.load("xgb_num_downloads_music_model.pkl")
xgb_music_feat_num = joblib.load("xgb_num_downloads_music_features.pkl")
xgb_music_avg = joblib.load("xgb_avg_rating_music_model.pkl")
xgb_music_feat_avg = joblib.load("xgb_avg_rating_music_features.pkl")
le_music_avg = joblib.load("xgb_avg_rating_music_label_encoder.pkl")
# Effect Sound
xgb_effect_num = joblib.load("xgb_num_downloads_effectsound_model.pkl")
xgb_effect_feat_num = joblib.load("xgb_num_downloads_effectsound_features.pkl")
xgb_effect_avg = joblib.load("xgb_avg_rating_effectsound_model.pkl")
xgb_effect_feat_avg = joblib.load("xgb_avg_rating_effectsound_features.pkl")
le_effect_avg = joblib.load("xgb_avg_rating_effectsound_label_encoder.pkl")
# ----------------------------
# 2️⃣ Fonctions utilitaires
# ----------------------------
def safe_float(v):
try:
return float(v)
except:
return 0.0
def predict_with_model(model, features, feat_list, le=None):
# Préparer la ligne
row = []
for col in feat_list:
val = features.get(col, 0)
if val is None or isinstance(val, (list, dict)):
val = 0
row.append(safe_float(val))
X = pd.DataFrame([row], columns=feat_list)
# Transformer en DMatrix
dmatrix = xgb.DMatrix(X.values, feature_names=feat_list)
# Prédiction
pred_int = int(model.get_booster().predict(dmatrix)[0])
if le:
return le.inverse_transform([pred_int])[0]
return pred_int
# ----------------------------
# 2️⃣ Mapping Num_downloads
# ----------------------------
NUM_DOWNLOADS_MAP = {
0: "Low",
1: "Medium",
2: "High"
}
# ----------------------------
# 3️⃣ Extraction + prédiction
# ----------------------------
def extract_and_predict(url):
try:
sound_id = int(url.rstrip("/").split("/")[-1])
# Inclure duration explicitement
all_features = list(set(
xgb_music_feat_num + xgb_music_feat_avg + xgb_effect_feat_num + xgb_effect_feat_avg
))
fields = "duration," + ",".join(all_features)
results = client.search(
query="",
filter=f"id:{sound_id}",
fields=fields
)
if len(results.results) == 0:
return pd.DataFrame([{"Erreur": "Sound not found"}])
sound = results.results[0]
# ⚠️ Récupérer duration séparément
duration = safe_float(sound.get("duration", 0))
# ✅ Décider du type
if 0.5 <= duration <= 3:
# Effect Sound
num = predict_with_model(xgb_effect_num, sound, xgb_effect_feat_num)
avg = predict_with_model(xgb_effect_avg, sound, xgb_effect_feat_avg, le_effect_avg)
return pd.DataFrame([{
"Type": "Effect Sound",
"Duration": duration,
"Num_downloads": NUM_DOWNLOADS_MAP.get(num, str(num)),
"Avg_rating": avg
}])
elif 10 <= duration <= 60:
# Music
num = predict_with_model(xgb_music_num, sound, xgb_music_feat_num)
avg = predict_with_model(xgb_music_avg, sound, xgb_music_feat_avg, le_music_avg)
return pd.DataFrame([{
"Type": "Music",
"Duration": duration,
"Num_downloads": NUM_DOWNLOADS_MAP.get(num, str(num)),
"Avg_rating": avg
}])
else:
return pd.DataFrame([{
"Erreur": "Durée non supportée pour prédiction",
"Duration": duration
}])
except Exception as e:
return pd.DataFrame([{"Erreur": str(e)}])
# ----------------------------
# 4️⃣ Interface Gradio
# ----------------------------
with gr.Blocks() as demo:
gr.Markdown("## 🎧 FreeSound – Prédiction XGBoost (DMatrix)")
url = gr.Textbox(label="URL FreeSound", placeholder="https://freesound.org/s/123456/")
btn = gr.Button("Prédire")
out = gr.Dataframe()
btn.click(extract_and_predict, url, out)
demo.launch()