MichalIwaniuk's picture
comit
a2e21e4
import gradio as gr
import numpy as np
import librosa
import librosa.display
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
import io
# Parametry
SR = 22050
N_MELS = 128
TARGET_FRAMES = 216
LABELS = ['cel', 'cla', 'flu', 'gac', 'gel', 'org', 'pia', 'sax', 'tru', 'vio', 'voi']
polskie_nazwy = {
'cel': 'wiolonczela',
'cla': 'klawesyn',
'flu': 'flet',
'gac': 'gitara klasyczna',
'gel': 'gitara elektryczna',
'org': 'organy',
'pia': 'fortepian',
'sax': 'saksofon',
'tru': 'tr膮bka',
'vio': 'skrzypce',
'voi': 'g艂os ludzki'
}
dark_theme = gr.themes.Base(
primary_hue="blue",
neutral_hue="gray",
font="sans"
).set(
body_background_fill="#121212",
block_background_fill="#1E1E1E",
body_text_color="#ffffff",
button_primary_background_fill="#2d72d9",
button_primary_text_color="#ffffff"
)
# Wczytanie modelu
model = tf.keras.models.load_model("model.h5")
def compute_melspectrogram(y, sr=SR, n_mels=N_MELS):
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels)
return librosa.power_to_db(S, ref=np.max)
def resize_spectrogram(S, target_frames=TARGET_FRAMES):
if S.shape[1] < target_frames:
pad = target_frames - S.shape[1]
left = pad // 2
right = pad - left
S = np.pad(S, ((0, 0), (left, right)), mode='constant')
elif S.shape[1] > target_frames:
start = (S.shape[1] - target_frames) // 2
S = S[:, start:start+target_frames]
return S
def predict_and_plot(audio_path):
y, _ = librosa.load(audio_path, sr=SR)
S_full = compute_melspectrogram(y)
S = resize_spectrogram(S_full)
x = S[np.newaxis, ..., np.newaxis]
preds = model.predict(x, verbose=0)[0]
fig, ax = plt.subplots(figsize=(8, 4))
librosa.display.specshow(S_full, sr=SR, x_axis='time', y_axis='mel', cmap='magma', ax=ax)
ax.set_title("Mel-spektrogram")
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
image = Image.open(buf)
pred_dict = {polskie_nazwy[label]: float(p) for label, p in zip(LABELS, preds)}
return pred_dict, image
demo = gr.Interface(
fn=predict_and_plot,
inputs=gr.Audio(type="filepath", label="Wgraj plik WAV"),
outputs=[
gr.Label(num_top_classes=5, label="Predykcja"),
gr.Image(label="Spektrogram")
],
title="Rozpoznawanie instrument贸w",
description="Model klasyfikuje d藕wi臋ki do kilku z klas instrument贸w.",
theme=dark_theme,
submit_btn="Zatwierd藕",
clear_btn="Wyczy艣膰"
)
if __name__ == "__main__":
demo.launch()