import torch import torchaudio import gradio as gr import os import numpy as np from transformers import Wav2Vec2Processor, Wav2Vec2Model from safetensors.torch import load_file import torch.nn as nn from huggingface_hub import hf_hub_download model_path = hf_hub_download(repo_id="creativepurus/accent-wav2vec2", filename="model.safetensors") # Load processor processor = Wav2Vec2Processor.from_pretrained("creativepurus/accent-wav2vec2") # Load model weights from model.safetensors state_dict = load_file(model_path, device="cpu") # Define the same model architecture used during training class Wav2Vec2Classifier(nn.Module): def __init__(self): super(Wav2Vec2Classifier, self).__init__() self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-960h") self.dropout = nn.Dropout(0.3) self.classifier = nn.Linear(self.wav2vec2.config.hidden_size, 2) def forward(self, input_values, attention_mask=None): outputs = self.wav2vec2(input_values, attention_mask=attention_mask) hidden_states = outputs.last_hidden_state pooled = hidden_states.mean(dim=1) pooled = self.dropout(pooled) logits = self.classifier(pooled) return logits # Instantiate and load the model model = Wav2Vec2Classifier() model.load_state_dict(state_dict) model.eval() # Prediction function def predict_accent(audio): waveform, sample_rate = torchaudio.load(audio) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) waveform = resampler(waveform) input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values with torch.no_grad(): logits = model(input_values) predicted_class_id = logits.argmax().item() label_map = {0: "Canadian English", 1: "England English"} return label_map[predicted_class_id] # # Gradio UI # interface = gr.Interface( # fn=predict_accent, # inputs=gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio (WAV)"), # outputs=gr.Textbox(label="Predicted Accent"), # title="Accent Classification", # description="This app classifies English accents as either Canadian or England using a fine-tuned Wav2Vec2 model.", # allow_flagging="never" # ) # Gradio UI with gr.Blocks # Gradio UI with gr.Blocks and Custom Styling custom_css = """ #predict-btn { background-color: orange !important; color: white !important; font-weight: bold; } #author-section { font-size: 18px; font-weight: 500; } """ with gr.Blocks(css=custom_css) as demo: gr.Markdown("## 🗣️ Accent Classification App") gr.Markdown("This app classifies English accents as either **Canadian** or **England** using a fine-tuned Wav2Vec2 model.") with gr.Row(): with gr.Column(): audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio (WAV)") predict_button = gr.Button("Predict Accent", elem_id="predict-btn") with gr.Column(): result_output = gr.Textbox(label="Predicted Accent") predict_button.click(fn=predict_accent, inputs=audio_input, outputs=result_output) gr.Markdown("---") gr.Markdown("""
""", ) demo.launch()