Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchaudio | |
| import os | |
| # ---------------------------- | |
| # Define Model | |
| # ---------------------------- | |
| class AudioCNN(nn.Module): | |
| def __init__(self, num_classes=3): | |
| super(AudioCNN, self).__init__() | |
| self.conv1 = nn.Conv1d(1, 16, kernel_size=5, stride=2) | |
| self.pool = nn.MaxPool1d(2) | |
| self.global_pool = nn.AdaptiveAvgPool1d(1) | |
| self.fc1 = nn.Linear(16, num_classes) | |
| def forward(self, x): | |
| x = self.pool(F.relu(self.conv1(x))) # [B, 16, L'] | |
| x = self.global_pool(x) # [B, 16, 1] | |
| x = x.view(x.size(0), -1) # [B, 16] | |
| x = self.fc1(x) # [B, num_classes] | |
| return x | |
| # ---------------------------- | |
| # Load model | |
| # ---------------------------- | |
| num_classes = 3 | |
| model_save_path = "audio_cnn_model.pth" | |
| model = AudioCNN(num_classes) | |
| if os.path.exists(model_save_path): | |
| try: | |
| model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')), strict=False) | |
| model.eval() | |
| print(f"✅ Model state dictionary loaded from {model_save_path}") | |
| except Exception as e: | |
| print(f"⚠️ Error loading model state dictionary: {e}") | |
| model = None | |
| else: | |
| print(f"⚠️ Model state dictionary not found at {model_save_path}") | |
| model = None | |
| # ---------------------------- | |
| # Prediction function | |
| # ---------------------------- | |
| def predict_audio(audio_file_path): | |
| if model is None: | |
| return "Model not loaded. Cannot make predictions." | |
| if audio_file_path is None: | |
| return "No audio file provided." | |
| try: | |
| waveform, sample_rate = torchaudio.load(audio_file_path) | |
| # Convert stereo → mono | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| # Normalize waveform | |
| waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-6) | |
| # Ensure correct shape [batch, channels, length] | |
| waveform = waveform.unsqueeze(0) | |
| # Predict | |
| with torch.no_grad(): | |
| outputs = model(waveform) | |
| _, predicted_index = torch.max(outputs.data, 1) | |
| predicted_index = predicted_index.item() | |
| label_map = {0: 'English', 1: 'Code-switched', 2: 'Other'} | |
| predicted_label = label_map.get(predicted_index, "Unknown") | |
| return predicted_label | |
| except Exception as e: | |
| return f"Error during prediction: {e}" | |
| # ---------------------------- | |
| # Launch Gradio | |
| # ---------------------------- | |
| if model is not None: | |
| interface = gr.Interface( | |
| fn=predict_audio, | |
| inputs=gr.Audio(type="filepath"), | |
| outputs=gr.Label(), | |
| title="Audio Code-Switching Detector", | |
| description="Upload an audio file to detect if it contains code-switching." | |
| ) | |
| interface.launch(share=True) | |
| else: | |
| print("⚠️ Gradio interface not created due to model loading error.") | |