import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import torchaudio import os # ---------------------------- # Define Model # ---------------------------- class AudioCNN(nn.Module): def __init__(self, num_classes=3): super(AudioCNN, self).__init__() self.conv1 = nn.Conv1d(1, 16, kernel_size=5, stride=2) self.pool = nn.MaxPool1d(2) self.global_pool = nn.AdaptiveAvgPool1d(1) self.fc1 = nn.Linear(16, num_classes) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) # [B, 16, L'] x = self.global_pool(x) # [B, 16, 1] x = x.view(x.size(0), -1) # [B, 16] x = self.fc1(x) # [B, num_classes] return x # ---------------------------- # Load model # ---------------------------- num_classes = 3 model_save_path = "audio_cnn_model.pth" model = AudioCNN(num_classes) if os.path.exists(model_save_path): try: model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')), strict=False) model.eval() print(f"✅ Model state dictionary loaded from {model_save_path}") except Exception as e: print(f"⚠️ Error loading model state dictionary: {e}") model = None else: print(f"⚠️ Model state dictionary not found at {model_save_path}") model = None # ---------------------------- # Prediction function # ---------------------------- def predict_audio(audio_file_path): if model is None: return "Model not loaded. Cannot make predictions." if audio_file_path is None: return "No audio file provided." try: waveform, sample_rate = torchaudio.load(audio_file_path) # Convert stereo → mono if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Normalize waveform waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-6) # Ensure correct shape [batch, channels, length] waveform = waveform.unsqueeze(0) # Predict with torch.no_grad(): outputs = model(waveform) _, predicted_index = torch.max(outputs.data, 1) predicted_index = predicted_index.item() label_map = {0: 'English', 1: 'Code-switched', 2: 'Other'} predicted_label = label_map.get(predicted_index, "Unknown") return predicted_label except Exception as e: return f"Error during prediction: {e}" # ---------------------------- # Launch Gradio # ---------------------------- if model is not None: interface = gr.Interface( fn=predict_audio, inputs=gr.Audio(type="filepath"), outputs=gr.Label(), title="Audio Code-Switching Detector", description="Upload an audio file to detect if it contains code-switching." ) interface.launch(share=True) else: print("⚠️ Gradio interface not created due to model loading error.")