Spaces:
Runtime error
Runtime error
File size: 3,014 Bytes
4a8263d 08dd23d e4fb85b 5def43b 3f5ef7d 5def43b 3f5ef7d 5def43b 3f5ef7d 5def43b 3f5ef7d 5def43b e4fb85b 5def43b e4fb85b 5def43b 43d2c04 3f5ef7d 5def43b 3f5ef7d 4a8263d 3f5ef7d 4a8263d 3f5ef7d 5def43b 4a8263d e4fb85b 5def43b e4fb85b 4a8263d e4fb85b 3f5ef7d e4fb85b 3f5ef7d e4fb85b 3f5ef7d e4fb85b 45dcbd0 e4fb85b 3f5ef7d e4fb85b ddcdf91 25cfd4b 2ba6a81 5def43b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import os
# ----------------------------
# Define Model
# ----------------------------
class AudioCNN(nn.Module):
def __init__(self, num_classes=3):
super(AudioCNN, self).__init__()
self.conv1 = nn.Conv1d(1, 16, kernel_size=5, stride=2)
self.pool = nn.MaxPool1d(2)
self.global_pool = nn.AdaptiveAvgPool1d(1)
self.fc1 = nn.Linear(16, num_classes)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # [B, 16, L']
x = self.global_pool(x) # [B, 16, 1]
x = x.view(x.size(0), -1) # [B, 16]
x = self.fc1(x) # [B, num_classes]
return x
# ----------------------------
# Load model
# ----------------------------
num_classes = 3
model_save_path = "audio_cnn_model.pth"
model = AudioCNN(num_classes)
if os.path.exists(model_save_path):
try:
model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')), strict=False)
model.eval()
print(f"✅ Model state dictionary loaded from {model_save_path}")
except Exception as e:
print(f"⚠️ Error loading model state dictionary: {e}")
model = None
else:
print(f"⚠️ Model state dictionary not found at {model_save_path}")
model = None
# ----------------------------
# Prediction function
# ----------------------------
def predict_audio(audio_file_path):
if model is None:
return "Model not loaded. Cannot make predictions."
if audio_file_path is None:
return "No audio file provided."
try:
waveform, sample_rate = torchaudio.load(audio_file_path)
# Convert stereo → mono
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Normalize waveform
waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-6)
# Ensure correct shape [batch, channels, length]
waveform = waveform.unsqueeze(0)
# Predict
with torch.no_grad():
outputs = model(waveform)
_, predicted_index = torch.max(outputs.data, 1)
predicted_index = predicted_index.item()
label_map = {0: 'English', 1: 'Code-switched', 2: 'Other'}
predicted_label = label_map.get(predicted_index, "Unknown")
return predicted_label
except Exception as e:
return f"Error during prediction: {e}"
# ----------------------------
# Launch Gradio
# ----------------------------
if model is not None:
interface = gr.Interface(
fn=predict_audio,
inputs=gr.Audio(type="filepath"),
outputs=gr.Label(),
title="Audio Code-Switching Detector",
description="Upload an audio file to detect if it contains code-switching."
)
interface.launch(share=True)
else:
print("⚠️ Gradio interface not created due to model loading error.")
|