TuneNet / app.py
diegovelilla's picture
Update app.py
baa6140 verified
import streamlit as st
from pydub import AudioSegment
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import librosa
import math
import json
GENRES = ["Metal", "Disco", "Pop", "Classical", "Reggae", "Country", "Rock", "Hiphop", "Jazz", "Blues"]
class CNNModel2(nn.Module):
def __init__(self, input_shape, num_classes=10):
super(CNNModel2, self).__init__()
self.conv1 = nn.Conv2d(input_shape[0], 32, kernel_size=3)
self.bn1 = nn.BatchNorm2d(32)
self.pool1 = nn.MaxPool2d(3, stride=2, padding=1)
self.dropout1 = nn.Dropout(0.2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
self.bn2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(3, stride=2, padding=1)
self.dropout2 = nn.Dropout(0.1)
self.conv3 = nn.Conv2d(64, 64, kernel_size=2)
self.bn3 = nn.BatchNorm2d(64)
self.pool3 = nn.MaxPool2d(2, stride=2, padding=1)
self.dropout3 = nn.Dropout(0.1)
self.flatten_dim = self._calculate_flatten_dim(input_shape)
self.fc1 = nn.Linear(self.flatten_dim, 128)
self.dropout4 = nn.Dropout(0.5)
self.fc2 = nn.Linear(128, num_classes)
def _calculate_flatten_dim(self, input_shape):
x = torch.zeros(1, *input_shape)
x = self.pool1(F.relu(self.bn1(self.conv1(x))))
x = self.pool2(F.relu(self.bn2(self.conv2(x))))
x = self.pool3(F.relu(self.bn3(self.conv3(x))))
return x.numel()
def forward(self, x):
x = self.pool1(F.relu(self.bn1(self.conv1(x))))
x = self.dropout1(x)
x = self.pool2(F.relu(self.bn2(self.conv2(x))))
x = self.dropout2(x)
x = self.pool3(F.relu(self.bn3(self.conv3(x))))
x = self.dropout3(x)
x = x.view(-1, self.flatten_dim)
x = F.relu(self.fc1(x))
x = self.dropout4(x)
x = self.fc2(x)
return x
def get_mfccs(file, fs=22500, duration=30, n_fft=2048, hop_length=512, n_mfcc=13, num_segments=10):
data = {
"genre_name": [],
"genre_num": [],
"mfcc": []
}
samples_per_track = fs * duration
samps_per_segment = int(samples_per_track/num_segments)
mfccs_per_segment = math.ceil(samps_per_segment/hop_length)
audio, fs = librosa.load(file, sr=fs)
for seg in range(num_segments):
start_sample = seg * samps_per_segment
end_sample = start_sample + samps_per_segment
mfcc = librosa.feature.mfcc(y=audio[start_sample:end_sample],
sr=fs,
n_fft=n_fft,
hop_length=hop_length,
n_mfcc=n_mfcc)
mfcc = mfcc.T
if len(mfcc) == mfccs_per_segment:
data["mfcc"].append(mfcc.tolist())
with open('data.json', "w") as filepath:
json.dump(data, filepath, indent=4)
return np.array(data["mfcc"]), np.array(data["genre_name"]), np.array(data["genre_num"])
def cut_audio(input_file):
audio = AudioSegment.from_wav(input_file)
first_30_seconds = audio[:30 * 1000]
output_file = "output_30_seconds.wav"
first_30_seconds.export(output_file, format="wav")
return output_file
model = CNNModel2((1, 130, 13))
model.load_state_dict(torch.load("model_cnn2.pth", map_location=torch.device('cpu')))
model.eval()
st.title("Audio Genre Classification")
st.write("Upload an audio file to classify its genre.")
uploaded_file = st.file_uploader("Choose an audio file...", type=["wav"])
if uploaded_file is not None:
file_path = "temp.wav"
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
file_path = cut_audio(file_path)
mfccs, genre_names, genre_nums = get_mfccs(file_path)
X_to_pred = torch.tensor(mfccs, dtype=torch.float32).unsqueeze(1)
predictions = []
for mfcc in X_to_pred:
mfcc = mfcc.unsqueeze(0)
with torch.no_grad():
output = model(mfcc)
_, predicted_class = torch.max(output, 1)
predictions.append(predicted_class.item())
st.write("Prediction:")
st.write(GENRES[max(set(predictions), key=predictions.count)])