Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from pydub import AudioSegment | |
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import librosa | |
| import math | |
| import json | |
| GENRES = ["Metal", "Disco", "Pop", "Classical", "Reggae", "Country", "Rock", "Hiphop", "Jazz", "Blues"] | |
| class CNNModel2(nn.Module): | |
| def __init__(self, input_shape, num_classes=10): | |
| super(CNNModel2, self).__init__() | |
| self.conv1 = nn.Conv2d(input_shape[0], 32, kernel_size=3) | |
| self.bn1 = nn.BatchNorm2d(32) | |
| self.pool1 = nn.MaxPool2d(3, stride=2, padding=1) | |
| self.dropout1 = nn.Dropout(0.2) | |
| self.conv2 = nn.Conv2d(32, 64, kernel_size=3) | |
| self.bn2 = nn.BatchNorm2d(64) | |
| self.pool2 = nn.MaxPool2d(3, stride=2, padding=1) | |
| self.dropout2 = nn.Dropout(0.1) | |
| self.conv3 = nn.Conv2d(64, 64, kernel_size=2) | |
| self.bn3 = nn.BatchNorm2d(64) | |
| self.pool3 = nn.MaxPool2d(2, stride=2, padding=1) | |
| self.dropout3 = nn.Dropout(0.1) | |
| self.flatten_dim = self._calculate_flatten_dim(input_shape) | |
| self.fc1 = nn.Linear(self.flatten_dim, 128) | |
| self.dropout4 = nn.Dropout(0.5) | |
| self.fc2 = nn.Linear(128, num_classes) | |
| def _calculate_flatten_dim(self, input_shape): | |
| x = torch.zeros(1, *input_shape) | |
| x = self.pool1(F.relu(self.bn1(self.conv1(x)))) | |
| x = self.pool2(F.relu(self.bn2(self.conv2(x)))) | |
| x = self.pool3(F.relu(self.bn3(self.conv3(x)))) | |
| return x.numel() | |
| def forward(self, x): | |
| x = self.pool1(F.relu(self.bn1(self.conv1(x)))) | |
| x = self.dropout1(x) | |
| x = self.pool2(F.relu(self.bn2(self.conv2(x)))) | |
| x = self.dropout2(x) | |
| x = self.pool3(F.relu(self.bn3(self.conv3(x)))) | |
| x = self.dropout3(x) | |
| x = x.view(-1, self.flatten_dim) | |
| x = F.relu(self.fc1(x)) | |
| x = self.dropout4(x) | |
| x = self.fc2(x) | |
| return x | |
| def get_mfccs(file, fs=22500, duration=30, n_fft=2048, hop_length=512, n_mfcc=13, num_segments=10): | |
| data = { | |
| "genre_name": [], | |
| "genre_num": [], | |
| "mfcc": [] | |
| } | |
| samples_per_track = fs * duration | |
| samps_per_segment = int(samples_per_track/num_segments) | |
| mfccs_per_segment = math.ceil(samps_per_segment/hop_length) | |
| audio, fs = librosa.load(file, sr=fs) | |
| for seg in range(num_segments): | |
| start_sample = seg * samps_per_segment | |
| end_sample = start_sample + samps_per_segment | |
| mfcc = librosa.feature.mfcc(y=audio[start_sample:end_sample], | |
| sr=fs, | |
| n_fft=n_fft, | |
| hop_length=hop_length, | |
| n_mfcc=n_mfcc) | |
| mfcc = mfcc.T | |
| if len(mfcc) == mfccs_per_segment: | |
| data["mfcc"].append(mfcc.tolist()) | |
| with open('data.json', "w") as filepath: | |
| json.dump(data, filepath, indent=4) | |
| return np.array(data["mfcc"]), np.array(data["genre_name"]), np.array(data["genre_num"]) | |
| def cut_audio(input_file): | |
| audio = AudioSegment.from_wav(input_file) | |
| first_30_seconds = audio[:30 * 1000] | |
| output_file = "output_30_seconds.wav" | |
| first_30_seconds.export(output_file, format="wav") | |
| return output_file | |
| model = CNNModel2((1, 130, 13)) | |
| model.load_state_dict(torch.load("model_cnn2.pth", map_location=torch.device('cpu'))) | |
| model.eval() | |
| st.title("Audio Genre Classification") | |
| st.write("Upload an audio file to classify its genre.") | |
| uploaded_file = st.file_uploader("Choose an audio file...", type=["wav"]) | |
| if uploaded_file is not None: | |
| file_path = "temp.wav" | |
| with open(file_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| file_path = cut_audio(file_path) | |
| mfccs, genre_names, genre_nums = get_mfccs(file_path) | |
| X_to_pred = torch.tensor(mfccs, dtype=torch.float32).unsqueeze(1) | |
| predictions = [] | |
| for mfcc in X_to_pred: | |
| mfcc = mfcc.unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model(mfcc) | |
| _, predicted_class = torch.max(output, 1) | |
| predictions.append(predicted_class.item()) | |
| st.write("Prediction:") | |
| st.write(GENRES[max(set(predictions), key=predictions.count)]) |