KawtarTL commited on
Commit
4eeb3ef
·
verified ·
1 Parent(s): 4ba1f61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -95
app.py CHANGED
@@ -1,71 +1,11 @@
1
- import gradio as gr
2
- import pandas as pd
3
- import freesound
4
- import joblib
5
- import xgboost as xgb
6
-
7
- # ----------------------------
8
- # Config Freesound
9
- # ----------------------------
10
- API_TOKEN = "zE9NjEOgUMzH9K7mjiGBaPJiNwJLjSM53LevarRK"
11
- client = freesound.FreesoundClient()
12
- client.set_token(API_TOKEN, "token")
13
-
14
- # ----------------------------
15
- # 1️⃣ Charger les modèles
16
- # ----------------------------
17
- # Music
18
- xgb_music_num = joblib.load("xgb_num_downloads_music_model.pkl")
19
- xgb_music_feat_num = joblib.load("xgb_num_downloads_music_features.pkl")
20
- xgb_music_avg = joblib.load("xgb_avg_rating_music_model.pkl")
21
- xgb_music_feat_avg = joblib.load("xgb_avg_rating_music_features.pkl")
22
- le_music_avg = joblib.load("xgb_avg_rating_music_label_encoder.pkl")
23
-
24
- # Effect Sound
25
- xgb_effect_num = joblib.load("xgb_num_downloads_effectsound_model.pkl")
26
- xgb_effect_feat_num = joblib.load("xgb_num_downloads_effectsound_features.pkl")
27
- xgb_effect_avg = joblib.load("xgb_avg_rating_effectsound_model.pkl")
28
- xgb_effect_feat_avg = joblib.load("xgb_avg_rating_effectsound_features.pkl")
29
- le_effect_avg = joblib.load("xgb_avg_rating_effectsound_label_encoder.pkl")
30
-
31
- # ----------------------------
32
- # 2️⃣ Fonctions utilitaires
33
- # ----------------------------
34
- def safe_float(v):
35
- try:
36
- return float(v)
37
- except:
38
- return 0.0
39
-
40
- def predict_with_model(model, features, feat_list, le=None):
41
- row = []
42
- for col in feat_list:
43
- val = features.get(col, 0)
44
- if val is None or isinstance(val, (list, dict)):
45
- val = 0
46
- row.append(safe_float(val))
47
-
48
- X = pd.DataFrame([row], columns=feat_list)
49
- dmatrix = xgb.DMatrix(X.values, feature_names=feat_list)
50
-
51
- pred_int = int(model.get_booster().predict(dmatrix)[0])
52
- if le:
53
- return le.inverse_transform([pred_int])[0]
54
- return pred_int
55
-
56
- # ----------------------------
57
- # 3️⃣ Extraction + prédiction
58
- # ----------------------------
59
  def extract_and_predict(url):
60
  try:
61
  sound_id = int(url.rstrip("/").split("/")[-1])
62
-
63
- # 🔹 Inclure duration explicitement
64
- all_features = list(set(
65
- xgb_music_feat_num + xgb_music_feat_avg + xgb_effect_feat_num + xgb_effect_feat_avg
66
- ))
67
  fields = "duration," + ",".join(all_features)
68
-
69
  results = client.search(
70
  query="",
71
  filter=f"id:{sound_id}",
@@ -76,46 +16,23 @@ def extract_and_predict(url):
76
  return pd.DataFrame([{"Erreur": "Sound not found"}])
77
 
78
  sound = results.results[0]
 
 
79
  duration = safe_float(sound.get("duration", 0))
80
-
81
- # 🔹 Vérifier le type selon la durée
82
  if 0.5 <= duration <= 3:
83
  # Effect Sound
84
  num = predict_with_model(xgb_effect_num, sound, xgb_effect_feat_num)
85
  avg = predict_with_model(xgb_effect_avg, sound, xgb_effect_feat_avg, le_effect_avg)
86
- return pd.DataFrame([{
87
- "Type": "Effect Sound",
88
- "Duration": duration,
89
- "Num_downloads": num,
90
- "Avg_rating": avg
91
- }])
92
  elif 10 <= duration <= 60:
93
  # Music
94
  num = predict_with_model(xgb_music_num, sound, xgb_music_feat_num)
95
  avg = predict_with_model(xgb_music_avg, sound, xgb_music_feat_avg, le_music_avg)
96
- return pd.DataFrame([{
97
- "Type": "Music",
98
- "Duration": duration,
99
- "Num_downloads": num,
100
- "Avg_rating": avg
101
- }])
102
  else:
103
- return pd.DataFrame([{
104
- "Erreur": "Durée non supportée pour prédiction",
105
- "Duration": duration
106
- }])
107
 
108
  except Exception as e:
109
- return pd.DataFrame([{"Erreur": str(e)}])
110
-
111
- # ----------------------------
112
- # 4️⃣ Interface Gradio
113
- # ----------------------------
114
- with gr.Blocks() as demo:
115
- gr.Markdown("## 🎧 FreeSound – Prédiction XGBoost (DMatrix)")
116
- url = gr.Textbox(label="URL FreeSound", placeholder="https://freesound.org/s/123456/")
117
- btn = gr.Button("Prédire")
118
- out = gr.Dataframe()
119
- btn.click(extract_and_predict, url, out)
120
-
121
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def extract_and_predict(url):
2
  try:
3
  sound_id = int(url.rstrip("/").split("/")[-1])
4
+
5
+ # Inclure duration explicitement
6
+ all_features = list(set(xgb_music_feat_num + xgb_music_feat_avg + xgb_effect_feat_num + xgb_effect_feat_avg))
 
 
7
  fields = "duration," + ",".join(all_features)
8
+
9
  results = client.search(
10
  query="",
11
  filter=f"id:{sound_id}",
 
16
  return pd.DataFrame([{"Erreur": "Sound not found"}])
17
 
18
  sound = results.results[0]
19
+
20
+ # ⚠️ Récupérer duration séparément
21
  duration = safe_float(sound.get("duration", 0))
22
+
23
+ # Décider du type
24
  if 0.5 <= duration <= 3:
25
  # Effect Sound
26
  num = predict_with_model(xgb_effect_num, sound, xgb_effect_feat_num)
27
  avg = predict_with_model(xgb_effect_avg, sound, xgb_effect_feat_avg, le_effect_avg)
28
+ return pd.DataFrame([{"Type": "Effect Sound", "Duration": duration, "Num_downloads": num, "Avg_rating": avg}])
 
 
 
 
 
29
  elif 10 <= duration <= 60:
30
  # Music
31
  num = predict_with_model(xgb_music_num, sound, xgb_music_feat_num)
32
  avg = predict_with_model(xgb_music_avg, sound, xgb_music_feat_avg, le_music_avg)
33
+ return pd.DataFrame([{"Type": "Music", "Duration": duration, "Num_downloads": num, "Avg_rating": avg}])
 
 
 
 
 
34
  else:
35
+ return pd.DataFrame([{"Erreur": "Durée non supportée pour prédiction", "Duration": duration}])
 
 
 
36
 
37
  except Exception as e:
38
+ return pd.DataFrame([{"Erreur": str(e)}])