Jabrave commited on
Commit
6495d4e
·
verified ·
1 Parent(s): fa9a9ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -14
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from transformers import AutoFeatureExtractor
2
  from transformers import AutoModelForAudioClassification
3
- import torchaudio
4
  from detect_face import detect_face
5
  from transformers import AutoModelForImageClassification
6
  from transformers import AutoImageProcessor
@@ -71,18 +71,14 @@ def predict_with_model(image, model, processor):
71
  "confidence": round(confidence * 100, 2)
72
  }
73
 
74
- def predict_audio(audio_path):
75
 
76
- torchaudio.set_audio_backend("soundfile")
77
 
78
- waveform, sample_rate = torchaudio.load(
79
- audio_path,
80
- backend="soundfile"
81
- )
82
 
83
  inputs = voice_processor(
84
- waveform.squeeze().numpy(),
85
- sampling_rate=sample_rate,
86
  return_tensors="pt"
87
  )
88
 
@@ -90,13 +86,9 @@ def predict_audio(audio_path):
90
  outputs = voice_model(**inputs)
91
 
92
  logits = outputs.logits
93
-
94
  predicted_class = logits.argmax(-1).item()
95
 
96
- confidence = torch.softmax(
97
- logits,
98
- dim=1
99
- )[0][predicted_class].item()
100
 
101
  label = voice_model.config.id2label[predicted_class]
102
 
 
1
  from transformers import AutoFeatureExtractor
2
  from transformers import AutoModelForAudioClassification
3
+ import librosa
4
  from detect_face import detect_face
5
  from transformers import AutoModelForImageClassification
6
  from transformers import AutoImageProcessor
 
71
  "confidence": round(confidence * 100, 2)
72
  }
73
 
 
74
 
75
+ def predict_audio(audio_path):
76
 
77
+ waveform, sr = librosa.load(audio_path, sr=16000)
 
 
 
78
 
79
  inputs = voice_processor(
80
+ waveform,
81
+ sampling_rate=16000,
82
  return_tensors="pt"
83
  )
84
 
 
86
  outputs = voice_model(**inputs)
87
 
88
  logits = outputs.logits
 
89
  predicted_class = logits.argmax(-1).item()
90
 
91
+ confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
 
 
 
92
 
93
  label = voice_model.config.id2label[predicted_class]
94