File size: 3,014 Bytes
4a8263d
 
 
 
 
 
08dd23d
e4fb85b
 
 
5def43b
3f5ef7d
5def43b
3f5ef7d
5def43b
3f5ef7d
 
5def43b
 
3f5ef7d
 
 
 
5def43b
 
e4fb85b
 
 
5def43b
 
e4fb85b
5def43b
 
 
43d2c04
3f5ef7d
5def43b
3f5ef7d
4a8263d
3f5ef7d
4a8263d
 
3f5ef7d
5def43b
4a8263d
e4fb85b
5def43b
e4fb85b
4a8263d
 
 
 
 
 
 
e4fb85b
 
3f5ef7d
e4fb85b
 
 
3f5ef7d
 
 
 
 
 
 
 
 
e4fb85b
 
 
3f5ef7d
e4fb85b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45dcbd0
e4fb85b
 
3f5ef7d
 
e4fb85b
ddcdf91
25cfd4b
2ba6a81
5def43b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import os

# ----------------------------
# Define Model
# ----------------------------
class AudioCNN(nn.Module):
    def __init__(self, num_classes=3):
        super(AudioCNN, self).__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=5, stride=2)
        self.pool = nn.MaxPool1d(2)
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.fc1 = nn.Linear(16, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # [B, 16, L']
        x = self.global_pool(x)                # [B, 16, 1]
        x = x.view(x.size(0), -1)             # [B, 16]
        x = self.fc1(x)                       # [B, num_classes]
        return x

# ----------------------------
# Load model
# ----------------------------
num_classes = 3
model_save_path = "audio_cnn_model.pth"

model = AudioCNN(num_classes)

if os.path.exists(model_save_path):
    try:
        model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')), strict=False)
        model.eval()
        print(f"✅ Model state dictionary loaded from {model_save_path}")
    except Exception as e:
        print(f"⚠️ Error loading model state dictionary: {e}")
        model = None
else:
    print(f"⚠️ Model state dictionary not found at {model_save_path}")
    model = None

# ----------------------------
# Prediction function
# ----------------------------
def predict_audio(audio_file_path):
    if model is None:
        return "Model not loaded. Cannot make predictions."
    if audio_file_path is None:
        return "No audio file provided."

    try:
        waveform, sample_rate = torchaudio.load(audio_file_path)

        # Convert stereo → mono
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        # Normalize waveform
        waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-6)

        # Ensure correct shape [batch, channels, length]
        waveform = waveform.unsqueeze(0)

        # Predict
        with torch.no_grad():
            outputs = model(waveform)
        _, predicted_index = torch.max(outputs.data, 1)
        predicted_index = predicted_index.item()

        label_map = {0: 'English', 1: 'Code-switched', 2: 'Other'}
        predicted_label = label_map.get(predicted_index, "Unknown")
        return predicted_label

    except Exception as e:
        return f"Error during prediction: {e}"

# ----------------------------
# Launch Gradio
# ----------------------------
if model is not None:
    interface = gr.Interface(
        fn=predict_audio,
        inputs=gr.Audio(type="filepath"),
        outputs=gr.Label(),
        title="Audio Code-Switching Detector",
        description="Upload an audio file to detect if it contains code-switching."
    )
    interface.launch(share=True)
else:
    print("⚠️ Gradio interface not created due to model loading error.")