Moremoholo2's picture
Update app.py
3f5ef7d verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import os
# ----------------------------
# Define Model
# ----------------------------
class AudioCNN(nn.Module):
def __init__(self, num_classes=3):
super(AudioCNN, self).__init__()
self.conv1 = nn.Conv1d(1, 16, kernel_size=5, stride=2)
self.pool = nn.MaxPool1d(2)
self.global_pool = nn.AdaptiveAvgPool1d(1)
self.fc1 = nn.Linear(16, num_classes)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # [B, 16, L']
x = self.global_pool(x) # [B, 16, 1]
x = x.view(x.size(0), -1) # [B, 16]
x = self.fc1(x) # [B, num_classes]
return x
# ----------------------------
# Load model
# ----------------------------
num_classes = 3
model_save_path = "audio_cnn_model.pth"
model = AudioCNN(num_classes)
if os.path.exists(model_save_path):
try:
model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')), strict=False)
model.eval()
print(f"✅ Model state dictionary loaded from {model_save_path}")
except Exception as e:
print(f"⚠️ Error loading model state dictionary: {e}")
model = None
else:
print(f"⚠️ Model state dictionary not found at {model_save_path}")
model = None
# ----------------------------
# Prediction function
# ----------------------------
def predict_audio(audio_file_path):
if model is None:
return "Model not loaded. Cannot make predictions."
if audio_file_path is None:
return "No audio file provided."
try:
waveform, sample_rate = torchaudio.load(audio_file_path)
# Convert stereo → mono
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Normalize waveform
waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-6)
# Ensure correct shape [batch, channels, length]
waveform = waveform.unsqueeze(0)
# Predict
with torch.no_grad():
outputs = model(waveform)
_, predicted_index = torch.max(outputs.data, 1)
predicted_index = predicted_index.item()
label_map = {0: 'English', 1: 'Code-switched', 2: 'Other'}
predicted_label = label_map.get(predicted_index, "Unknown")
return predicted_label
except Exception as e:
return f"Error during prediction: {e}"
# ----------------------------
# Launch Gradio
# ----------------------------
if model is not None:
interface = gr.Interface(
fn=predict_audio,
inputs=gr.Audio(type="filepath"),
outputs=gr.Label(),
title="Audio Code-Switching Detector",
description="Upload an audio file to detect if it contains code-switching."
)
interface.launch(share=True)
else:
print("⚠️ Gradio interface not created due to model loading error.")