Spaces:
Build error
Build error
File size: 3,904 Bytes
cfa5958 978917b cfa5958 978917b cfa5958 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import torch
import whisper
import torchaudio
import gradio as gr
import torch.nn as nn
from huggingface_hub import hf_hub_download
# Define the same model class used during training
class DialectClassifier(nn.Module):
def __init__(self, input_dim, num_classes):
super(DialectClassifier, self).__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, num_classes)
self.relu = nn.ReLU()
def forward(self, x):
x = x.view(x.size(0), -1) # Flatten the input tensor
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# Function to preprocess audio and extract features
def preprocess_audio(file_path, whisper_model, device):
def load_audio(file_path):
waveform, sample_rate = torchaudio.load(file_path)
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Convert to single channel (mono) if necessary
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Pad or trim audio to 30 seconds
desired_length = 16000 * 30 # 30 seconds at 16 kHz
current_length = waveform.shape[1]
if current_length < desired_length:
# Pad with zeros
padding = desired_length - current_length
waveform = torch.nn.functional.pad(waveform, (0, padding))
elif current_length > desired_length:
# Trim to desired length
waveform = waveform[:, :desired_length]
return waveform
audio = load_audio(file_path)
audio = whisper.pad_or_trim(audio.flatten())
mel = whisper.log_mel_spectrogram(audio).to_dense()
with torch.no_grad():
mel = mel.unsqueeze(0).to(device) # Add batch dimension and move to device
features = whisper_model.encoder(mel)
return features
repo_id = "dipankar53/assamese_dialect_classifier_model"
model_filename = "dialect_classifier_model.pth"
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
label_to_idx = {"Darrangiya Accent": 0, "Kamrupiya Accent": 1, "Upper Assam": 2, "Nalbaria Accent": 3}
# Load Whisper model
whisper_model = whisper.load_model("medium")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the trained model
num_classes = len(label_to_idx)
sample_input = torch.randn(1, 80, 3000).to(device)
with torch.no_grad():
sample_output = whisper_model.encoder(sample_input)
input_dim = sample_output.view(1, -1).shape[1] # Flatten and get dimension
model = DialectClassifier(input_dim, num_classes)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
# Function to predict the dialect of a single audio file
def predict_dialect(audio_path):
try:
# Preprocess audio and extract features
features = preprocess_audio(audio_path, whisper_model, device)
features = features.view(1, -1) # Flatten features
# Perform prediction
with torch.no_grad():
outputs = model(features)
_, predicted = torch.max(outputs, 1)
# Map predicted index back to dialect label
idx_to_label = {idx: label for label, idx in label_to_idx.items()}
predicted_label = idx_to_label[predicted.item()]
return f"Predicted Dialect: {predicted_label}"
except Exception as e:
return f"Error: {str(e)}"
# Define Gradio interface
interface = gr.Interface(
fn=predict_dialect,
inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"),
outputs="text",
title="Assamese Dialect Prediction",
description="Upload an Assamese audio file to predict its dialect.",
)
# Launch the interface
if __name__ == "__main__":
interface.launch() |