Spaces:
Sleeping
Sleeping
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| import os | |
| import numpy as np | |
| from transformers import Wav2Vec2Processor, Wav2Vec2Model | |
| from safetensors.torch import load_file | |
| import torch.nn as nn | |
| from huggingface_hub import hf_hub_download | |
| model_path = hf_hub_download(repo_id="creativepurus/accent-wav2vec2", filename="model.safetensors") | |
| # Load processor | |
| processor = Wav2Vec2Processor.from_pretrained("creativepurus/accent-wav2vec2") | |
| # Load model weights from model.safetensors | |
| state_dict = load_file(model_path, device="cpu") | |
| # Define the same model architecture used during training | |
| class Wav2Vec2Classifier(nn.Module): | |
| def __init__(self): | |
| super(Wav2Vec2Classifier, self).__init__() | |
| self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-960h") | |
| self.dropout = nn.Dropout(0.3) | |
| self.classifier = nn.Linear(self.wav2vec2.config.hidden_size, 2) | |
| def forward(self, input_values, attention_mask=None): | |
| outputs = self.wav2vec2(input_values, attention_mask=attention_mask) | |
| hidden_states = outputs.last_hidden_state | |
| pooled = hidden_states.mean(dim=1) | |
| pooled = self.dropout(pooled) | |
| logits = self.classifier(pooled) | |
| return logits | |
| # Instantiate and load the model | |
| model = Wav2Vec2Classifier() | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| # Prediction function | |
| def predict_accent(audio): | |
| waveform, sample_rate = torchaudio.load(audio) | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
| waveform = resampler(waveform) | |
| input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values | |
| with torch.no_grad(): | |
| logits = model(input_values) | |
| predicted_class_id = logits.argmax().item() | |
| label_map = {0: "Canadian English", 1: "England English"} | |
| return label_map[predicted_class_id] | |
| # # Gradio UI | |
| # interface = gr.Interface( | |
| # fn=predict_accent, | |
| # inputs=gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio (WAV)"), | |
| # outputs=gr.Textbox(label="Predicted Accent"), | |
| # title="Accent Classification", | |
| # description="This app classifies English accents as either Canadian or England using a fine-tuned Wav2Vec2 model.", | |
| # allow_flagging="never" | |
| # ) | |
| # Gradio UI with gr.Blocks | |
| # Gradio UI with gr.Blocks and Custom Styling | |
| custom_css = """ | |
| #predict-btn { | |
| background-color: orange !important; | |
| color: white !important; | |
| font-weight: bold; | |
| } | |
| #author-section { | |
| font-size: 18px; | |
| font-weight: 500; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css) as demo: | |
| gr.Markdown("## π£οΈ Accent Classification App") | |
| gr.Markdown("This app classifies English accents as either **Canadian** or **England** using a fine-tuned Wav2Vec2 model.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio (WAV)") | |
| predict_button = gr.Button("Predict Accent", elem_id="predict-btn") | |
| with gr.Column(): | |
| result_output = gr.Textbox(label="Predicted Accent") | |
| predict_button.click(fn=predict_accent, inputs=audio_input, outputs=result_output) | |
| gr.Markdown("---") | |
| gr.Markdown(""" | |
| <div id="author-section"> | |
| π¨βπ» Created by <strong>Anand Purushottam</strong> | |
| π <a href="https://github.com/creativepurus" target="_blank">GitHub</a> | | |
| <a href="https://linkedin.com/in/creativepurus" target="_blank">LinkedIn</a> | |
| </div> | |
| """, | |
| ) | |
| demo.launch() |