Yilin0601 commited on
Commit
3e63959
·
verified ·
1 Parent(s): af74093

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -11
app.py CHANGED
@@ -5,13 +5,10 @@ import librosa
5
  from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
6
 
7
  # --------------------------------------------------
8
- # Configuration
9
  # --------------------------------------------------
10
- # Your fine-tuned model has 8 classes, corresponding to levels 3..10
11
- num_labels = 8
12
-
13
- # Load your fine-tuned model from the Hugging Face Hub
14
- # (Replace "Yilin0601/wav2vec2-accuracy-checkpoints" with your actual repo if different)
15
  model = Wav2Vec2ForSequenceClassification.from_pretrained(
16
  "Yilin0601/wav2vec2-accuracy-checkpoints"
17
  )
@@ -29,11 +26,15 @@ def predict(audio):
29
  # Gradio provides audio as (sample_rate, np.array)
30
  sample_rate, audio_data = audio
31
 
 
 
 
 
32
  # Convert stereo to mono if needed
33
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
34
  audio_data = np.mean(audio_data, axis=1)
35
 
36
- # Resample to 16 kHz if not already
37
  if sample_rate != 16000:
38
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
39
 
@@ -50,11 +51,9 @@ def predict(audio):
50
  with torch.no_grad():
51
  logits = model(**inputs).logits
52
 
53
- # Argmax over logits -> integer class in [0..7]
54
  pred_class = torch.argmax(logits, dim=-1).item()
55
-
56
- # Map [0..7] back to levels [3..10] by adding 3
57
- predicted_level = pred_class + 3
58
 
59
  return f"Predicted Level: {predicted_level}"
60
 
 
5
  from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
6
 
7
  # --------------------------------------------------
8
+ # Load Your Fine-Tuned Model
9
  # --------------------------------------------------
10
+ # This model was fine-tuned with labels remapped from [3..10] to [0..7].
11
+ # Make sure the model repo name below is correct and accessible.
 
 
 
12
  model = Wav2Vec2ForSequenceClassification.from_pretrained(
13
  "Yilin0601/wav2vec2-accuracy-checkpoints"
14
  )
 
26
  # Gradio provides audio as (sample_rate, np.array)
27
  sample_rate, audio_data = audio
28
 
29
+ # Ensure the audio is floating-point (librosa requires float32 or float64)
30
+ if audio_data.dtype not in [np.float32, np.float64]:
31
+ audio_data = audio_data.astype(np.float32)
32
+
33
  # Convert stereo to mono if needed
34
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
35
  audio_data = np.mean(audio_data, axis=1)
36
 
37
+ # Resample to 16 kHz if needed
38
  if sample_rate != 16000:
39
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
40
 
 
51
  with torch.no_grad():
52
  logits = model(**inputs).logits
53
 
54
+ # The model output is an 8-class prediction (0..7), corresponding to original labels 3..10
55
  pred_class = torch.argmax(logits, dim=-1).item()
56
+ predicted_level = pred_class + 3 # Map back to [3..10]
 
 
57
 
58
  return f"Predicted Level: {predicted_level}"
59