Djibi972 commited on
Commit
091748b
·
verified ·
1 Parent(s): e3e476b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -14
app.py CHANGED
@@ -1,8 +1,15 @@
1
- import gradio as gr, numpy as np, librosa, soundfile as sf
 
 
 
 
 
 
2
  from perch_hoplite.zoo import model_configs
3
 
4
- # Charge Perch v2 (télécharge depuis Kaggle via hoplite)
5
- MODEL = model_configs.load_model_by_name("perch_V2")
 
6
  SR = 32000
7
  WIN = 5 * SR
8
 
@@ -11,6 +18,8 @@ def _prep(wav, sr):
11
  wav = np.mean(wav, axis=1)
12
  if sr != SR:
13
  wav = librosa.resample(wav.astype(np.float32), orig_sr=sr, target_sr=SR)
 
 
14
  if len(wav) < WIN:
15
  wav = np.pad(wav, (0, WIN - len(wav)))
16
  else:
@@ -19,33 +28,50 @@ def _prep(wav, sr):
19
 
20
  def infer(audio):
21
  if audio is None:
22
- return {"error": "no audio"}
23
- wav, sr = audio
 
 
 
24
  wav = _prep(wav, sr)
25
 
26
  out = MODEL.embed(wav)
 
 
 
 
27
  logits = out.logits["label"]
28
- labels = out.label_names["label"] if hasattr(out, "label_names") else None
29
 
 
30
  idx = np.argsort(logits)[::-1][:3]
31
  topk = []
32
- for i in idx:
33
- name = labels[i] if labels is not None else f"class_{int(i)}"
34
- prob = float(np.exp(logits[i]) / np.sum(np.exp(logits[idx])))
 
 
 
 
 
 
 
35
  topk.append({"label": name, "score": round(prob, 4)})
36
 
37
  return {
38
  "topk": topk,
39
  "embedding_dim": int(out.embeddings.shape[-1]),
40
- "note": "scores non calibrés; régler un seuil selon votre usage"
41
  }
42
 
43
  demo = gr.Interface(
44
  fn=infer,
45
- inputs=gr.Audio(type="numpy", sources=["microphone", "upload"]),
46
- outputs=gr.JSON(label="Perch v2"),
47
- title="Perch 2.0 — Bioacoustics",
 
48
  allow_flagging="never"
49
  )
50
 
51
- demo.queue(api_open=True).launch()
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import librosa
4
+
5
+ # Pas besoin d'importer soundfile si non utilisé directement
6
+ # import soundfile as sf
7
+
8
  from perch_hoplite.zoo import model_configs
9
 
10
+ # CORRECTION: Utilisation du nom de modèle valide "perch" au lieu de "perch_V2"
11
+ # Le commentaire peut rester car il s'agit bien conceptuellement de la v2 du modèle.
12
+ MODEL = model_configs.load_model_by_name("perch")
13
  SR = 32000
14
  WIN = 5 * SR
15
 
 
18
  wav = np.mean(wav, axis=1)
19
  if sr != SR:
20
  wav = librosa.resample(wav.astype(np.float32), orig_sr=sr, target_sr=SR)
21
+
22
+ # S'assurer que le tableau a la bonne longueur (padding ou troncature)
23
  if len(wav) < WIN:
24
  wav = np.pad(wav, (0, WIN - len(wav)))
25
  else:
 
28
 
29
  def infer(audio):
30
  if audio is None:
31
+ # Retourner un dictionnaire vide ou un message d'erreur clair pour l'interface JSON
32
+ return {"erreur": "Aucun fichier audio fourni."}
33
+
34
+ # Gradio fournit les données audio sous forme de tuple (fréquence d'échantillonnage, données numpy)
35
+ sr, wav = audio
36
  wav = _prep(wav, sr)
37
 
38
  out = MODEL.embed(wav)
39
+ # Assurez-vous que la clé 'label' existe dans les logits
40
+ if "label" not in out.logits:
41
+ return {"erreur": "La sortie du modèle ne contient pas de logits 'label'."}
42
+
43
  logits = out.logits["label"]
44
+ labels = out.label_names.get("label")
45
 
46
+ # Calcul des 3 meilleures prédictions
47
  idx = np.argsort(logits)[::-1][:3]
48
  topk = []
49
+
50
+ # Normalisation Softmax sur les logits des 3 meilleures classes pour obtenir un score relatif
51
+ top_logits = logits[idx]
52
+ exp_logits = np.exp(top_logits - np.max(top_logits)) # Stabilité numérique
53
+ sum_exp_logits = np.sum(exp_logits)
54
+
55
+ for i in range(len(idx)):
56
+ class_index = idx[i]
57
+ name = labels[class_index] if labels is not None and class_index < len(labels) else f"classe_{int(class_index)}"
58
+ prob = float(exp_logits[i] / sum_exp_logits)
59
  topk.append({"label": name, "score": round(prob, 4)})
60
 
61
  return {
62
  "topk": topk,
63
  "embedding_dim": int(out.embeddings.shape[-1]),
64
+ "note": "Les scores sont des probabilités relatives aux 3 meilleurs résultats, pas des scores absolus."
65
  }
66
 
67
  demo = gr.Interface(
68
  fn=infer,
69
+ inputs=gr.Audio(type="numpy", label="Audio (5 secondes max)"),
70
+ outputs=gr.JSON(label="Résultats de l'inférence"),
71
+ title="Perch 2.0 — Identification Bioacoustique",
72
+ description="Téléchargez un fichier audio ou utilisez votre microphone pour identifier les sons (oiseaux, etc.). Le modèle analyse les 5 premières secondes.",
73
  allow_flagging="never"
74
  )
75
 
76
+ # api_open=True permet d'appeler l'interface comme une API
77
+ demo.queue().launch(debug=True)