Spaces:
Sleeping
Sleeping
| import tempfile | |
| import torch | |
| import torch.nn.functional as F | |
| import torchaudio | |
| import gradio as gr | |
| from transformers import Wav2Vec2FeatureExtractor, AutoConfig | |
| from models import Wav2Vec2ForSpeechClassification, HubertForSpeechClassification | |
| # Load model and feature extractor | |
| config = AutoConfig.from_pretrained("Gizachew/wev2vec-large960-agu-amharic") | |
| feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("Gizachew/wev2vec-large960-agu-amharic") | |
| model = Wav2Vec2ForSpeechClassification.from_pretrained("Gizachew/wev2vec-large960-agu-amharic") | |
| sampling_rate = feature_extractor.sampling_rate | |
| # List of test examples (you can replace these with your actual file paths) | |
| test_example = [ | |
| "test_examples/a5-06-02-02-60.wav", | |
| "test_examples/f2-04-02-02-65.wav", | |
| "test_examples/h3-06-02-02-41.wav", | |
| "test_examples/n1-01-01-01-25.wav", | |
| "test_examples/s4-06-01-02-51.wav" | |
| ] | |
| # Define inputs and outputs for the Gradio interface | |
| audio_input = gr.Audio(label="Upload file", type="filepath") | |
| dropdown_input = gr.Dropdown(label="Choose an example audio file", choices=test_example) | |
| text_output = gr.TextArea(label="Emotion Prediction Output", text_align="right", rtl=True, type="text") | |
| def SER(audio): | |
| with tempfile.NamedTemporaryFile(suffix=".wav") as temp_audio_file: | |
| # Copy the contents of the uploaded audio file to the temporary file | |
| temp_audio_file.write(open(audio, "rb").read()) | |
| temp_audio_file.flush() | |
| # Load the audio file using torchaudio | |
| speech_array, _sampling_rate = torchaudio.load(temp_audio_file.name) | |
| resampler = torchaudio.transforms.Resample(_sampling_rate) | |
| speech = resampler(speech_array).squeeze().numpy() | |
| inputs = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True) | |
| inputs = {key: inputs[key] for key in inputs} | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0] | |
| # Get the highest score and its corresponding label | |
| max_index = scores.argmax() | |
| label = config.id2label[max_index] | |
| score = scores[max_index] | |
| # Format the output string | |
| output = f"{label}: {score * 100:.1f}%" | |
| return output | |
| def process_audio(audio_path): | |
| return SER(audio_path) | |
| # Create the Gradio interface | |
| iface = gr.Interface( | |
| fn=process_audio, | |
| inputs=[audio_input, dropdown_input], | |
| outputs=text_output | |
| ) | |
| # Launch the Gradio app | |
| iface.launch(share=True) | |