Spaces:
Build error
Build error
| import torch | |
| import whisper | |
| import torchaudio | |
| import gradio as gr | |
| import torch.nn as nn | |
| from huggingface_hub import hf_hub_download | |
| # Define the same model class used during training | |
| class DialectClassifier(nn.Module): | |
| def __init__(self, input_dim, num_classes): | |
| super(DialectClassifier, self).__init__() | |
| self.fc1 = nn.Linear(input_dim, 128) | |
| self.fc2 = nn.Linear(128, 64) | |
| self.fc3 = nn.Linear(64, num_classes) | |
| self.relu = nn.ReLU() | |
| def forward(self, x): | |
| x = x.view(x.size(0), -1) # Flatten the input tensor | |
| x = self.relu(self.fc1(x)) | |
| x = self.relu(self.fc2(x)) | |
| x = self.fc3(x) | |
| return x | |
| # Function to preprocess audio and extract features | |
| def preprocess_audio(file_path, whisper_model, device): | |
| def load_audio(file_path): | |
| waveform, sample_rate = torchaudio.load(file_path) | |
| if sample_rate != 16000: | |
| waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) | |
| # Convert to single channel (mono) if necessary | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| # Pad or trim audio to 30 seconds | |
| desired_length = 16000 * 30 # 30 seconds at 16 kHz | |
| current_length = waveform.shape[1] | |
| if current_length < desired_length: | |
| # Pad with zeros | |
| padding = desired_length - current_length | |
| waveform = torch.nn.functional.pad(waveform, (0, padding)) | |
| elif current_length > desired_length: | |
| # Trim to desired length | |
| waveform = waveform[:, :desired_length] | |
| return waveform | |
| audio = load_audio(file_path) | |
| audio = whisper.pad_or_trim(audio.flatten()) | |
| mel = whisper.log_mel_spectrogram(audio).to_dense() | |
| with torch.no_grad(): | |
| mel = mel.unsqueeze(0).to(device) # Add batch dimension and move to device | |
| features = whisper_model.encoder(mel) | |
| return features | |
| repo_id = "dipankar53/assamese_dialect_classifier_model" | |
| model_filename = "dialect_classifier_model.pth" | |
| model_path = hf_hub_download(repo_id=repo_id, filename=model_filename) | |
| label_to_idx = {"Darrangiya Accent": 0, "Kamrupiya Accent": 1, "Upper Assam": 2, "Nalbaria Accent": 3} | |
| # Load Whisper model | |
| whisper_model = whisper.load_model("medium") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load the trained model | |
| num_classes = len(label_to_idx) | |
| sample_input = torch.randn(1, 80, 3000).to(device) | |
| with torch.no_grad(): | |
| sample_output = whisper_model.encoder(sample_input) | |
| input_dim = sample_output.view(1, -1).shape[1] # Flatten and get dimension | |
| model = DialectClassifier(input_dim, num_classes) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # Function to predict the dialect of a single audio file | |
| def predict_dialect(audio_path): | |
| try: | |
| # Preprocess audio and extract features | |
| features = preprocess_audio(audio_path, whisper_model, device) | |
| features = features.view(1, -1) # Flatten features | |
| # Perform prediction | |
| with torch.no_grad(): | |
| outputs = model(features) | |
| _, predicted = torch.max(outputs, 1) | |
| # Map predicted index back to dialect label | |
| idx_to_label = {idx: label for label, idx in label_to_idx.items()} | |
| predicted_label = idx_to_label[predicted.item()] | |
| return f"Predicted Dialect: {predicted_label}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Define Gradio interface | |
| interface = gr.Interface( | |
| fn=predict_dialect, | |
| inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"), | |
| outputs="text", | |
| title="Assamese Dialect Prediction", | |
| description="Upload an Assamese audio file to predict its dialect.", | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| interface.launch() |